Fahad-S commited on
Commit
ee392c3
·
verified ·
1 Parent(s): 4468254

Upload test_screenspot_showui.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_screenspot_showui.py +226 -0
test_screenspot_showui.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import os
3
+ import json
4
+ import argparse
5
+ import torch
6
+ import sys
7
+ sys.path.append("/proj/cvl/users/x_fahkh2/UI-R1-Extention/UI-R1/src/ui_r1/src/open_r1")
8
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor,Qwen2_5_VLForConditionalGeneration
9
+ #from ..showui import ShowUIForConditionalGeneration, ShowUIProcessor
10
+ from showui import ShowUIForConditionalGeneration
11
+ from showui import ShowUIProcessor
12
+ from qwen_vl_utils import process_vision_info
13
+ import sys
14
+ import re
15
+ import multiprocessing as mp
16
+ import logging
17
+ from multiprocessing import Pool
18
+ import functools
19
+ import torch.multiprocessing as mp
20
+ logging.basicConfig()
21
+ logger = logging.getLogger(__name__)
22
+ logger.setLevel(logging.INFO)
23
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
24
+
25
+ rank = 0
26
+ def extract_coord(content):
27
+ # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
28
+ answer_tag_pattern = r'<answer>(.*?)</answer>'
29
+ bbox_pattern = r'\{.*\[(\d+),\s*(\d+)]\s*.*\}'
30
+ content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
31
+ if content_answer_match:
32
+ content_answer = content_answer_match.group(1).strip()
33
+ coord_match = re.search(bbox_pattern, content_answer)
34
+ if coord_match:
35
+ coord = [int(coord_match.group(1)), int(coord_match.group(2))]
36
+ return coord, True
37
+ else:
38
+ coord_pattern = r'\{.*\((\d+),\s*(\d+))\s*.*\}'
39
+ coord_match = re.search(coord_pattern, content)
40
+ if coord_match:
41
+ coord = [int(coord_match.group(1)), int(coord_match.group(2))]
42
+ return coord, True
43
+ return [0, 0, 0, 0], False
44
+
45
+
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ def run(rank, world_size, args):
50
+ model = ShowUIForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cpu")
51
+ '''
52
+ if "Qwen2.5" in args.model_path:
53
+ model = ShowUIForConditionalGeneration.from_pretrained(
54
+ args.model_path,
55
+ torch_dtype=torch.bfloat16,
56
+ attn_implementation="flash_attention_2",
57
+ device_map="cpu",
58
+ )
59
+ else:
60
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
61
+ args.model_path,
62
+ torch_dtype=torch.bfloat16,
63
+ attn_implementation="flash_attention_2",
64
+ device_map="cpu",
65
+ )
66
+ '''
67
+ if args.ori_processor_path is None:
68
+ ori_processor_path = args.model_path
69
+ infer_dir = os.path.join(args.model_path,'infer')
70
+ if not os.path.exists(infer_dir):
71
+ os.makedirs(infer_dir)
72
+ output_file = os.path.join(infer_dir, f'prediction_results_{args.test_name}.jsonl')
73
+
74
+ processor = ShowUIProcessor.from_pretrained(args.model_path)
75
+
76
+ model = model.to(torch.device(rank))
77
+ model = model.eval()
78
+
79
+ error_count = 0
80
+ correct_count = 0
81
+ pred_results = []
82
+
83
+
84
+ dataset = args.test_json
85
+ data = json.load(open(dataset, "r"))
86
+
87
+ data = data[rank::world_size]
88
+ print(f"Process {rank} handling {len(data)} samples", flush=True)
89
+
90
+ for j, item in tqdm(enumerate(data), total=len(data)):
91
+ image_path = os.path.join(args.image_path, item["img_filename"]) # 通过 args 传递路径
92
+ task_prompt = item["instruction"]
93
+
94
+ question_template_think = (
95
+ f"In this UI screenshot, I want to perform the command '{task_prompt}'.\n"
96
+ "Please provide the action to perform (enumerate in ['click', 'scroll']) and the coordinate where the cursor is moved to(integer) if click is performed.\n"
97
+ "Output the thinking process in <think> </think> and final answer in <answer> </answer> tags."
98
+ "The output answer format should be as follows:\n"
99
+ "<think> ... </think> <answer>[{'action': enum['click', 'scroll'], 'coordinate': [x, y]}]</answer>\n"
100
+ "Please strictly follow the format."
101
+ )
102
+ question_template = (
103
+ f"In this UI screenshot, I want to perform the command '{task_prompt}'.\n"
104
+ "Please provide the action to perform (enumerate in ['click'])"
105
+ "and the coordinate where the cursor is moved to(integer) if click is performed.\n"
106
+ "Output the final answer in <answer> </answer> tags directly."
107
+ "The output answer format should be as follows:\n"
108
+ "<answer>[{'action': 'click', 'coordinate': [x, y]}]</answer>\n"
109
+ "Please strictly follow the format."
110
+ )
111
+
112
+ query = '<image>\n' + question_template
113
+ messages = [
114
+ {
115
+ "role": "user",
116
+ "content": [
117
+ {"type": "image", "image": image_path}
118
+ ] + [{"type": "text", "text": query}],
119
+ }
120
+ ]
121
+
122
+ try:
123
+ text = processor.apply_chat_template(
124
+ messages, tokenize=False, add_generation_prompt=True
125
+ )
126
+ image_inputs, video_inputs = process_vision_info(messages)
127
+ #print("processor: ", processor)
128
+ #print("image_inputs shape: ", image_inputs.shape)
129
+ inputs = processor(
130
+ text=[text],
131
+ images=image_inputs,
132
+ videos=video_inputs,
133
+ padding=True,
134
+ return_tensors="pt",
135
+ )
136
+ # optional: resize coord due to image resize
137
+ resized_height = inputs['image_grid_thw'][0][1] * processor.image_processor.patch_size
138
+ resized_width = inputs['image_grid_thw'][0][2] * processor.image_processor.patch_size
139
+ origin_height = image_inputs[0].size[1]
140
+ origin_width = image_inputs[0].size[0]
141
+ scale_x = origin_width / resized_width
142
+ scale_y = origin_height / resized_height
143
+ inputs = inputs.to(model.device)
144
+
145
+ generated_ids = model.generate(**inputs, max_new_tokens=1024, use_cache=True)
146
+ generated_ids_trimmed = [
147
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
148
+ ]
149
+ response = processor.batch_decode(
150
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
151
+ )
152
+ response = response[0]
153
+ gt_bbox = item["bbox"]
154
+ pred_coord, _ = extract_coord(response)
155
+ pred_coord[0] = int(pred_coord[0] * scale_x)
156
+ pred_coord[1] = int(pred_coord[1] * scale_y)
157
+ #success = gt_bbox[0] <= pred_coord[0] <= gt_bbox[2] and gt_bbox[1] <= pred_coord[1] <= gt_bbox[3]
158
+ success = gt_bbox[0] <= pred_coord[0] <= (gt_bbox[0]+gt_bbox[2]) and gt_bbox[1] <= pred_coord[1] <= (gt_bbox[1]+gt_bbox[3])
159
+
160
+
161
+
162
+ if success:
163
+ correct_count += 1
164
+ else:
165
+ error_count += 1
166
+
167
+ new_pred_dict = {
168
+ 'image_id': item["img_filename"],
169
+ 'gt_bbox': gt_bbox,
170
+ 'pred_coord': pred_coord,
171
+ 'response': response,
172
+ 'pred_result': success
173
+ }
174
+ print("new_pred_dict: ", new_pred_dict)
175
+ with open(output_file, 'a') as json_file:
176
+ json.dump(new_pred_dict, json_file)
177
+ json_file.write('\n')
178
+ pred_results.append(new_pred_dict)
179
+
180
+ except Exception as e:
181
+ print(f"Process {rank} error: {e}", flush=True)
182
+ error_count += 1
183
+
184
+ return [error_count, correct_count, pred_results]
185
+
186
+ def main(args):
187
+ multiprocess = torch.cuda.device_count() >= 2
188
+ mp.set_start_method('spawn')
189
+
190
+ if multiprocess:
191
+ logger.info('Started generation')
192
+ n_gpus = torch.cuda.device_count()
193
+ world_size = n_gpus
194
+
195
+ with Pool(world_size) as pool:
196
+ func = functools.partial(run, world_size=world_size, args=args)
197
+ result_lists = pool.map(func, range(world_size))
198
+
199
+ global_count_error = 0
200
+ global_count_correct = 0
201
+ global_results = []
202
+
203
+ for i in range(world_size):
204
+ global_count_error += int(result_lists[i][0])
205
+ global_count_correct += int(result_lists[i][1])
206
+ global_results.extend(result_lists[i][2]) # 修正拼接方式
207
+
208
+ logger.info(f'Error number: {global_count_error}')
209
+
210
+ logger.info('Finished running')
211
+
212
+ else:
213
+ logger.info("Not enough GPUs")
214
+
215
+
216
+ if __name__ == "__main__":
217
+
218
+
219
+ parser = argparse.ArgumentParser()
220
+ parser.add_argument("--model_path", type=str, required=True)
221
+ parser.add_argument("--ori_processor_path", type=str, default=None)
222
+ parser.add_argument("--image_path", type=str, default=None)
223
+ parser.add_argument("--test_json", type=str, required=True)
224
+ parser.add_argument("--test_name", type=str, required=True)
225
+ args = parser.parse_args()
226
+ main(args)