File size: 9,716 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Qwen2.5-VL Grounding任务\n",
    "\n",
    "这里介绍使用qwen2.5-vl进行grounding任务的全流程介绍。当然,你也可以使用internvl2.5或者qwen2-vl等多模态模型。\n",
    "\n",
    "我们使用[AI-ModelScope/coco](https://modelscope.cn/datasets/AI-ModelScope/coco)数据集来展示整个流程。\n",
    "\n",
    "如果需要使用自定义数据集,需要符合以下格式:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, {\"role\": \"user\", \"content\": \"<image>描述图像\"}, {\"role\": \"assistant\", \"content\": \"<ref-object><bbox>和<ref-object><bbox>正在沙滩上玩耍\"}], \"images\": [\"/xxx/x.jpg\"], \"objects\": {\"ref\": [\"一只狗\", \"一个女人\"], \"bbox\": [[331.5, 761.4, 853.5, 1594.8], [676.5, 685.8, 1099.5, 1427.4]]}}\n",
    "{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, {\"role\": \"user\", \"content\": \"<image>找到图像中的<ref-object>\"}, {\"role\": \"assistant\", \"content\": \"<bbox><bbox>\"}], \"images\": [\"/xxx/x.jpg\"], \"objects\": {\"ref\": [\"\"], \"bbox\": [[90.9, 160.8, 135, 212.8], [360.9, 480.8, 495, 532.8]]}}\n",
    "{\"messages\": [{\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, {\"role\": \"user\", \"content\": \"<image>帮我打开谷歌浏览器\"}, {\"role\": \"assistant\", \"content\": \"Action: click(start_box='<bbox>')\"}], \"images\": [\"/xxx/x.jpg\"], \"objects\": {\"ref\": [], \"bbox\": [[615, 226]]}}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ms-swift在预处理数据集时,会使用模型特有的grounding任务格式,将objects中的ref填充`<ref-object>`,bbox会根据模型类型选择是否进行0-1000的归一化,并填充`<bbox>`。例如:qwen2-vl为`f'<|object_ref_start|>羊<|object_ref_end|>'`和`f'<|box_start|>(101,201),(150,266)<|box_end|>'`(qwen2.5-vl不进行归一化,只将float型转成int型),internvl2.5则为`f'<ref>羊</ref>'`和`f'<box>[[101, 201, 150, 266]]</box>'`等。\n",
    "\n",
    "\n",
    "训练之前,你需要从main分支安装ms-swift:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "# pip install git+https://github.com/modelscope/ms-swift.git\n",
    "\n",
    "git clone https://github.com/modelscope/ms-swift.git\n",
    "cd ms-swift\n",
    "pip install -e .\n",
    "\n",
    "# 如果'transformers>=4.49'已经发版,则无需从main分支安装\n",
    "pip install git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后,使用以下shell进行训练。MAX_PIXELS的参数含义可以查看[这里](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html#specific-model-arguments)\n",
    "\n",
    "### 训练\n",
    "\n",
    "单卡训练:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "# 显存资源:24GiB\n",
    "CUDA_VISIBLE_DEVICES=0 \\\n",
    "MAX_PIXELS=1003520 \\\n",
    "swift sft \\\n",
    "    --model Qwen/Qwen2.5-VL-7B-Instruct \\\n",
    "    --dataset 'AI-ModelScope/coco#2000' \\\n",
    "    --train_type lora \\\n",
    "    --torch_dtype bfloat16 \\\n",
    "    --num_train_epochs 1 \\\n",
    "    --per_device_train_batch_size 1 \\\n",
    "    --per_device_eval_batch_size 1 \\\n",
    "    --learning_rate 1e-4 \\\n",
    "    --lora_rank 8 \\\n",
    "    --lora_alpha 32 \\\n",
    "    --target_modules all-linear \\\n",
    "    --freeze_vit true \\\n",
    "    --gradient_accumulation_steps 16 \\\n",
    "    --eval_steps 100 \\\n",
    "    --save_steps 100 \\\n",
    "    --save_total_limit 5 \\\n",
    "    --logging_steps 5 \\\n",
    "    --max_length 2048 \\\n",
    "    --output_dir output \\\n",
    "    --warmup_ratio 0.05 \\\n",
    "    --dataloader_num_workers 4 \\\n",
    "    --dataset_num_proc 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们将训练的模型推送到ModelScope:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "swift export \\\n",
    "    --adapters output/vx-xxx/checkpoint-xxx \\\n",
    "    --push_to_hub true \\\n",
    "    --hub_model_id '<model-id>' \\\n",
    "    --hub_token '<sdk-token>' \\\n",
    "    --use_hf false"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将训练的checkpoint推送到[swift/test_grounding](https://modelscope.cn/models/swift/test_grounding)。\n",
    "\n",
    "### 推理\n",
    "\n",
    "训练完成后,我们使用以下命令对训练时的验证集进行推理。这里`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件,因此不需要额外指定`--model`。\n",
    "\n",
    "若模型采用的是绝对坐标的方式进行输出,推理时请提前对图像进行缩放而不使用`MAX_PIXELS`或者`--max_pixels`。若是千分位坐标,则没有此约束。\n",
    "\n",
    "由于我们已经将训练后的checkpoint推送到了ModelScope上,以下推理脚本可以直接运行:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "CUDA_VISIBLE_DEVICES=0 \\\n",
    "swift infer \\\n",
    "    --adapters swift/test_grounding \\\n",
    "    --stream true \\\n",
    "    --load_data_args true \\\n",
    "    --max_new_tokens 512 \\\n",
    "    --dataset_num_proc 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们也可以使用代码的方式进行推理:\n",
    "\n",
    "单样本推理的例子可以查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_grounding.py)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "import re\n",
    "from typing import Literal\n",
    "from swift.llm import (\n",
    "    PtEngine, RequestConfig, BaseArguments, InferRequest, safe_snapshot_download, draw_bbox, load_image, load_dataset, InferEngine\n",
    ")\n",
    "from IPython.display import display\n",
    "\n",
    "def infer_stream(engine: InferEngine, infer_request: InferRequest):\n",
    "    request_config = RequestConfig(max_tokens=512, temperature=0, stream=True)\n",
    "    gen_list = engine.infer([infer_request], request_config)\n",
    "    query = infer_request.messages[0]['content']\n",
    "    print(f'query: {query}\\nresponse: ', end='')\n",
    "    response = ''\n",
    "    for resp in gen_list[0]:\n",
    "        if resp is None:\n",
    "            continue\n",
    "        delta = resp.choices[0].delta.content\n",
    "        response += delta\n",
    "        print(delta, end='', flush=True)\n",
    "    print()\n",
    "    return response\n",
    "\n",
    "def draw_bbox_qwen2_vl(image, response, norm_bbox: Literal['norm1000', 'none']):\n",
    "    matches = re.findall(\n",
    "        r'<\\|object_ref_start\\|>(.*?)<\\|object_ref_end\\|><\\|box_start\\|>\\((\\d+),(\\d+)\\),\\((\\d+),(\\d+)\\)<\\|box_end\\|>',\n",
    "        response)\n",
    "    ref = []\n",
    "    bbox = []\n",
    "    for match_ in matches:\n",
    "        ref.append(match_[0])\n",
    "        bbox.append(list(match_[1:]))\n",
    "    draw_bbox(image, ref, bbox, norm_bbox=norm_bbox)\n",
    "\n",
    "# 下载权重,并加载模型\n",
    "output_dir = 'images_bbox'\n",
    "model_id_or_path = 'swift/test_grounding'\n",
    "output_dir = os.path.abspath(os.path.expanduser(output_dir))\n",
    "adapter_path = safe_snapshot_download(model_id_or_path)\n",
    "args = BaseArguments.from_pretrained(adapter_path)\n",
    "engine = PtEngine(args.model, adapters=[adapter_path])\n",
    "\n",
    "# 获取验证集并推理\n",
    "_, val_dataset = load_dataset(args.dataset, split_dataset_ratio=args.split_dataset_ratio, num_proc=4, seed=args.seed)\n",
    "print(f'output_dir: {output_dir}')\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "for i, data in enumerate(val_dataset):\n",
    "    image = data['images'][0]\n",
    "    image = load_image(image['bytes'] or image['path'])\n",
    "    display(image)\n",
    "    response = infer_stream(engine, InferRequest(**data))\n",
    "    draw_bbox_qwen2_vl(image, response, norm_bbox=args.norm_bbox)\n",
    "    print('-' * 50)\n",
    "    image.save(os.path.join(output_dir, f'{i}.png'))\n",
    "    display(image)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test_py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}