yuccaaa commited on
Commit
0a0c584
·
verified ·
1 Parent(s): a58504d

Upload ms-swift/examples/notebook/qwen2vl-ocr/infer.ipynb with huggingface_hub

Browse files
ms-swift/examples/notebook/qwen2vl-ocr/infer.ipynb ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Inference\n",
8
+ "We have trained a well-trained checkpoint through the `ocr-sft.ipynb` tutorial, and here we use `PtEngine` to do the inference on it."
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "metadata": {},
15
+ "outputs": [],
16
+ "source": [
17
+ "# import some libraries\n",
18
+ "import os\n",
19
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
20
+ "\n",
21
+ "from swift.llm import (\n",
22
+ " InferEngine, InferRequest, PtEngine, RequestConfig, get_template, load_dataset, load_image\n",
23
+ ")\n",
24
+ "from swift.utils import get_model_parameter_info, get_logger, seed_everything\n",
25
+ "logger = get_logger()\n",
26
+ "seed_everything(42)"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 2,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "# Hyperparameters for inference\n",
36
+ "last_model_checkpoint = 'output/checkpoint-xxx'\n",
37
+ "\n",
38
+ "# model\n",
39
+ "model_id_or_path = 'Qwen/Qwen2-VL-2B-Instruct' # model_id or model_path\n",
40
+ "system = None\n",
41
+ "infer_backend = 'pt'\n",
42
+ "\n",
43
+ "# dataset\n",
44
+ "dataset = ['AI-ModelScope/LaTeX_OCR#20000']\n",
45
+ "data_seed = 42\n",
46
+ "split_dataset_ratio = 0.01\n",
47
+ "num_proc = 4\n",
48
+ "strict = False\n",
49
+ "\n",
50
+ "# generation_config\n",
51
+ "max_new_tokens = 512\n",
52
+ "temperature = 0\n",
53
+ "stream = True"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "# Get model and template, and load LoRA weights.\n",
63
+ "engine = PtEngine(model_id_or_path, adapters=[last_model_checkpoint])\n",
64
+ "template = get_template(engine.model_meta.template, engine.tokenizer, default_system=system)\n",
65
+ "# The default mode of the template is 'pt', so there is no need to make any changes.\n",
66
+ "# template.set_mode('pt')\n",
67
+ "\n",
68
+ "model_parameter_info = get_model_parameter_info(engine.model)\n",
69
+ "logger.info(f'model_parameter_info: {model_parameter_info}')"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# Due to the data_seed setting, the validation set here is the same as the validation set used during training.\n",
79
+ "_, val_dataset = load_dataset(dataset, split_dataset_ratio=split_dataset_ratio, num_proc=num_proc,\n",
80
+ " strict=strict, seed=data_seed)\n",
81
+ "val_dataset = val_dataset.select(range(10)) # Take the first 10 items"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "# Streaming inference and save images from the validation set.\n",
91
+ "# The batch processing code can be found here: https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_mllm.py\n",
92
+ "def infer_stream(engine: InferEngine, infer_request: InferRequest):\n",
93
+ " request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature, stream=True)\n",
94
+ " gen_list = engine.infer([infer_request], request_config)\n",
95
+ " query = infer_request.messages[0]['content']\n",
96
+ " print(f'query: {query}\\nresponse: ', end='')\n",
97
+ " for resp in gen_list[0]:\n",
98
+ " if resp is None:\n",
99
+ " continue\n",
100
+ " print(resp.choices[0].delta.content, end='', flush=True)\n",
101
+ " print()\n",
102
+ "\n",
103
+ "from IPython.display import display\n",
104
+ "os.makedirs('images', exist_ok=True)\n",
105
+ "for i, data in enumerate(val_dataset):\n",
106
+ " image = data['images'][0]\n",
107
+ " image = load_image(image['bytes'] or image['path'])\n",
108
+ " image.save(f'images/{i}.png')\n",
109
+ " display(image)\n",
110
+ " infer_stream(engine, InferRequest(**data))\n",
111
+ " print('-' * 50)"
112
+ ]
113
+ }
114
+ ],
115
+ "metadata": {
116
+ "kernelspec": {
117
+ "display_name": "test_py310",
118
+ "language": "python",
119
+ "name": "python3"
120
+ },
121
+ "language_info": {
122
+ "codemirror_mode": {
123
+ "name": "ipython",
124
+ "version": 3
125
+ },
126
+ "file_extension": ".py",
127
+ "mimetype": "text/x-python",
128
+ "name": "python",
129
+ "nbconvert_exporter": "python",
130
+ "pygments_lexer": "ipython3",
131
+ "version": "3.10.15"
132
+ }
133
+ },
134
+ "nbformat": 4,
135
+ "nbformat_minor": 2
136
+ }