|
|
from tqdm import tqdm |
|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
import torch |
|
|
import sys |
|
|
sys.path.append("/proj/cvl/users/x_fahkh2/UI-R1-Extention/UI-R1/src/ui_r1/src/open_r1") |
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor,Qwen2_5_VLForConditionalGeneration |
|
|
|
|
|
from showui import ShowUIForConditionalGeneration |
|
|
from showui import ShowUIProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
import sys |
|
|
import re |
|
|
import multiprocessing as mp |
|
|
import logging |
|
|
from multiprocessing import Pool |
|
|
import functools |
|
|
import torch.multiprocessing as mp |
|
|
logging.basicConfig() |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(logging.INFO) |
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
|
|
|
|
rank = 0 |
|
|
def extract_coord(content): |
|
|
|
|
|
answer_tag_pattern = r'<answer>(.*?)</answer>' |
|
|
bbox_pattern = r'\{.*\[(\d+),\s*(\d+)]\s*.*\}' |
|
|
content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL) |
|
|
if content_answer_match: |
|
|
content_answer = content_answer_match.group(1).strip() |
|
|
coord_match = re.search(bbox_pattern, content_answer) |
|
|
if coord_match: |
|
|
coord = [int(coord_match.group(1)), int(coord_match.group(2))] |
|
|
return coord, True |
|
|
else: |
|
|
coord_pattern = r'\{.*\((\d+),\s*(\d+))\s*.*\}' |
|
|
coord_match = re.search(coord_pattern, content) |
|
|
if coord_match: |
|
|
coord = [int(coord_match.group(1)), int(coord_match.group(2))] |
|
|
return coord, True |
|
|
return [0, 0, 0, 0], False |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def run(rank, world_size, args): |
|
|
model = ShowUIForConditionalGeneration.from_pretrained(args.model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="cpu") |
|
|
''' |
|
|
if "Qwen2.5" in args.model_path: |
|
|
model = ShowUIForConditionalGeneration.from_pretrained( |
|
|
args.model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2", |
|
|
device_map="cpu", |
|
|
) |
|
|
else: |
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
args.model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2", |
|
|
device_map="cpu", |
|
|
) |
|
|
''' |
|
|
if args.ori_processor_path is None: |
|
|
ori_processor_path = args.model_path |
|
|
infer_dir = os.path.join(args.model_path,'infer') |
|
|
if not os.path.exists(infer_dir): |
|
|
os.makedirs(infer_dir) |
|
|
output_file = os.path.join(infer_dir, f'prediction_results_{args.test_name}.jsonl') |
|
|
|
|
|
processor = ShowUIProcessor.from_pretrained(args.model_path) |
|
|
|
|
|
model = model.to(torch.device(rank)) |
|
|
model = model.eval() |
|
|
|
|
|
error_count = 0 |
|
|
correct_count = 0 |
|
|
pred_results = [] |
|
|
|
|
|
|
|
|
dataset = args.test_json |
|
|
data = json.load(open(dataset, "r")) |
|
|
|
|
|
data = data[rank::world_size] |
|
|
print(f"Process {rank} handling {len(data)} samples", flush=True) |
|
|
|
|
|
for j, item in tqdm(enumerate(data), total=len(data)): |
|
|
image_path = os.path.join(args.image_path, item["img_filename"]) |
|
|
task_prompt = item["instruction"] |
|
|
|
|
|
question_template_think = ( |
|
|
f"In this UI screenshot, I want to perform the command '{task_prompt}'.\n" |
|
|
"Please provide the action to perform (enumerate in ['click', 'scroll']) and the coordinate where the cursor is moved to(integer) if click is performed.\n" |
|
|
"Output the thinking process in <think> </think> and final answer in <answer> </answer> tags." |
|
|
"The output answer format should be as follows:\n" |
|
|
"<think> ... </think> <answer>[{'action': enum['click', 'scroll'], 'coordinate': [x, y]}]</answer>\n" |
|
|
"Please strictly follow the format." |
|
|
) |
|
|
question_template = ( |
|
|
f"In this UI screenshot, I want to perform the command '{task_prompt}'.\n" |
|
|
"Please provide the action to perform (enumerate in ['click'])" |
|
|
"and the coordinate where the cursor is moved to(integer) if click is performed.\n" |
|
|
"Output the final answer in <answer> </answer> tags directly." |
|
|
"The output answer format should be as follows:\n" |
|
|
"<answer>[{'action': 'click', 'coordinate': [x, y]}]</answer>\n" |
|
|
"Please strictly follow the format." |
|
|
) |
|
|
|
|
|
query = '<image>\n' + question_template |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image_path} |
|
|
] + [{"type": "text", "text": query}], |
|
|
} |
|
|
] |
|
|
|
|
|
try: |
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
resized_height = inputs['image_grid_thw'][0][1] * processor.image_processor.patch_size |
|
|
resized_width = inputs['image_grid_thw'][0][2] * processor.image_processor.patch_size |
|
|
origin_height = image_inputs[0].size[1] |
|
|
origin_width = image_inputs[0].size[0] |
|
|
scale_x = origin_width / resized_width |
|
|
scale_y = origin_height / resized_height |
|
|
inputs = inputs.to(model.device) |
|
|
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=1024, use_cache=True) |
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
response = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
) |
|
|
response = response[0] |
|
|
gt_bbox = item["bbox"] |
|
|
pred_coord, _ = extract_coord(response) |
|
|
pred_coord[0] = int(pred_coord[0] * scale_x) |
|
|
pred_coord[1] = int(pred_coord[1] * scale_y) |
|
|
|
|
|
success = gt_bbox[0] <= pred_coord[0] <= (gt_bbox[0]+gt_bbox[2]) and gt_bbox[1] <= pred_coord[1] <= (gt_bbox[1]+gt_bbox[3]) |
|
|
|
|
|
|
|
|
|
|
|
if success: |
|
|
correct_count += 1 |
|
|
else: |
|
|
error_count += 1 |
|
|
|
|
|
new_pred_dict = { |
|
|
'image_id': item["img_filename"], |
|
|
'gt_bbox': gt_bbox, |
|
|
'pred_coord': pred_coord, |
|
|
'response': response, |
|
|
'pred_result': success |
|
|
} |
|
|
print("new_pred_dict: ", new_pred_dict) |
|
|
with open(output_file, 'a') as json_file: |
|
|
json.dump(new_pred_dict, json_file) |
|
|
json_file.write('\n') |
|
|
pred_results.append(new_pred_dict) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Process {rank} error: {e}", flush=True) |
|
|
error_count += 1 |
|
|
|
|
|
return [error_count, correct_count, pred_results] |
|
|
|
|
|
def main(args): |
|
|
multiprocess = torch.cuda.device_count() >= 2 |
|
|
mp.set_start_method('spawn') |
|
|
|
|
|
if multiprocess: |
|
|
logger.info('Started generation') |
|
|
n_gpus = torch.cuda.device_count() |
|
|
world_size = n_gpus |
|
|
|
|
|
with Pool(world_size) as pool: |
|
|
func = functools.partial(run, world_size=world_size, args=args) |
|
|
result_lists = pool.map(func, range(world_size)) |
|
|
|
|
|
global_count_error = 0 |
|
|
global_count_correct = 0 |
|
|
global_results = [] |
|
|
|
|
|
for i in range(world_size): |
|
|
global_count_error += int(result_lists[i][0]) |
|
|
global_count_correct += int(result_lists[i][1]) |
|
|
global_results.extend(result_lists[i][2]) |
|
|
|
|
|
logger.info(f'Error number: {global_count_error}') |
|
|
|
|
|
logger.info('Finished running') |
|
|
|
|
|
else: |
|
|
logger.info("Not enough GPUs") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--model_path", type=str, required=True) |
|
|
parser.add_argument("--ori_processor_path", type=str, default=None) |
|
|
parser.add_argument("--image_path", type=str, default=None) |
|
|
parser.add_argument("--test_json", type=str, required=True) |
|
|
parser.add_argument("--test_name", type=str, required=True) |
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|