CanerDedeoglu commited on
Commit
6183d4f
·
verified ·
1 Parent(s): c20029c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +87 -16
handler.py CHANGED
@@ -1,12 +1,12 @@
1
  # -*- coding: utf-8 -*-
2
  # handler.py — Rapid_ECG / PULSE-7B için HF Inference Endpoints custom handler
3
  # - LLAVA otomatik kurulum (requirements'a yazmak zorunda değilsin)
4
- # - EndpointHandler.load()/__call__ sözleşmesi
 
5
  # - URL / base64 / yerel yol görüntü girişi
6
  # - <image> sentineli (+ IM_START/END gerekiyorsa)
7
- # - attention_mask fix (NoneType.new_ones önler)
8
- # - CUDA'da bf16/fp16, CPU'da fp32; echo-fix (sadece yeni tokenları decode)
9
-
10
  import os
11
  import io
12
  import sys
@@ -18,7 +18,7 @@ import torch
18
  from PIL import Image
19
  import requests
20
 
21
- # ===== LLaVA: handler içinden kur (tag'e sabitle) =====
22
  def _ensure_llava(tag: str = "v1.2.0"):
23
  try:
24
  import llava # noqa
@@ -43,6 +43,12 @@ from llava.constants import (
43
  from llava.model.builder import load_pretrained_model
44
  from llava.mm_utils import process_images, tokenizer_image_token
45
 
 
 
 
 
 
 
46
 
47
  # ---------- yardımcılar ----------
48
  def _get_env(name: str, default: Optional[str] = None) -> Optional[str]:
@@ -99,8 +105,6 @@ def _load_image_from_any(image_input: Any) -> Image.Image:
99
  return Image.open(s).convert("RGB")
100
  raise ValueError(f"Unsupported image input type: {type(image_input)}")
101
 
102
-
103
- # --- Senin istediğin: güvenli conv template & prompt build ---
104
  def _get_conv_mode(model_name: str) -> str:
105
  name = (model_name or "").lower()
106
  if "llama-2" in name:
@@ -123,18 +127,77 @@ def _build_prompt_with_image(prompt: str, model_cfg) -> str:
123
  return f"{token}\n{prompt}"
124
  return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # ---------- Endpoint Handler ----------
128
  class EndpointHandler:
129
  """
130
  HF Inference Toolkit çağrı akışı:
131
- handler = EndpointHandler()
132
  handler.load()
133
  handler(inputs_dict)
134
  """
135
 
136
  def __init__(self, model_dir: Optional[str] = None):
137
- self.model_dir = model_dir # HF endpoint burayı geçiriyor
138
  self.model = None
139
  self.tokenizer = None
140
  self.image_processor = None
@@ -142,17 +205,14 @@ class EndpointHandler:
142
  self.device = _pick_device()
143
  self.dtype = _pick_dtype(self.device)
144
  self.model_name = None
145
-
146
  def load(self):
147
- # Model seçimleri (ENV ile yönetilebilir)
148
  model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
149
  model_base = _get_env("HF_MODEL_BASE", None)
150
 
151
- # (varsa) flash-attn ipuçları — yoksa zarar vermez
152
  os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
153
  os.environ.setdefault("FLASH_ATTENTION", "1")
154
 
155
- # Modeli yükle
156
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
157
  model_path=model_path,
158
  model_base=model_base,
@@ -163,6 +223,9 @@ class EndpointHandler:
163
  )
164
  self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
165
 
 
 
 
166
  # tokenizer güvenliği
167
  try:
168
  self.tokenizer.padding_side = "left"
@@ -176,7 +239,7 @@ class EndpointHandler:
176
 
177
  @torch.inference_mode()
178
  def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
179
-
180
  if "inputs" in inputs and isinstance(inputs["inputs"], dict):
181
  inputs = inputs["inputs"]
182
 
@@ -202,11 +265,19 @@ class EndpointHandler:
202
  images = [image]
203
  image_sizes = [image.size]
204
 
205
- # process_images -> tensör
206
  try:
 
 
 
207
  images_tensor = process_images(images, self.image_processor, self.model.config)
208
  except Exception:
209
- images_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
 
 
 
 
 
210
  images_tensor = images_tensor.to(self.device, dtype=self.dtype)
211
 
212
  # ---- konuşma şablonu + prompt
 
1
  # -*- coding: utf-8 -*-
2
  # handler.py — Rapid_ECG / PULSE-7B için HF Inference Endpoints custom handler
3
  # - LLAVA otomatik kurulum (requirements'a yazmak zorunda değilsin)
4
+ # - EndpointHandler.load()/__call__(inputs) sözleşmesi
5
+ # - {"inputs": {...}} ve düz payload formatlarını destekler
6
  # - URL / base64 / yerel yol görüntü girişi
7
  # - <image> sentineli (+ IM_START/END gerekiyorsa)
8
+ # - attention_mask fix + echo-fix
9
+ # - CUDA: bf16/fp16, CPU: fp32
 
10
  import os
11
  import io
12
  import sys
 
18
  from PIL import Image
19
  import requests
20
 
21
+ # ===== LLaVA: handler içinden kur =====
22
  def _ensure_llava(tag: str = "v1.2.0"):
23
  try:
24
  import llava # noqa
 
43
  from llava.model.builder import load_pretrained_model
44
  from llava.mm_utils import process_images, tokenizer_image_token
45
 
46
+ # (gerekirse fallback için)
47
+ try:
48
+ from transformers import AutoProcessor, CLIPImageProcessor # type: ignore
49
+ except Exception:
50
+ AutoProcessor = None
51
+ CLIPImageProcessor = None
52
 
