Create handler.py

#1
by EugeneZhao - opened
Files changed (1) hide show
  1. handler.py +88 -0
handler.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+
4
+ import requests
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
13
+
14
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
15
+ path,
16
+ torch_dtype=dtype,
17
+ device_map="auto",
18
+ )
19
+ self.processor = AutoProcessor.from_pretrained(path)
20
+
21
+ def _load_image(self, image_ref):
22
+ if image_ref is None:
23
+ raise ValueError("Missing image. Please provide `inputs.image_url` or `inputs.image_base64`.")
24
+
25
+ if isinstance(image_ref, str) and image_ref.startswith("http"):
26
+ resp = requests.get(image_ref, timeout=30)
27
+ resp.raise_for_status()
28
+ return Image.open(BytesIO(resp.content)).convert("RGB")
29
+
30
+ if isinstance(image_ref, str) and image_ref.startswith("data:image"):
31
+ _, b64data = image_ref.split(",", 1)
32
+ return Image.open(BytesIO(base64.b64decode(b64data))).convert("RGB")
33
+
34
+ # 默认当作本地路径处理
35
+ return Image.open(image_ref).convert("RGB")
36
+
37
+ def __call__(self, data):
38
+ payload = data.get("inputs", {}) or {}
39
+
40
+ prompt = payload.get("prompt", "Please analyze this image and infer its location.")
41
+ image_url = payload.get("image_url")
42
+ image_base64 = payload.get("image_base64")
43
+ max_new_tokens = int(payload.get("max_new_tokens", 256))
44
+
45
+ image = self._load_image(image_url or image_base64)
46
+
47
+ messages = [
48
+ {
49
+ "role": "user",
50
+ "content": [
51
+ {"type": "image"},
52
+ {"type": "text", "text": prompt},
53
+ ],
54
+ }
55
+ ]
56
+
57
+ text = self.processor.apply_chat_template(
58
+ messages,
59
+ tokenize=False,
60
+ add_generation_prompt=True,
61
+ )
62
+
63
+ model_inputs = self.processor(
64
+ text=[text],
65
+ images=[image],
66
+ return_tensors="pt",
67
+ ).to(self.model.device)
68
+
69
+ with torch.no_grad():
70
+ output_ids = self.model.generate(
71
+ **model_inputs,
72
+ max_new_tokens=max_new_tokens,
73
+ )
74
+
75
+ generated_ids = [
76
+ out_ids[len(in_ids):]
77
+ for in_ids, out_ids in zip(model_inputs.input_ids, output_ids)
78
+ ]
79
+
80
+ output_text = self.processor.batch_decode(
81
+ generated_ids,
82
+ skip_special_tokens=True,
83
+ clean_up_tokenization_spaces=True,
84
+ )[0]
85
+
86
+ return {
87
+ "generated_text": output_text
88
+ }