CanerDedeoglu commited on
Commit
ca30e1a
·
verified ·
1 Parent(s): e7fc237

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +89 -45
handler.py CHANGED
@@ -1,8 +1,9 @@
1
  # -*- coding: utf-8 -*-
2
- # handler.py — Rapid_ECG / PULSE-7B — Stabil ve DEBUG'li sürüm (local/hub + vision tower fix)
3
- # - HF Endpoint uyumlu (EndpointHandler.load().__call__)
4
- # - Yerel klasörden (HF_MODEL_DIR) veya hub'dan (HF_MODEL_ID) yükleme
5
- # - Görsel sadece .preprocess() ile işlenir
 
6
  # - Vision tower kontrolü: mm_vision_tower veya vision_tower
7
  # - IMAGE_TOKEN_INDEX kullanımı ve kapsamlı [DEBUG] logları
8
 
@@ -17,7 +18,8 @@ import torch
17
  from PIL import Image
18
  import requests
19
 
20
- # ===== LLaVA kütüphanesi (gerekirse kur) =====
 
21
  def _ensure_llava(tag: str = "v1.2.0"):
22
  try:
23
  import llava # noqa
@@ -86,16 +88,23 @@ def _load_image_from_any(image_input: Any) -> Image.Image:
86
  if isinstance(image_input, str):
87
  s = image_input.strip()
88
  if s.startswith("data:image"):
89
- _, b64 = s.split(",", 1)
90
- data = base64.b64decode(b64)
91
- return Image.open(io.BytesIO(data)).convert("RGB")
 
 
 
92
  if _is_probably_base64(s) and not s.startswith(("http://", "https://")):
93
- data = base64.b64decode(s)
94
- return Image.open(io.BytesIO(data)).convert("RGB")
 
 
 
95
  if s.startswith(("http://", "https://")):
96
  resp = requests.get(s, timeout=20)
97
  resp.raise_for_status()
98
  return Image.open(io.BytesIO(resp.content)).convert("RGB")
 
99
  return Image.open(s).convert("RGB")
100
  raise ValueError(f"Unsupported image input type: {type(image_input)}")
101
 
@@ -114,7 +123,7 @@ def _get_conv_mode(model_name: str) -> str:
114
  return "llava_v0"
115
 
116
  def _build_prompt_with_image(prompt: str, model_cfg) -> str:
117
- # Kullanıcı prompt'a image token eklediyse yeniden eklemeyelim
118
  if DEFAULT_IMAGE_TOKEN in prompt or DEFAULT_IM_START_TOKEN in prompt:
119
  return prompt
120
  if getattr(model_cfg, "mm_use_im_start_end", False):
@@ -123,7 +132,7 @@ def _build_prompt_with_image(prompt: str, model_cfg) -> str:
123
  return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
124
 
125
  def _resolve_model_path(model_dir_hint: Optional[str], default_dir: str = "/repository") -> str:
126
- # Öncelik sırası: HF_MODEL_DIR (yerel) -> verilen model_dir_hint -> default_dir
127
  p = _get_env("HF_MODEL_DIR") or model_dir_hint or default_dir
128
  p = os.path.abspath(p)
129
  print(f"[DEBUG] resolved model path: {p}")
