JunzheJosephZhu's picture
handler
64807a4
#!/usr/bin/env python3
from typing import Dict, List, Any
import random
import base64
import requests
from pathlib import Path
from huggingface_hub import HfFolder
from huggingface_hub import get_inference_endpoint, list_inference_endpoints
from transformers import pipeline
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoTokenizer
from PIL import Image
from peft import PeftModel
from io import BytesIO
import torch
import re
class EndpointHandler():
def __init__(self, path="."):
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(
path,
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
description = data.pop("description", None)
image_base64 = data.pop("image_base64", None)
image = Image.open(BytesIO(base64.b64decode(image_base64)))
original_height, original_width = image.height, image.width
conversation = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": f"You are a GUI agent. You are given a screenshot of width {original_width} and height {original_height}. You need to locate the described element. \n\n## Output Format\n\nclick(start_box='(x,y)')\n\n## User Instruction\nClick the element with description '{description}'\n\n<|vision_start|><|image_pad|><|vision_end|>"
}
]
formatted_text = self.processor.apply_chat_template(conversation, tokenize=False)
if not self.processor.image_processor.min_pixels <= original_width * original_height <= self.processor.image_processor.max_pixels:
return {"error": f"Image size {original_width * original_height} is not within the allowed range {self.processor.image_processor.min_pixels} to {self.processor.image_processor.max_pixels}"}
# Process the image and text
inputs = self.processor(
text=formatted_text,
images=[image],
return_tensors="pt",
padding=True,
add_special_tokens=False,
padding_side="left"
)
inputs = {k: v.cuda() for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=100,
do_sample=False, # Use greedy decoding # TODO: check tthis
temperature=0.1,
pad_token_id=self.tokenizer.pad_token_id,
bad_words_ids=[[self.model.config.image_token_id]],
)
# Decode response (only the new tokens)
input_length = inputs['input_ids'].shape[1]
response_tokens = outputs[0][input_length:]
response = self.processor.decode(response_tokens, skip_special_tokens=True)
matches = re.findall(r"click\(start_box='\((\d+),(\d+)\)'\)", response)
matches = [tuple(map(int, match)) for match in matches]
if len(matches) != 1:
print("not exactly 1 match found in:", response)
return {"error": "not exactly 1 match found in: " + response}
else:
return {"x": matches[0][0], "y": matches[0][1]}