UI-R1-ShowUI / test_screenspot_showui.py
Fahad-S's picture
Upload test_screenspot_showui.py with huggingface_hub
ee392c3 verified
from tqdm import tqdm
import os
import json
import argparse
import torch
import sys
sys.path.append("/proj/cvl/users/x_fahkh2/UI-R1-Extention/UI-R1/src/ui_r1/src/open_r1")
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor,Qwen2_5_VLForConditionalGeneration
#from ..showui import ShowUIForConditionalGeneration, ShowUIProcessor
from showui import ShowUIForConditionalGeneration
from showui import ShowUIProcessor
from qwen_vl_utils import process_vision_info
import sys
import re
import multiprocessing as mp
import logging
from multiprocessing import Pool
import functools
import torch.multiprocessing as mp
logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
rank = 0
def extract_coord(content):
# Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
answer_tag_pattern = r'<answer>(.*?)</answer>'
bbox_pattern = r'\{.*\[(\d+),\s*(\d+)]\s*.*\}'
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
if content_answer_match:
content_answer = content_answer_match.group(1).strip()
coord_match = re.search(bbox_pattern, content_answer)
if coord_match:
coord = [int(coord_match.group(1)), int(coord_match.group(2))]
return coord, True
else:
coord_pattern = r'\{.*\((\d+),\s*(\d+))\s*.*\}'
coord_match = re.search(coord_pattern, content)
if coord_match:
coord = [int(coord_match.group(1)), int(coord_match.group(2))]
return coord, True
return [0, 0, 0, 0], False
logger = logging.getLogger(__name__)
def run(rank, world_size, args):
model = ShowUIForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cpu")
'''
if "Qwen2.5" in args.model_path:
model = ShowUIForConditionalGeneration.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="cpu",
)
else:
model = Qwen2VLForConditionalGeneration.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="cpu",
)
'''
if args.ori_processor_path is None:
ori_processor_path = args.model_path
infer_dir = os.path.join(args.model_path,'infer')
if not os.path.exists(infer_dir):
os.makedirs(infer_dir)
output_file = os.path.join(infer_dir, f'prediction_results_{args.test_name}.jsonl')
processor = ShowUIProcessor.from_pretrained(args.model_path)
model = model.to(torch.device(rank))
model = model.eval()
error_count = 0
correct_count = 0
pred_results = []
dataset = args.test_json
data = json.load(open(dataset, "r"))
data = data[rank::world_size]
print(f"Process {rank} handling {len(data)} samples", flush=True)
for j, item in tqdm(enumerate(data), total=len(data)):
image_path = os.path.join(args.image_path, item["img_filename"]) # 通过 args 传递路径
task_prompt = item["instruction"]
question_template_think = (
f"In this UI screenshot, I want to perform the command '{task_prompt}'.\n"
"Please provide the action to perform (enumerate in ['click', 'scroll']) and the coordinate where the cursor is moved to(integer) if click is performed.\n"
"Output the thinking process in <think> </think> and final answer in <answer> </answer> tags."
"The output answer format should be as follows:\n"
"<think> ... </think> <answer>[{'action': enum['click', 'scroll'], 'coordinate': [x, y]}]</answer>\n"
"Please strictly follow the format."
)
question_template = (
f"In this UI screenshot, I want to perform the command '{task_prompt}'.\n"
"Please provide the action to perform (enumerate in ['click'])"
"and the coordinate where the cursor is moved to(integer) if click is performed.\n"
"Output the final answer in <answer> </answer> tags directly."
"The output answer format should be as follows:\n"
"<answer>[{'action': 'click', 'coordinate': [x, y]}]</answer>\n"
"Please strictly follow the format."
)
query = '<image>\n' + question_template
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path}
] + [{"type": "text", "text": query}],
}
]
try:
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
#print("processor: ", processor)
#print("image_inputs shape: ", image_inputs.shape)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# optional: resize coord due to image resize
resized_height = inputs['image_grid_thw'][0][1] * processor.image_processor.patch_size
resized_width = inputs['image_grid_thw'][0][2] * processor.image_processor.patch_size
origin_height = image_inputs[0].size[1]
origin_width = image_inputs[0].size[0]
scale_x = origin_width / resized_width
scale_y = origin_height / resized_height
inputs = inputs.to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=1024, use_cache=True)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
response = response[0]
gt_bbox = item["bbox"]
pred_coord, _ = extract_coord(response)
pred_coord[0] = int(pred_coord[0] * scale_x)
pred_coord[1] = int(pred_coord[1] * scale_y)
#success = gt_bbox[0] <= pred_coord[0] <= gt_bbox[2] and gt_bbox[1] <= pred_coord[1] <= gt_bbox[3]
success = gt_bbox[0] <= pred_coord[0] <= (gt_bbox[0]+gt_bbox[2]) and gt_bbox[1] <= pred_coord[1] <= (gt_bbox[1]+gt_bbox[3])
if success:
correct_count += 1
else:
error_count += 1
new_pred_dict = {
'image_id': item["img_filename"],
'gt_bbox': gt_bbox,
'pred_coord': pred_coord,
'response': response,
'pred_result': success
}
print("new_pred_dict: ", new_pred_dict)
with open(output_file, 'a') as json_file:
json.dump(new_pred_dict, json_file)
json_file.write('\n')
pred_results.append(new_pred_dict)
except Exception as e:
print(f"Process {rank} error: {e}", flush=True)
error_count += 1
return [error_count, correct_count, pred_results]
def main(args):
multiprocess = torch.cuda.device_count() >= 2
mp.set_start_method('spawn')
if multiprocess:
logger.info('Started generation')
n_gpus = torch.cuda.device_count()
world_size = n_gpus
with Pool(world_size) as pool:
func = functools.partial(run, world_size=world_size, args=args)
result_lists = pool.map(func, range(world_size))
global_count_error = 0
global_count_correct = 0
global_results = []
for i in range(world_size):
global_count_error += int(result_lists[i][0])
global_count_correct += int(result_lists[i][1])
global_results.extend(result_lists[i][2]) # 修正拼接方式
logger.info(f'Error number: {global_count_error}')
logger.info('Finished running')
else:
logger.info("Not enough GPUs")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--ori_processor_path", type=str, default=None)
parser.add_argument("--image_path", type=str, default=None)
parser.add_argument("--test_json", type=str, required=True)
parser.add_argument("--test_name", type=str, required=True)
args = parser.parse_args()
main(args)