CanerDedeoglu commited on
Commit
547115d
·
verified ·
1 Parent(s): 92e08dc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +95 -20
handler.py CHANGED
@@ -1,31 +1,106 @@
 
 
 
 
1
  import torch
2
  from PIL import Image
3
- from transformers import AutoProcessor, AutoModelForCausalLM
4
- import requests
5
 
6
  class EndpointHandler:
7
- def __init__(self, path=""):
8
- model_id = "PULSE-ECG/PULSE-7B"
9
- self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
10
- self.model = AutoModelForCausalLM.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.float16,
13
  device_map="auto",
14
- trust_remote_code=True
 
15
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def __call__(self, data: dict) -> dict:
18
- image_url = data.get("image_url")
19
- text = data.get("text", "Interpret this ECG image.")
 
 
 
 
 
 
 
 
 
 
20
 
21
- if not image_url:
22
- return {"error": "No image_url provided"}
 
23
 
24
- image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
25
- inputs = self.processor(text=text, images=image, return_tensors="pt").to(self.model.device)
26
 
27
- with torch.no_grad():
28
- outputs = self.model.generate(**inputs, max_new_tokens=512)
 
29
 
30
- result = self.processor.decode(outputs[0], skip_special_tokens=True)
31
- return {"result": result}
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import requests
5
  import torch
6
  from PIL import Image
7
+ from transformers import pipeline
 
8
 
9
  class EndpointHandler:
10
+ def __init__(self, path: str = ""):
11
+ # PULSE-7B (PyTorch/safetensors) + LLaVA tabanlı VLM
12
+ self.pipe = pipeline(
13
+ task="image-text-to-text",
14
+ model="PULSE-ECG/PULSE-7B",
 
15
  device_map="auto",
16
+ torch_dtype="auto",
17
+ trust_remote_code=True,
18
  )
19
+ # Uyarıları susturmak için pad_token
20
+ try:
21
+ if getattr(self.pipe.model.generation_config, "pad_token_id", None) is None:
22
+ self.pipe.model.generation_config.pad_token_id = self.pipe.model.config.eos_token_id
23
+ except Exception:
24
+ pass
25
+
26
+ # ---- yardımcılar ----
27
+ def _load_image(self, src: str) -> Image.Image:
28
+ """HTTP URL veya data URL (base64) kabul eder."""
29
+ if src.startswith("data:image/"):
30
+ # data:image/png;base64,xxxx
31
+ b64 = src.split(",", 1)[1]
32
+ return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
33
+ # HTTP(S) URL
34
+ r = requests.get(src, stream=True, timeout=20)
35
+ r.raise_for_status()
36
+ return Image.open(r.raw).convert("RGB")
37
+
38
+ def _normalize_inputs(self, data: dict):
39
+ """
40
+ Aşağıdaki iki biçimi de destekler:
41
+ 1) {"image_url": "...", "text": "..."}
42
+ 2) {
43
+ "inputs": [
44
+ {"role":"user","content":[
45
+ {"type":"image","image_url":"..."},
46
+ {"type":"text","text":"..."}
47
+ ]}
48
+ ],
49
+ "parameters": {...}
50
+ }
51
+ """
52
+ # Biçim-1
53
+ if "image_url" in data or "text" in data:
54
+ image_url = data.get("image_url")
55
+ text = data.get("text", "Interpret this ECG image.")
56
+ if not image_url:
57
+ raise ValueError("No image_url provided")
58
+ return [
59
+ {"role": "user", "content": [
60
+ {"type": "image", "image_url": image_url},
61
+ {"type": "text", "text": text},
62
+ ]}
63
+ ], data.get("parameters", {})
64
+
65
+ # Biçim-2 (multimodal chat)
66
+ if "inputs" in data:
67
+ return data.get("inputs", []), data.get("parameters", {})
68
+
69
+ raise ValueError("Invalid payload: expected 'image_url'+'text' or 'inputs' format.")
70
+
71
+ # ---- giriş noktası ----
72
+ def __call__(self, data: dict):
73
+ """
74
+ Dönen sonuç, TGI/Toolkit ile uyumlu ham çıktı olur.
75
+ Client, aşağıdaki iki şemadan birini gönderebilir:
76
+
77
+ # Şema-1 (basit)
78
+ { "image_url": "https://.../ecg.png", "text": "Kısa yorum yaz", "parameters": {"max_new_tokens":256} }
79
 
80
+ # Şema-2 (multimodal chat)
81
+ {
82
+ "inputs": [
83
+ {"role":"user","content":[
84
+ {"type":"image","image_url":"https://.../ecg.png"},
85
+ {"type":"text","text":"Kısa yorum yaz"}
86
+ ]}
87
+ ],
88
+ "parameters": {"max_new_tokens":256, "temperature":0.2}
89
+ }
90
+ """
91
+ # 1) İstek şemasını normalize et
92
+ msgs, params = self._normalize_inputs(data)
93
 
94
+ # 2) Eğer image_url base64/URL ise pipe doğrudan kabul eder;
95
+ # ekstra dönüştürme gerekmiyor (pipeline image-text-to-text bunu destekler).
96
+ # Ancak bazı eski sürümlerde local image objesi istenirse yukarıdaki _load_image kullanılabilir.
97
 
98
+ # 3) Varsayılan parametreler
99
+ params = {"max_new_tokens": 512, "temperature": 0.2, **(params or {})}
100
 
101
+ # 4) Çalıştır
102
+ with torch.inference_mode():
103
+ out = self.pipe(msgs, **params)
104
 
105
+ # pipe, genellikle metni doğrudan string veya dict listesi olarak döndürür.
106
+ return out