53
  # ---------- yardımcılar ----------
54
  def _get_env(name: str, default: Optional[str] = None) -> Optional[str]:
 
105
  return Image.open(s).convert("RGB")
106
  raise ValueError(f"Unsupported image input type: {type(image_input)}")
107
 
 
 
108
  def _get_conv_mode(model_name: str) -> str:
109
  name = (model_name or "").lower()
110
  if "llama-2" in name:
 
127
  return f"{token}\n{prompt}"
128
  return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
129
 
130
+ # ---- image_processor yoksa oluşturmak için yardımcılar ----
131
+ def _maybe_get_vision_tower_from_cfg(cfg) -> Optional[str]:
132
+ vt = getattr(cfg, "vision_tower", None)
133
+ if isinstance(vt, (list, tuple)) and vt:
134
+ return str(vt[0])
135
+ if isinstance(vt, str):
136
+ return vt
137
+ return _get_env("HF_VISION_TOWER_ID", None)
138
+
139
+ class _ProcessorWrapper:
140
+ """AutoProcessor/FeatureExtractor için .preprocess uyum katmanı."""
141
+ def __init__(self, proc):
142
+ self.proc = proc
143
+ def preprocess(self, image, return_tensors="pt"):
144
+ out = self.proc(image, return_tensors=return_tensors)
145
+ # AutoProcessor bazen dict döner, bazen tensor; normalize edelim
146
+ if isinstance(out, dict):
147
+ return out
148
+ return {"pixel_values": out}
149
+
150
+ def _ensure_image_processor(image_processor, model_cfg, model_path: str):
151
+ if image_processor is not None:
152
+ # bazı AutoProcessor'larda gerçek işleyici proc.image_processor altında
153
+ if hasattr(image_processor, "preprocess"):
154
+ return image_processor
155
+ if hasattr(image_processor, "image_processor"):
156
+ ip = image_processor.image_processor
157
+ if hasattr(ip, "preprocess"):
158
+ return ip
159
+ return _ProcessorWrapper(ip)
160
+ if callable(image_processor):
161
+ return _ProcessorWrapper(image_processor)
162
+
163
+ # 1) AutoProcessor (trust_remote_code ile) dene
164
+ if AutoProcessor is not None:
165
+ try:
166
+ proc = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
167
+ if hasattr(proc, "preprocess"):
168
+ return proc
169
+ if hasattr(proc, "image_processor"):
170
+ ip = proc.image_processor
171
+ if hasattr(ip, "preprocess"):
172
+ return ip
173
+ return _ProcessorWrapper(ip)
174
+ return _ProcessorWrapper(proc)
175
+ except Exception:
176
+ pass
177
+
178
+ # 2) Vision tower'dan CLIPImageProcessor üret
179
+ vt = _maybe_get_vision_tower_from_cfg(model_cfg)
180
+ if vt and CLIPImageProcessor is not None:
181
+ try:
182
+ ip = CLIPImageProcessor.from_pretrained(vt)
183
+ return ip
184
+ except Exception:
185
+ pass
186
+
187
+ # 3) en sonda None kalsın; çağrı tarafında fallback var
188
+ return None
189
 
190
  # ---------- Endpoint Handler ----------
191
  class EndpointHandler:
192
  """
193
  HF Inference Toolkit çağrı akışı:
194
+ handler = EndpointHandler(model_dir)
195
  handler.load()
196
  handler(inputs_dict)
197
  """
198
 
199
  def __init__(self, model_dir: Optional[str] = None):
200
+ self.model_dir = model_dir
201
  self.model = None
202
  self.tokenizer = None
203
  self.image_processor = None
 
205
  self.device = _pick_device()
206
  self.dtype = _pick_dtype(self.device)
207
  self.model_name = None
208
+
209
  def load(self):
 
210
  model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
211
  model_base = _get_env("HF_MODEL_BASE", None)
212
 
 
213
  os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
214
  os.environ.setdefault("FLASH_ATTENTION", "1")
215
 
 
216
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
217
  model_path=model_path,
218
  model_base=model_base,
 
223
  )
224
  self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
225
 
226
+ # image_processor fallback (kritik!)
227
+ self.image_processor = _ensure_image_processor(self.image_processor, self.model.config, model_path)
228
+
229
  # tokenizer güvenliği
230
  try:
231
  self.tokenizer.padding_side = "left"
 
239
 
240
  @torch.inference_mode()
241
  def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
242
+ # HF bazen payload'ı {"inputs": {...}} diye sarar
243
  if "inputs" in inputs and isinstance(inputs["inputs"], dict):
244
  inputs = inputs["inputs"]
245
 
 
265
  images = [image]
266
  image_sizes = [image.size]
267
 
268
+ # process_images -> tensör (image_processor None olabilir; o zaman plain preprocess)
269
  try:
270
+ if self.image_processor is None:
271
+ # en kaba yedek: AutoProcessor başarısız olduysa
272
+ raise RuntimeError("image_processor is None")
273
  images_tensor = process_images(images, self.image_processor, self.model.config)
274
  except Exception:
275
+ # plain preprocess
276
+ if hasattr(self.image_processor, "preprocess"):
277
+ images_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
278
+ else:
279
+ # en son çare: AutoProcessor benzeri çağrı
280
+ images_tensor = _ProcessorWrapper(self.image_processor).preprocess(image, return_tensors="pt")["pixel_values"]
281
  images_tensor = images_tensor.to(self.device, dtype=self.dtype)
282
 
283
  # ---- konuşma şablonu + prompt