@@ -133,22 +142,44 @@ def _resolve_model_path(model_dir_hint: Optional[str], default_dir: str = "/repo
133
  # ---------- Endpoint Handler ----------
134
  class EndpointHandler:
135
  def __init__(self, model_dir: Optional[str] = None):
136
- print(f"[DEBUG] EndpointHandler.__init__ model_dir={model_dir}")
 
 
 
 
 
 
 
 
 
 
137
  self.model_dir = model_dir
 
 
 
 
 
 
 
 
 
138
  self.model = None
139
  self.tokenizer = None
140
  self.image_processor = None
141
  self.context_len = None
142
- self.device = _pick_device()
143
- self.dtype = _pick_dtype(self.device)
144
  self.model_name = None
145
 
146
- def load(self):
147
- """
148
- Yükleme stratejisi:
149
- - Eğer HF_MODEL_DIR set edilmişse veya repo kökünde ağırlıklar varsa: YERELDEN yükle.
150
- - Aksi halde HF_MODEL_ID ile hub'dan yükle.
151
- """
 
 
 
 
 
152
  local_path = _resolve_model_path(self.model_dir)
153
  use_local = os.path.isdir(local_path) and any(
154
  os.path.exists(os.path.join(local_path, f))
@@ -156,9 +187,6 @@ class EndpointHandler:
156
  )
157
  model_base = _get_env("HF_MODEL_BASE", None)
158
 
159
- os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
160
- os.environ.setdefault("FLASH_ATTENTION", "1")
161
-
162
  if use_local:
163
  model_path = local_path
164
  print(f"[DEBUG] loading model LOCALLY from: {model_path}")
@@ -166,7 +194,6 @@ class EndpointHandler:
166
  model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
167
  print(f"[DEBUG] loading model from HUB: {model_path} (HF_MODEL_BASE={model_base})")
168
 
169
- # Modeli yükle
170
  print("[DEBUG] calling load_pretrained_model ...")
171
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
172
  model_path=model_path,
@@ -179,7 +206,7 @@ class EndpointHandler:
179
  self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
180
  print(f"[DEBUG] model loaded: name={self.model_name}")
181
 
182
- # ---- Vision tower kontrolü: mm_vision_tower veya vision_tower
183
  vt = (
184
  getattr(self.model.config, "mm_vision_tower", None)
185
  or getattr(self.model.config, "vision_tower", None)
@@ -188,26 +215,29 @@ class EndpointHandler:
188
  if self.image_processor is None or vt is None:
189
  raise RuntimeError(
190
  "[ERROR] Vision tower not loaded (mm_vision_tower/vision_tower None). "
191
- "Bu model multimodal değil veya yanlış checkpoint yüklendi. "
192
- "Yerelden yükleyecekseniz HF_MODEL_DIR doğru klasörü göstermeli; "
193
- "hub'dan yükleyecekseniz HF_MODEL_ID olarak PULSE/LLaVA tabanlı bir model verin (örn: 'PULSE-ECG/PULSE-7B')."
194
  )
195
 
196
- # tokenizer güvenliği
197
  try:
198
  self.tokenizer.padding_side = "left"
199
- if self.tokenizer.pad_token_id is None:
200
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
201
- except Exception:
202
- pass
203
 
204
  self.model.eval()
205
- print("[DEBUG] model.eval() done")
 
 
 
206
  return True
207
 
208
  @torch.inference_mode()
209
  def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
210
  print(f"[DEBUG] __call__ inputs keys={list(inputs.keys()) if hasattr(inputs,'keys') else 'N/A'}")
 
211
  if "inputs" in inputs and isinstance(inputs["inputs"], dict):
212
  inputs = inputs["inputs"]
213
 
@@ -218,13 +248,14 @@ class EndpointHandler:
218
  if not isinstance(prompt, str) or not prompt.strip():
219
  return {"error": "Missing 'query'/'prompt' text"}
220
 
 
221
  temperature = float(inputs.get("temperature", 0.0))
222
  top_p = float(inputs.get("top_p", 0.9))
223
  max_new = int(inputs.get("max_new_tokens", inputs.get("max_tokens", 512)))
224
  repetition_penalty = float(inputs.get("repetition_penalty", 1.0))
225
  conv_mode_override = inputs.get("conv_mode") or _get_env("CONV_MODE", None)
226
 
227
- # ---- image load + preprocess
228
  try:
229
  image = _load_image_from_any(image_in)
230
  print(f"[DEBUG] loaded image size={image.size}")
@@ -237,31 +268,39 @@ class EndpointHandler:
237
  try:
238
  out = self.image_processor.preprocess(image, return_tensors="pt")
239
  images_tensor = out["pixel_values"].to(self.device, dtype=self.dtype)
240
- image_sizes = [image.size] # bazı LLaVA sürümleri image_sizes ister
241
  print(f"[DEBUG] preprocess OK; images_tensor.shape={images_tensor.shape}")
242
  except Exception as e:
243
  return {"error": f"Image preprocessing failed: {e}"}
244
 
245
- # ---- conversation + prompt
246
  mode = conv_mode_override or _get_conv_mode(self.model_name)
247
  conv = (conv_templates.get(mode) or conv_templates[list(conv_templates.keys())[0]]).copy()
248
  conv.append_message(conv.roles[0], _build_prompt_with_image(prompt.strip(), self.model.config))
249
  conv.append_message(conv.roles[1], None)
250
  full_prompt = conv.get_prompt()
 
251
 
252
- # ---- tokenization (IMAGE_TOKEN_INDEX ile)
253
  try:
254
  input_ids = tokenizer_image_token(
255
  full_prompt, self.tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt"
256
  ).unsqueeze(0).to(self.device)
257
- except Exception:
258
- toks = self.tokenizer([full_prompt], return_tensors="pt", padding=True, truncation=True)
259
- input_ids = toks["input_ids"].to(self.device)
 
 
 
 
 
 
260
 
261
  attention_mask = torch.ones_like(input_ids, device=self.device)
262
 
263
- # ---- generate
264
  try:
 
265
  gen_ids = self.model.generate(
266
  input_ids=input_ids,
267
  attention_mask=attention_mask,
@@ -274,12 +313,17 @@ class EndpointHandler:
274
  repetition_penalty=repetition_penalty,
275
  use_cache=True,
276
  )
 
277
  except Exception as e:
278
  return {"error": f"Generation failed: {e}"}
279
 
280
- # ---- decode (sadece yeni tokenlar)
281
- new_tokens = gen_ids[0, input_ids.shape[1]:]
282
- text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
 
 
 
 
283
 
284
  return {
285
  "generated_text": text,
 
1
  # -*- coding: utf-8 -*-
2
+ # handler.py — Rapid_ECG / PULSE-7B — Startup-load, Stabil ve DEBUG'li sürüm
3
+ # - Sunucu açılır açılmaz model yüklenir (cold start only once)
4
+ # - HF Endpoint sözleşmesi (EndpointHandler.load().__call__)
5
+ # - Yerel (HF_MODEL_DIR) → Hub (HF_MODEL_ID) yükleme sırası
6
+ # - Görsel sadece .preprocess() ile işlenir (process_images yok)
7
  # - Vision tower kontrolü: mm_vision_tower veya vision_tower
8
  # - IMAGE_TOKEN_INDEX kullanımı ve kapsamlı [DEBUG] logları
9
 
 
18
  from PIL import Image
19
  import requests
20
 
21
+
22
+ # ===== LLaVA kütüphanesini garantiye al =====
23
  def _ensure_llava(tag: str = "v1.2.0"):
24
  try:
25
  import llava # noqa
 
88
  if isinstance(image_input, str):
89
  s = image_input.strip()
90
  if s.startswith("data:image"):
91
+ try:
92
+ _, b64 = s.split(",", 1)
93
+ data = base64.b64decode(b64)
94
+ return Image.open(io.BytesIO(data)).convert("RGB")
95
+ except Exception as e:
96
+ raise ValueError(f"Bad data URL: {e}")
97
  if _is_probably_base64(s) and not s.startswith(("http://", "https://")):
98
+ try:
99
+ data = base64.b64decode(s)
100
+ return Image.open(io.BytesIO(data)).convert("RGB")
101
+ except Exception as e:
102
+ raise ValueError(f"Bad base64 image: {e}")
103
  if s.startswith(("http://", "https://")):
104
  resp = requests.get(s, timeout=20)
105
  resp.raise_for_status()
106
  return Image.open(io.BytesIO(resp.content)).convert("RGB")
107
+ # local path
108
  return Image.open(s).convert("RGB")
109
  raise ValueError(f"Unsupported image input type: {type(image_input)}")
110
 
 
123
  return "llava_v0"
124
 
125
  def _build_prompt_with_image(prompt: str, model_cfg) -> str:
126
+ # Kullanıcı image token eklediyse yeniden eklemeyelim
127
  if DEFAULT_IMAGE_TOKEN in prompt or DEFAULT_IM_START_TOKEN in prompt:
128
  return prompt
129
  if getattr(model_cfg, "mm_use_im_start_end", False):
 
132
  return f"{DEFAULT_IMAGE_TOKEN}\n{prompt}"
133
 
134
  def _resolve_model_path(model_dir_hint: Optional[str], default_dir: str = "/repository") -> str:
135
+ # Öncelik: HF_MODEL_DIR (yerel) -> ctor'dan gelen model_dir_hint -> default_dir
136
  p = _get_env("HF_MODEL_DIR") or model_dir_hint or default_dir
137
  p = os.path.abspath(p)
138
  print(f"[DEBUG] resolved model path: {p}")
 
142
  # ---------- Endpoint Handler ----------
143
  class EndpointHandler:
144
  def __init__(self, model_dir: Optional[str] = None):
145
+ # DEBUG banner
146
+ print("🚀 Starting up PULSE-7B handler (startup load)...")
147
+ print("📝 Enhanced by Kefstacks")
148
+ print(f"🔧 Python: {sys.version}")
149
+ print(f"🔧 PyTorch: {torch.__version__}")
150
+ try:
151
+ import transformers
152
+ print(f"🔧 Transformers: {transformers.__version__}")
153
+ except Exception as e:
154
+ print(f"[DEBUG] transformers import failed: {e}")
155
+
156
  self.model_dir = model_dir
157
+ self.device = _pick_device()
158
+ self.dtype = _pick_dtype(self.device)
159
+
160
+ # Ortam ayarları (flash attn ipucu, zarar vermez)
161
+ os.environ.setdefault("ATTN_IMPLEMENTATION", "flash_attention_2")
162
+ os.environ.setdefault("FLASH_ATTENTION", "1")
163
+ print(f"[DEBUG] ATTN_IMPLEMENTATION={os.getenv('ATTN_IMPLEMENTATION')} FLASH_ATTENTION={os.getenv('FLASH_ATTENTION')}")
164
+
165
+ # Model/Tokenizer/ImageProcessor konteynerleri
166
  self.model = None
167
  self.tokenizer = None
168
  self.image_processor = None
169
  self.context_len = None
 
 
170
  self.model_name = None
171
 
172
+ # ---- Modeli burada (startup’ta) yükle ----
173
+ try:
174
+ self._startup_load_model()
175
+ print("✅ Model loaded & ready in __init__")
176
+ except Exception as e:
177
+ # Kritik hata: init'te patladıysa endpoint zaten ayağa kalkamaz
178
+ print(f"💥 CRITICAL: model startup load failed: {e}")
179
+ raise
180
+
181
+ def _startup_load_model(self):
182
+ # Yerel dizin varsa onu kullan, yoksa hub
183
  local_path = _resolve_model_path(self.model_dir)
184
  use_local = os.path.isdir(local_path) and any(
185
  os.path.exists(os.path.join(local_path, f))
 
187
  )
188
  model_base = _get_env("HF_MODEL_BASE", None)
189
 
 
 
 
190
  if use_local:
191
  model_path = local_path
192
  print(f"[DEBUG] loading model LOCALLY from: {model_path}")
 
194
  model_path = _get_env("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
195
  print(f"[DEBUG] loading model from HUB: {model_path} (HF_MODEL_BASE={model_base})")
196
 
 
197
  print("[DEBUG] calling load_pretrained_model ...")
198
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
199
  model_path=model_path,
 
206
  self.model_name = getattr(self.model.config, "name_or_path", str(model_path))
207
  print(f"[DEBUG] model loaded: name={self.model_name}")
208
 
209
+ # Vision tower kontrolü (yeni/eskı alan adları)
210
  vt = (
211
  getattr(self.model.config, "mm_vision_tower", None)
212
  or getattr(self.model.config, "vision_tower", None)
 
215
  if self.image_processor is None or vt is None:
216
  raise RuntimeError(
217
  "[ERROR] Vision tower not loaded (mm_vision_tower/vision_tower None). "
218
+ "Yerel yükleme için HF_MODEL_DIR doğru klasörü göstermeli; "
219
+ "Hub için HF_MODEL_ID PULSE/LLaVA tabanlı olmalı (örn: 'PULSE-ECG/PULSE-7B')."
 
220
  )
221
 
222
+ # Tokenizer güvenliği
223
  try:
224
  self.tokenizer.padding_side = "left"
225
+ if getattr(self.tokenizer, "pad_token_id", None) is None:
226
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
227
+ except Exception as e:
228
+ print(f"[DEBUG] tokenizer safety patch failed: {e}")
229
 
230
  self.model.eval()
231
+
232
+ # HF inference toolkit load() yine çağıracağı için no-op
233
+ def load(self):
234
+ print("[DEBUG] load(): model is already initialized in __init__")
235
  return True
236
 
237
  @torch.inference_mode()
238
  def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
239
  print(f"[DEBUG] __call__ inputs keys={list(inputs.keys()) if hasattr(inputs,'keys') else 'N/A'}")
240
+ # HF {"inputs": {...}} sarmasını aç
241
  if "inputs" in inputs and isinstance(inputs["inputs"], dict):
242
  inputs = inputs["inputs"]
243
 
 
248
  if not isinstance(prompt, str) or not prompt.strip():
249
  return {"error": "Missing 'query'/'prompt' text"}
250
 
251
+ # Üretim parametreleri
252
  temperature = float(inputs.get("temperature", 0.0))
253
  top_p = float(inputs.get("top_p", 0.9))
254
  max_new = int(inputs.get("max_new_tokens", inputs.get("max_tokens", 512)))
255
  repetition_penalty = float(inputs.get("repetition_penalty", 1.0))
256
  conv_mode_override = inputs.get("conv_mode") or _get_env("CONV_MODE", None)
257
 
258
+ # ---- Görsel yükle + preprocess
259
  try:
260
  image = _load_image_from_any(image_in)
261
  print(f"[DEBUG] loaded image size={image.size}")
 
268
  try:
269
  out = self.image_processor.preprocess(image, return_tensors="pt")
270
  images_tensor = out["pixel_values"].to(self.device, dtype=self.dtype)
271
+ image_sizes = [image.size]
272
  print(f"[DEBUG] preprocess OK; images_tensor.shape={images_tensor.shape}")
273
  except Exception as e:
274
  return {"error": f"Image preprocessing failed: {e}"}
275
 
276
+ # ---- Konuşma + prompt
277
  mode = conv_mode_override or _get_conv_mode(self.model_name)
278
  conv = (conv_templates.get(mode) or conv_templates[list(conv_templates.keys())[0]]).copy()
279
  conv.append_message(conv.roles[0], _build_prompt_with_image(prompt.strip(), self.model.config))
280
  conv.append_message(conv.roles[1], None)
281
  full_prompt = conv.get_prompt()
282
+ print(f"[DEBUG] conv_mode={mode}; full_prompt_len={len(full_prompt)}")
283
 
284
+ # ---- Tokenization (IMAGE_TOKEN_INDEX ile)
285
  try:
286
  input_ids = tokenizer_image_token(
287
  full_prompt, self.tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors="pt"
288
  ).unsqueeze(0).to(self.device)
289
+ print(f"[DEBUG] tokenizer_image_token OK; input_ids.shape={input_ids.shape}")
290
+ except Exception as e:
291
+ print(f"[DEBUG] tokenizer_image_token failed: {e}; fallback to plain tokenizer")
292
+ try:
293
+ toks = self.tokenizer([full_prompt], return_tensors="pt", padding=True, truncation=True)
294
+ input_ids = toks["input_ids"].to(self.device)
295
+ print(f"[DEBUG] plain tokenizer OK; input_ids.shape={input_ids.shape}")
296
+ except Exception as e2:
297
+ return {"error": f"Tokenization failed: {e} / {e2}"}
298
 
299
  attention_mask = torch.ones_like(input_ids, device=self.device)
300
 
301
+ # ---- Generate
302
  try:
303
+ print(f"[DEBUG] generate(max_new_tokens={max_new}, temp={temperature}, top_p={top_p}, rep={repetition_penalty})")
304
  gen_ids = self.model.generate(
305
  input_ids=input_ids,
306
  attention_mask=attention_mask,
 
313
  repetition_penalty=repetition_penalty,
314
  use_cache=True,
315
  )
316
+ print(f"[DEBUG] generate OK; gen_ids.shape={gen_ids.shape}")
317
  except Exception as e:
318
  return {"error": f"Generation failed: {e}"}
319
 
320
+ # ---- Decode (sadece yeni tokenlar)
321
+ try:
322
+ new_tokens = gen_ids[0, input_ids.shape[1]:]
323
+ text = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
324
+ print(f"[DEBUG] decoded_text_len={len(text)}")
325
+ except Exception as e:
326
+ return {"error": f"Decode failed: {e}"}
327
 
328
  return {
329
  "generated_text": text,