File size: 3,803 Bytes
64807a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python3
from typing import Dict, List, Any
import random
import base64
import requests
from pathlib import Path
from huggingface_hub import HfFolder
from huggingface_hub import get_inference_endpoint, list_inference_endpoints
from transformers import pipeline
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoTokenizer
from PIL import Image
from peft import PeftModel
from io import BytesIO
import torch
import re



class EndpointHandler():
    def __init__(self, path="."):
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            trust_remote_code=True
        )
        self.processor = AutoProcessor.from_pretrained(
            path,
            trust_remote_code=True
        )
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        description = data.pop("description", None)
        image_base64 = data.pop("image_base64", None)
        image = Image.open(BytesIO(base64.b64decode(image_base64)))
        original_height, original_width = image.height, image.width
        conversation = [
            {
                "role": "system",
                "content": "You are a helpful assistant."
            },
            {
                "role": "user",
                "content": f"You are a GUI agent. You are given a screenshot of width {original_width} and height {original_height}. You need to locate the described element. \n\n## Output Format\n\nclick(start_box='(x,y)')\n\n## User Instruction\nClick the element with description '{description}'\n\n<|vision_start|><|image_pad|><|vision_end|>"
            }
        ]     
        formatted_text = self.processor.apply_chat_template(conversation, tokenize=False)
        if not self.processor.image_processor.min_pixels <= original_width * original_height <= self.processor.image_processor.max_pixels:
            return {"error": f"Image size {original_width * original_height} is not within the allowed range {self.processor.image_processor.min_pixels} to {self.processor.image_processor.max_pixels}"}
        # Process the image and text
        inputs = self.processor(
            text=formatted_text,
            images=[image],
            return_tensors="pt",
            padding=True,
            add_special_tokens=False,
            padding_side="left"
        )
        inputs = {k: v.cuda() for k, v in inputs.items()}
        # Generate response
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=False,  # Use greedy decoding # TODO: check tthis
                temperature=0.1,
                pad_token_id=self.tokenizer.pad_token_id,
                bad_words_ids=[[self.model.config.image_token_id]],
            )
        # Decode response (only the new tokens)
        input_length = inputs['input_ids'].shape[1]
        response_tokens = outputs[0][input_length:]
        response = self.processor.decode(response_tokens, skip_special_tokens=True)
        
        matches = re.findall(r"click\(start_box='\((\d+),(\d+)\)'\)", response)
        matches = [tuple(map(int, match)) for match in matches]
        if len(matches) != 1:
            print("not exactly 1 match found in:", response)
            return {"error": "not exactly 1 match found in: " + response}
        else:
            return {"x": matches[0][0], "y": matches[0][1]}