jiyatai commited on
Commit
d8618db
·
verified ·
1 Parent(s): 428367e

Upload qwen25vl.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. qwen25vl.py +250 -0
qwen25vl.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import json
3
+ import random
4
+ import io
5
+ import ast
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ from PIL import ImageColor
8
+ from tqdm import tqdm
9
+ import torch
10
+ import os
11
+ import torch.distributed as dist
12
+ import xml.etree.ElementTree as ET
13
+
14
+ additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
15
+
16
+ def decode_xml_points(text):
17
+ try:
18
+ root = ET.fromstring(text)
19
+ num_points = (len(root.attrib) - 1) // 2
20
+ points = []
21
+ for i in range(num_points):
22
+ x = root.attrib.get(f'x{i+1}')
23
+ y = root.attrib.get(f'y{i+1}')
24
+ points.append([x, y])
25
+ alt = root.attrib.get('alt')
26
+ phrase = root.text.strip() if root.text else None
27
+ return {
28
+ "points": points,
29
+ "alt": alt,
30
+ "phrase": phrase
31
+ }
32
+ except Exception as e:
33
+ print(e)
34
+ return None
35
+
36
+ def plot_bounding_boxes(im, bounding_boxes, input_width, input_height):
37
+ """
38
+ Plots bounding boxes on an image with markers for each a name, using PIL, normalized coordinates, and different colors.
39
+
40
+ Args:
41
+ img_path: The path to the image file.
42
+ bounding_boxes: A list of bounding boxes containing the name of the object
43
+ and their positions in normalized [y1 x1 y2 x2] format.
44
+ """
45
+
46
+ # Load the image
47
+ img = im
48
+ width, height = img.size
49
+ # print(img.size)
50
+ # Create a drawing object
51
+ draw = ImageDraw.Draw(img)
52
+
53
+ # Define a list of colors
54
+ colors = [
55
+ 'red',
56
+ 'green',
57
+ 'blue',
58
+ 'yellow',
59
+ 'orange',
60
+ 'pink',
61
+ 'purple',
62
+ 'brown',
63
+ 'gray',
64
+ 'beige',
65
+ 'turquoise',
66
+ 'cyan',
67
+ 'magenta',
68
+ 'lime',
69
+ 'navy',
70
+ 'maroon',
71
+ 'teal',
72
+ 'olive',
73
+ 'coral',
74
+ 'lavender',
75
+ 'violet',
76
+ 'gold',
77
+ 'silver',
78
+ ] + additional_colors
79
+
80
+ # Parsing out the markdown fencing
81
+ bounding_boxes = parse_json(bounding_boxes)
82
+
83
+ # font = ImageFont.truetype("NotoSansCJK-Regular.ttc", size=14)
84
+
85
+ try:
86
+ json_output = ast.literal_eval(bounding_boxes)
87
+ except Exception as e:
88
+ end_idx = bounding_boxes.rfind('"}') + len('"}')
89
+ truncated_text = bounding_boxes[:end_idx] + "]"
90
+ json_output = ast.literal_eval(truncated_text)
91
+
92
+ # Iterate over the bounding boxes
93
+ for i, bounding_box in enumerate(json_output):
94
+ # Select a color from the list
95
+ color = colors[i % len(colors)]
96
+
97
+ # Convert normalized coordinates to absolute coordinates
98
+ abs_y1 = int(bounding_box["bbox_2d"][1]/input_height * height)
99
+ abs_x1 = int(bounding_box["bbox_2d"][0]/input_width * width)
100
+ abs_y2 = int(bounding_box["bbox_2d"][3]/input_height * height)
101
+ abs_x2 = int(bounding_box["bbox_2d"][2]/input_width * width)
102
+
103
+ if abs_x1 > abs_x2:
104
+ abs_x1, abs_x2 = abs_x2, abs_x1
105
+
106
+ if abs_y1 > abs_y2:
107
+ abs_y1, abs_y2 = abs_y2, abs_y1
108
+
109
+ # Draw the bounding box
110
+ draw.rectangle(
111
+ ((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4
112
+ )
113
+
114
+ # # Draw the text
115
+ # if "label" in bounding_box:
116
+ # draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
117
+
118
+ # Display the image
119
+ # img.show()
120
+ # img.save('output.png')
121
+ return [abs_x1, abs_y1, abs_x2, abs_y2]
122
+
123
+
124
+
125
+ # @title Parsing JSON output
126
+ def parse_json(json_output):
127
+ # Parsing out the markdown fencing
128
+ lines = json_output.splitlines()
129
+ for i, line in enumerate(lines):
130
+ if line == "```json":
131
+ json_output = "\n".join(lines[i+1:]) # Remove everything before "```json"
132
+ json_output = json_output.split("```")[0] # Remove everything after the closing "```"
133
+ break # Exit the loop once "```json" is found
134
+ return json_output
135
+
136
+
137
+ import torch
138
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
139
+
140
+ world_size = torch.cuda.device_count()
141
+ rank = int(os.environ.get("LOCAL_RANK", 0))
142
+ os.environ['MASTER_ADDR'] = 'localhost'
143
+ os.environ['MASTER_PORT'] = '12355'
144
+
145
+ # 初始化进程组
146
+ dist.init_process_group(
147
+ backend="nccl", # 使用NCCL后端(适用于GPU)
148
+ init_method="env://",
149
+ rank=rank,
150
+ world_size=world_size
151
+ )
152
+ print(f"Rank {rank} initialized")
153
+ device = torch.device(f"cuda:{rank}")
154
+
155
+ model_path = "/lustre/fsw/portfolios/nvr/users/yataij/pretrained/Qwen2.5-VL-7B-Instruct"
156
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",device_map={"": device},)
157
+ processor = AutoProcessor.from_pretrained(model_path)
158
+
159
+ def inference(image, prompt, system_prompt="You are a helpful assistant", max_new_tokens=1024):
160
+ # image = Image.open(img_url)
161
+ img_url_dummy = "/lustre/fsw/portfolios/nvr/users/yataij/data/SPAR-7M-RGBD/example.png"
162
+ messages = [
163
+ {
164
+ "role": "system",
165
+ "content": system_prompt
166
+ },
167
+ {
168
+ "role": "user",
169
+ "content": [
170
+ {
171
+ "type": "text",
172
+ "text": prompt
173
+ },
174
+ {
175
+ "image": img_url_dummy
176
+ }
177
+ ]
178
+ }
179
+ ]
180
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
+ # print("input:\n",text)
182
+ inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
183
+
184
+ output_ids = model.generate(**inputs, max_new_tokens=1024)
185
+ generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
186
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
187
+ # print("output:\n",output_text[0])
188
+
189
+ input_height = inputs['image_grid_thw'][0][1]*14
190
+ input_width = inputs['image_grid_thw'][0][2]*14
191
+
192
+ return output_text[0], input_height, input_width
193
+
194
+ # prepare the model input
195
+ dataset = load_dataset('/lustre/fs12/portfolios/nvr/projects/nvr_lpr_nvgptvision/users/yataij/data/SPAR_Bench')
196
+ print(len(dataset['test']))
197
+ print(dataset['test'].features.keys())
198
+ data_list = []
199
+ for i,example in tqdm(enumerate(dataset['test'])):
200
+ if example['img_type'] == 'single_view' and example['format_type'] == 'select': # select fill
201
+ data_list.append(example)
202
+
203
+ print('test', len(data_list))
204
+ visual_prompts = json.load(open('qwen3_visual_prompt_extract.json'))
205
+ print(len(visual_prompts))
206
+
207
+ data_list = data_list[rank::world_size]
208
+ visual_prompts = visual_prompts[rank::world_size]
209
+
210
+ res = []
211
+ for i in tqdm(range(len(data_list))):
212
+ instance = data_list[i]
213
+ visual_prompt = visual_prompts[i]
214
+ assert instance['id'] == visual_prompt['id']
215
+
216
+ image = instance['image'][0]
217
+ width, height = image.size
218
+ vp_bbox = {}
219
+ for vp,ins in visual_prompt['visual_prompt'].items():
220
+
221
+ if 'point' in vp:
222
+ color = vp.split()[0]
223
+ prompt = f"Locate the {color} round point, output its bbox coordinates using JSON format."
224
+ response, input_height, input_width = inference(image, prompt)
225
+ try:
226
+ coord = plot_bounding_boxes(image,response,input_width,input_height)
227
+ except:
228
+ print(i, vp)
229
+ continue
230
+ coord = [coord[0]-50, coord[1]-50, coord[2]+50, coord[3]+50]
231
+ if coord[0] < 0: coord[0] = 0
232
+ if coord[1] < 0: coord[1] = 0
233
+ if coord[2] > width: coord[2] = width
234
+ if coord[3] > height: coord[3] = height
235
+ vp_bbox[vp] = coord
236
+ elif 'bbox' in vp:
237
+ anno = f"the {ins} in {vp}"
238
+ prompt = f"Locate {anno}, output its bbox coordinates using JSON format."
239
+ response, input_height, input_width = inference(image, prompt)
240
+ try:
241
+ coord = plot_bounding_boxes(image,response,input_width,input_height)
242
+ except:
243
+ print(i, vp)
244
+ continue
245
+ vp_bbox[vp] = coord
246
+
247
+ visual_prompt['visual_prompt_bbox'] = vp_bbox
248
+ res.append(visual_prompt)
249
+ with open(f'qwen25vl_sparbench_singleimg_select_bbox_rank{rank}.json', 'w') as f:
250
+ json.dump(res, f, indent=4)