File size: 2,735 Bytes
854eade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import base64
from io import BytesIO

import requests
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration


class EndpointHandler:
    def __init__(self, path=""):
        dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            path,
            torch_dtype=dtype,
            device_map="auto",
        )
        self.processor = AutoProcessor.from_pretrained(path)

    def _load_image(self, image_ref):
        if image_ref is None:
            raise ValueError("Missing image. Please provide `inputs.image_url` or `inputs.image_base64`.")

        if isinstance(image_ref, str) and image_ref.startswith("http"):
            resp = requests.get(image_ref, timeout=30)
            resp.raise_for_status()
            return Image.open(BytesIO(resp.content)).convert("RGB")

        if isinstance(image_ref, str) and image_ref.startswith("data:image"):
            _, b64data = image_ref.split(",", 1)
            return Image.open(BytesIO(base64.b64decode(b64data))).convert("RGB")

        # 默认当作本地路径处理
        return Image.open(image_ref).convert("RGB")

    def __call__(self, data):
        payload = data.get("inputs", {}) or {}

        prompt = payload.get("prompt", "Please analyze this image and infer its location.")
        image_url = payload.get("image_url")
        image_base64 = payload.get("image_base64")
        max_new_tokens = int(payload.get("max_new_tokens", 256))

        image = self._load_image(image_url or image_base64)

        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": prompt},
                ],
            }
        ]

        text = self.processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        model_inputs = self.processor(
            text=[text],
            images=[image],
            return_tensors="pt",
        ).to(self.model.device)

        with torch.no_grad():
            output_ids = self.model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
            )

        generated_ids = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(model_inputs.input_ids, output_ids)
        ]

        output_text = self.processor.batch_decode(
            generated_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
        )[0]

        return {
            "generated_text": output_text
        }