| | |
| | from typing import Dict, List, Any |
| | import random |
| | import base64 |
| | import requests |
| | from pathlib import Path |
| | from huggingface_hub import HfFolder |
| | from huggingface_hub import get_inference_endpoint, list_inference_endpoints |
| | from transformers import pipeline |
| | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoTokenizer |
| | from PIL import Image |
| | from peft import PeftModel |
| | from io import BytesIO |
| | import torch |
| | import re |
| |
|
| |
|
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path="."): |
| | self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| | path, |
| | device_map="auto", |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True |
| | ) |
| | self.processor = AutoProcessor.from_pretrained( |
| | path, |
| | trust_remote_code=True |
| | ) |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.model.eval() |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str`) |
| | date (:obj: `str`) |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | description = data.pop("description", None) |
| | image_base64 = data.pop("image_base64", None) |
| | image = Image.open(BytesIO(base64.b64decode(image_base64))) |
| | original_height, original_width = image.height, image.width |
| | conversation = [ |
| | { |
| | "role": "system", |
| | "content": "You are a helpful assistant." |
| | }, |
| | { |
| | "role": "user", |
| | "content": f"You are a GUI agent. You are given a screenshot of width {original_width} and height {original_height}. You need to locate the described element. \n\n## Output Format\n\nclick(start_box='(x,y)')\n\n## User Instruction\nClick the element with description '{description}'\n\n<|vision_start|><|image_pad|><|vision_end|>" |
| | } |
| | ] |
| | formatted_text = self.processor.apply_chat_template(conversation, tokenize=False) |
| | if not self.processor.image_processor.min_pixels <= original_width * original_height <= self.processor.image_processor.max_pixels: |
| | return {"error": f"Image size {original_width * original_height} is not within the allowed range {self.processor.image_processor.min_pixels} to {self.processor.image_processor.max_pixels}"} |
| | |
| | inputs = self.processor( |
| | text=formatted_text, |
| | images=[image], |
| | return_tensors="pt", |
| | padding=True, |
| | add_special_tokens=False, |
| | padding_side="left" |
| | ) |
| | inputs = {k: v.cuda() for k, v in inputs.items()} |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=100, |
| | do_sample=False, |
| | temperature=0.1, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | bad_words_ids=[[self.model.config.image_token_id]], |
| | ) |
| | |
| | input_length = inputs['input_ids'].shape[1] |
| | response_tokens = outputs[0][input_length:] |
| | response = self.processor.decode(response_tokens, skip_special_tokens=True) |
| | |
| | matches = re.findall(r"click\(start_box='\((\d+),(\d+)\)'\)", response) |
| | matches = [tuple(map(int, match)) for match in matches] |
| | if len(matches) != 1: |
| | print("not exactly 1 match found in:", response) |
| | return {"error": "not exactly 1 match found in: " + response} |
| | else: |
| | return {"x": matches[0][0], "y": matches[0][1]} |
| |
|