jiyatai commited on
Commit
e456723
·
verified ·
1 Parent(s): fd9d77f

Delete qwen25vl.py

Browse files
Files changed (1) hide show
  1. qwen25vl.py +0 -250
qwen25vl.py DELETED
@@ -1,250 +0,0 @@
1
- from datasets import load_dataset
2
- import json
3
- import random
4
- import io
5
- import ast
6
- from PIL import Image, ImageDraw, ImageFont
7
- from PIL import ImageColor
8
- from tqdm import tqdm
9
- import torch
10
- import os
11
- import torch.distributed as dist
12
- import xml.etree.ElementTree as ET
13
-
14
- additional_colors = [colorname for (colorname, colorcode) in ImageColor.colormap.items()]
15
-
16
- def decode_xml_points(text):
17
- try:
18
- root = ET.fromstring(text)
19
- num_points = (len(root.attrib) - 1) // 2
20
- points = []
21
- for i in range(num_points):
22
- x = root.attrib.get(f'x{i+1}')
23
- y = root.attrib.get(f'y{i+1}')
24
- points.append([x, y])
25
- alt = root.attrib.get('alt')
26
- phrase = root.text.strip() if root.text else None
27
- return {
28
- "points": points,
29
- "alt": alt,
30
- "phrase": phrase
31
- }
32
- except Exception as e:
33
- print(e)
34
- return None
35
-
36
- def plot_bounding_boxes(im, bounding_boxes, input_width, input_height):
37
- """
38
- Plots bounding boxes on an image with markers for each a name, using PIL, normalized coordinates, and different colors.
39
-
40
- Args:
41
- img_path: The path to the image file.
42
- bounding_boxes: A list of bounding boxes containing the name of the object
43
- and their positions in normalized [y1 x1 y2 x2] format.
44
- """
45
-
46
- # Load the image
47
- img = im
48
- width, height = img.size
49
- # print(img.size)
50
- # Create a drawing object
51
- draw = ImageDraw.Draw(img)
52
-
53
- # Define a list of colors
54
- colors = [
55
- 'red',
56
- 'green',
57
- 'blue',
58
- 'yellow',
59
- 'orange',
60
- 'pink',
61
- 'purple',
62
- 'brown',
63
- 'gray',
64
- 'beige',
65
- 'turquoise',
66
- 'cyan',
67
- 'magenta',
68
- 'lime',
69
- 'navy',
70
- 'maroon',
71
- 'teal',
72
- 'olive',
73
- 'coral',
74
- 'lavender',
75
- 'violet',
76
- 'gold',
77
- 'silver',
78
- ] + additional_colors
79
-
80
- # Parsing out the markdown fencing
81
- bounding_boxes = parse_json(bounding_boxes)
82
-
83
- # font = ImageFont.truetype("NotoSansCJK-Regular.ttc", size=14)
84
-
85
- try:
86
- json_output = ast.literal_eval(bounding_boxes)
87
- except Exception as e:
88
- end_idx = bounding_boxes.rfind('"}') + len('"}')
89
- truncated_text = bounding_boxes[:end_idx] + "]"
90
- json_output = ast.literal_eval(truncated_text)
91
-
92
- # Iterate over the bounding boxes
93
- for i, bounding_box in enumerate(json_output):
94
- # Select a color from the list
95
- color = colors[i % len(colors)]
96
-
97
- # Convert normalized coordinates to absolute coordinates
98
- abs_y1 = int(bounding_box["bbox_2d"][1]/input_height * height)
99
- abs_x1 = int(bounding_box["bbox_2d"][0]/input_width * width)
100
- abs_y2 = int(bounding_box["bbox_2d"][3]/input_height * height)
101
- abs_x2 = int(bounding_box["bbox_2d"][2]/input_width * width)
102
-
103
- if abs_x1 > abs_x2:
104
- abs_x1, abs_x2 = abs_x2, abs_x1
105
-
106
- if abs_y1 > abs_y2:
107
- abs_y1, abs_y2 = abs_y2, abs_y1
108
-
109
- # Draw the bounding box
110
- draw.rectangle(
111
- ((abs_x1, abs_y1), (abs_x2, abs_y2)), outline=color, width=4
112
- )
113
-
114
- # # Draw the text
115
- # if "label" in bounding_box:
116
- # draw.text((abs_x1 + 8, abs_y1 + 6), bounding_box["label"], fill=color, font=font)
117
-
118
- # Display the image
119
- # img.show()
120
- # img.save('output.png')
121
- return [abs_x1, abs_y1, abs_x2, abs_y2]
122
-
123
-
124
-
125
- # @title Parsing JSON output
126
- def parse_json(json_output):
127
- # Parsing out the markdown fencing
128
- lines = json_output.splitlines()
129
- for i, line in enumerate(lines):
130
- if line == "```json":
131
- json_output = "\n".join(lines[i+1:]) # Remove everything before "```json"
132
- json_output = json_output.split("```")[0] # Remove everything after the closing "```"
133
- break # Exit the loop once "```json" is found
134
- return json_output
135
-
136
-
137
- import torch
138
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
139
-
140
- world_size = torch.cuda.device_count()
141
- rank = int(os.environ.get("LOCAL_RANK", 0))
142
- os.environ['MASTER_ADDR'] = 'localhost'
143
- os.environ['MASTER_PORT'] = '12355'
144
-
145
- # 初始化进程组
146
- dist.init_process_group(
147
- backend="nccl", # 使用NCCL后端(适用于GPU)
148
- init_method="env://",
149
- rank=rank,
150
- world_size=world_size
151
- )
152
- print(f"Rank {rank} initialized")
153
- device = torch.device(f"cuda:{rank}")
154
-
155
- model_path = "/lustre/fsw/portfolios/nvr/users/yataij/pretrained/Qwen2.5-VL-7B-Instruct"
156
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2",device_map={"": device},)
157
- processor = AutoProcessor.from_pretrained(model_path)
158
-
159
- def inference(image, prompt, system_prompt="You are a helpful assistant", max_new_tokens=1024):
160
- # image = Image.open(img_url)
161
- img_url_dummy = "/lustre/fsw/portfolios/nvr/users/yataij/data/SPAR-7M-RGBD/example.png"
162
- messages = [
163
- {
164
- "role": "system",
165
- "content": system_prompt
166
- },
167
- {
168
- "role": "user",
169
- "content": [
170
- {
171
- "type": "text",
172
- "text": prompt
173
- },
174
- {
175
- "image": img_url_dummy
176
- }
177
- ]
178
- }
179
- ]
180
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
181
- # print("input:\n",text)
182
- inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(device)
183
-
184
- output_ids = model.generate(**inputs, max_new_tokens=1024)
185
- generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
186
- output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
187
- # print("output:\n",output_text[0])
188
-
189
- input_height = inputs['image_grid_thw'][0][1]*14
190
- input_width = inputs['image_grid_thw'][0][2]*14
191
-
192
- return output_text[0], input_height, input_width
193
-
194
- # prepare the model input
195
- dataset = load_dataset('/lustre/fs12/portfolios/nvr/projects/nvr_lpr_nvgptvision/users/yataij/data/SPAR_Bench')
196
- print(len(dataset['test']))
197
- print(dataset['test'].features.keys())
198
- data_list = []
199
- for i,example in tqdm(enumerate(dataset['test'])):
200
- if example['img_type'] == 'single_view' and example['format_type'] == 'select': # select fill
201
- data_list.append(example)
202
-
203
- print('test', len(data_list))
204
- visual_prompts = json.load(open('qwen3_visual_prompt_extract.json'))
205
- print(len(visual_prompts))
206
-
207
- data_list = data_list[rank::world_size]
208
- visual_prompts = visual_prompts[rank::world_size]
209
-
210
- res = []
211
- for i in tqdm(range(len(data_list))):
212
- instance = data_list[i]
213
- visual_prompt = visual_prompts[i]
214
- assert instance['id'] == visual_prompt['id']
215
-
216
- image = instance['image'][0]
217
- width, height = image.size
218
- vp_bbox = {}
219
- for vp,ins in visual_prompt['visual_prompt'].items():
220
-
221
- if 'point' in vp:
222
- color = vp.split()[0]
223
- prompt = f"Locate the {color} round point, output its bbox coordinates using JSON format."
224
- response, input_height, input_width = inference(image, prompt)
225
- try:
226
- coord = plot_bounding_boxes(image,response,input_width,input_height)
227
- except:
228
- print(i, vp)
229
- continue
230
- coord = [coord[0]-50, coord[1]-50, coord[2]+50, coord[3]+50]
231
- if coord[0] < 0: coord[0] = 0
232
- if coord[1] < 0: coord[1] = 0
233
- if coord[2] > width: coord[2] = width
234
- if coord[3] > height: coord[3] = height
235
- vp_bbox[vp] = coord
236
- elif 'bbox' in vp:
237
- anno = f"the {ins} in {vp}"
238
- prompt = f"Locate {anno}, output its bbox coordinates using JSON format."
239
- response, input_height, input_width = inference(image, prompt)
240
- try:
241
- coord = plot_bounding_boxes(image,response,input_width,input_height)
242
- except:
243
- print(i, vp)
244
- continue
245
- vp_bbox[vp] = coord
246
-
247
- visual_prompt['visual_prompt_bbox'] = vp_bbox
248
- res.append(visual_prompt)
249
- with open(f'qwen25vl_sparbench_singleimg_select_bbox_rank{rank}.json', 'w') as f:
250
- json.dump(res, f, indent=4)