CanerDedeoglu commited on
Commit
a8b66ea
·
verified ·
1 Parent(s): 1906c8c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +140 -82
handler.py CHANGED
@@ -1,14 +1,19 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler — Demo Parity + Style Hint
4
  - Demo app.py ile aynı üretim ayarları:
5
  do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
  - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
  - Görsel tensörü: .half() ve model cihazında
8
  - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
  - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
- - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) yaklaşmak için
11
- - Post-process: YALNIZCA whitespace/biçim normalizasyonu (yönetim/öneri cümleleri korunur)
 
 
 
 
 
12
  """
13
 
14
  import os
@@ -19,37 +24,46 @@ import hashlib
19
  import datetime
20
  from io import BytesIO
21
  from threading import Thread
22
- from typing import Optional, Union
23
 
24
  import torch
25
  from PIL import Image
26
  import requests
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # ====== LLaVA & Transformers ======
29
  try:
30
- from llava.constants import (
31
- IMAGE_TOKEN_INDEX,
32
- DEFAULT_IMAGE_TOKEN,
33
- )
34
  from llava.conversation import conv_templates, SeparatorStyle
35
  from llava.model.builder import load_pretrained_model
36
- from llava.mm_utils import (
37
- tokenizer_image_token,
38
- process_images,
39
- get_model_name_from_path,
40
- )
41
  from llava.utils import disable_torch_init
42
  LLAVA_AVAILABLE = True
43
  except Exception as e:
44
  LLAVA_AVAILABLE = False
45
- print(f"[WARN] LLaVA not available: {e}")
46
 
47
  try:
48
  from transformers import TextIteratorStreamer, StoppingCriteria
49
  TRANSFORMERS_AVAILABLE = True
50
  except Exception as e:
51
  TRANSFORMERS_AVAILABLE = False
52
- print(f"[WARN] transformers not available: {e}")
53
 
54
  # ====== HF Hub logging (opsiyonel) ======
55
  try:
@@ -66,7 +80,7 @@ if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
66
  api = HfApi()
67
  repo_name = os.environ.get("LOG_REPO", "")
68
  except Exception as e:
69
- print(f"[HF Hub] init failed: {e}")
70
  api = None
71
  repo_name = ""
72
 
@@ -91,7 +105,6 @@ STYLE_HINT = (
91
  )
92
 
93
  # ===================== Utilities =====================
94
-
95
  def _safe_upload(path: str):
96
  if api and repo_name and path and os.path.isfile(path):
97
  try:
@@ -102,7 +115,7 @@ def _safe_upload(path: str):
102
  repo_type="dataset",
103
  )
104
  except Exception as e:
105
- print(f"[upload] failed for {path}: {e}")
106
 
107
  def _conv_log_path() -> str:
108
  t = datetime.datetime.now()
@@ -136,12 +149,6 @@ def load_image_any(image_input: Union[str, dict]) -> Image.Image:
136
  raise ValueError("Unsupported image input format")
137
 
138
  def _normalize_whitespace(text: str) -> str:
139
- """
140
- Gereksiz boşluk/boş satırları toparlar:
141
- - Satır başı/sonu boşluklarını siler
142
- - Birden çok boşluğu tek boşluğa indirger
143
- - 3+ boş satırı 1 boş satıra indirger
144
- """
145
  text = text.replace("\r\n", "\n").replace("\r", "\n")
146
  lines = [re.sub(r"[ \t]+", " ", ln.strip()) for ln in text.split("\n")]
147
  text = "\n".join(lines).strip()
@@ -149,14 +156,10 @@ def _normalize_whitespace(text: str) -> str:
149
  return text
150
 
151
  def _postprocess_min(text: str) -> str:
152
- # Yalnızca whitespace/biçim temizliği
153
  return _normalize_whitespace(text)
154
 
155
  # ====== Güvenli Stop Kriteri (conv separator) ======
156
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
157
- """
158
- conv.sep/sep2 bazlı token eşleşmesi; tensör → bool hatası yok.
159
- """
160
  def __init__(self, keyword: str, tokenizer):
161
  self.tokenizer = tokenizer
162
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
@@ -174,7 +177,6 @@ class SafeKeywordsStoppingCriteria(StoppingCriteria):
174
  return torch.equal(tail, kw)
175
 
176
  # ===================== Core Generation =====================
177
-
178
  class InferenceDemo:
179
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
180
  if not LLAVA_AVAILABLE:
@@ -183,7 +185,6 @@ class InferenceDemo:
183
  self.tokenizer, self.model, self.image_processor, self.context_len = (
184
  tokenizer_, model_, image_processor_, context_len_
185
  )
186
- # Parite için sabit şablon
187
  self.conv_mode = "llava_v1"
188
  self.conversation = conv_templates[self.conv_mode].copy()
189
  self.num_frames = getattr(args, "num_frames", 16)
@@ -200,19 +201,16 @@ class ChatSessionManager:
200
  self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
201
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
202
  self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
203
- # Her çağrıda taze template (demo gibi yeni tur)
204
  self.chatbot.conversation = conv_templates[self.chatbot.conv_mode].copy()
205
  return self.chatbot
206
 
207
  chat_manager = ChatSessionManager()
208
 
209
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
210
- # DEMO PARİTE: sarım yok, tek görüntü için tek image token
211
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
212
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
213
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
214
  prompt = chatbot.conversation.get_prompt()
215
-
216
  input_ids = tokenizer_image_token(
217
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
218
  ).unsqueeze(0).to(device)
@@ -227,31 +225,29 @@ def generate_response(
227
  max_new_tokens: Optional[int] = None,
228
  conv_mode_override: Optional[str] = None,
229
  repetition_penalty: Optional[float] = None,
230
- det_seed: Optional[int] = None, # None → stokastik (demo gibi)
231
  ):
232
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
233
  return {"error": "Required libraries not available (llava/transformers)"}
234
  if not message_text or image_input is None:
235
  return {"error": "Both 'message' and 'image' are required"}
236
 
237
- # Varsayılanlar → demo
238
  if temperature is None: temperature = 0.05
239
  if top_p is None: top_p = 1.0
240
  if max_new_tokens is None: max_new_tokens = 4096
241
- if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz
 
 
242
 
243
- # Chat session
244
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
245
  if conv_mode_override and conv_mode_override in conv_templates:
246
  chatbot.conversation = conv_templates[conv_mode_override].copy()
247
 
248
- # Görüntü yükle
249
  try:
250
  pil_img = load_image_any(image_input)
251
  except Exception as e:
252
  return {"error": f"Failed to load image: {e}"}
253
 
254
- # Log için hash+path
255
  img_hash, img_path = "NA", None
256
  try:
257
  buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
@@ -262,37 +258,55 @@ def generate_response(
262
  if not os.path.isfile(img_path):
263
  pil_img.save(img_path)
264
  except Exception as e:
265
- print(f"[log] save image failed: {e}")
266
 
267
- # Cihaz/dtype
268
  device = next(chatbot.model.parameters()).device
269
- dtype = torch.float16 # demo: half
270
 
271
- # Görüntü ön-işleme → tensör
272
  try:
 
273
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
 
 
274
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
275
  image_tensor = processed[0]
276
  elif isinstance(processed, torch.Tensor):
277
  image_tensor = processed[0] if processed.ndim == 4 else processed
278
  else:
279
- return {"error": "Image processing returned empty"}
 
280
  if image_tensor.ndim == 3:
281
- image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
282
- image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device
 
283
  except Exception as e:
284
- return {"error": f"Image processing failed: {e}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- # STYLE_HINT ekle ve prompt hazırla
287
  msg = (message_text or "").strip()
288
  msg = f"{msg}\n\n{STYLE_HINT}"
 
289
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
290
 
291
- # Stop string (conv separator) → güvenli kriter
292
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
293
  stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
294
 
295
- # Seed (gönderilmediyse stokastik → demo gibi)
296
  if det_seed is not None:
297
  try:
298
  s = int(det_seed)
@@ -303,26 +317,21 @@ def generate_response(
303
  except Exception:
304
  pass
305
 
306
- # Streamer (demo gibi)
307
- streamer = TextIteratorStreamer(
308
- chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
309
- )
310
 
311
- # Generate kwargs — demo ayarları
312
  gen_kwargs = dict(
313
  inputs=input_ids,
314
  images=image_tensor,
315
  streamer=streamer,
316
- do_sample=True, # DEMO
317
- temperature=float(temperature), # DEMO default 0.05
318
- top_p=float(top_p), # DEMO default 1.0
319
- max_new_tokens=int(max_new_tokens), # DEMO slider
320
- repetition_penalty=float(repetition_penalty), # default 1.0 → etkisiz
321
  use_cache=False,
322
- stopping_criteria=[stopping], # DEMO-benzeri durdurma
323
  )
324
 
325
- # Üretim (arka thread) + akışı topla
326
  try:
327
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
328
  t.start()
@@ -330,12 +339,11 @@ def generate_response(
330
  for piece in streamer:
331
  chunks.append(piece)
332
  text = "".join(chunks)
333
- text = _postprocess_min(text) # yalnızca whitespace/format temizliği
334
  chatbot.conversation.messages[-1][-1] = text
335
  except Exception as e:
336
  return {"error": f"Generation failed: {e}"}
337
 
338
- # Log
339
  try:
340
  row = {
341
  "time": datetime.datetime.now().isoformat(),
@@ -349,12 +357,11 @@ def generate_response(
349
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
350
  _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
351
  except Exception as e:
352
- print(f"[log] failed: {e}")
353
 
354
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
355
 
356
  # ===================== Public API =====================
357
-
358
  def query(payload: dict):
359
  """HF Endpoint entry (demo-like)."""
360
  global model_initialized, tokenizer, model, image_processor, context_len, args
@@ -369,11 +376,10 @@ def query(payload: dict):
369
  if not message.strip(): return {"error": "Missing 'message' text"}
370
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
371
 
372
- # Demo varsayılanları — payload override edebilir
373
  temperature = float(payload.get("temperature", 0.05))
374
  top_p = float(payload.get("top_p", 1.0))
375
  max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
376
- repetition_penalty = float(payload.get("repetition_penalty", 1.0)) # etkisiz default
377
 
378
  conv_mode_override = payload.get("conv_mode", None)
379
  det_seed = payload.get("det_seed", None)
@@ -413,13 +419,12 @@ def get_model_info():
413
  }
414
 
415
  # ===================== Init & Session =====================
416
-
417
  class _Args:
418
  def __init__(self):
419
  self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
420
  self.model_base = None
421
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
422
- self.conv_mode = "llava_v1" # Parite için sabit
423
  self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
424
  self.num_frames = 16
425
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
@@ -429,21 +434,53 @@ class _Args:
429
  def initialize_model():
430
  global tokenizer, model, image_processor, context_len, args
431
  if not LLAVA_AVAILABLE:
432
- print("[init] LLaVA not available; cannot init.")
433
  return False
434
  try:
435
  args = _Args()
 
436
  model_name = get_model_name_from_path(args.model_path)
 
437
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
438
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
439
  )
440
- # demo: model genelde cuda’da çalıştırır
 
441
  try:
442
  _ = next(model_.parameters()).device
443
  except Exception:
444
  if torch.cuda.is_available():
445
  model_ = model_.to(torch.device("cuda"))
446
  model_.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  globals()["tokenizer"] = tokenizer_
449
  globals()["model"] = model_
@@ -454,11 +491,10 @@ def initialize_model():
454
  print("[init] model/tokenizer/image_processor loaded.")
455
  return True
456
  except Exception as e:
457
- print(f"[init] failed: {e}")
458
  return False
459
 
460
  # ===================== HF EndpointHandler =====================
461
-
462
  class EndpointHandler:
463
  """Hugging Face Endpoint uyumlu sınıf"""
464
  def __init__(self, model_dir):
@@ -474,24 +510,21 @@ class EndpointHandler:
474
  return get_model_info()
475
 
476
  if __name__ == "__main__":
477
- print("Handler ready (Demo Parity + Style Hint + whitespace post-process). Use `EndpointHandler` or `query`.")
478
-
479
 
480
  # ===================== Minimal FastAPI Wrapper =====================
481
  try:
482
- from fastapi import FastAPI, Body
483
  from pydantic import BaseModel
484
- from typing import Any, Dict
485
  FASTAPI_AVAILABLE = True
486
  except Exception as e:
487
  FASTAPI_AVAILABLE = False
488
- print(f"[WARN] fastapi/pydantic not available: {e}")
489
 
490
  if FASTAPI_AVAILABLE:
491
  app = FastAPI(title="PULSE ECG Handler API", version="1.0.0")
492
 
493
  class QueryIn(BaseModel):
494
- # Hugging Face Endpoint tarzı payload ile uyumlu
495
  message: str | None = None
496
  query: str | None = None
497
  prompt: str | None = None
@@ -523,10 +556,35 @@ if FASTAPI_AVAILABLE:
523
  async def _info():
524
  return get_model_info()
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  @app.post("/query")
527
  async def _query(payload: QueryIn):
528
- # Boş alanları at, handler.query interface'ine aynen gönder
529
  return query({k: v for k, v in payload.dict().items() if v is not None})
530
-
531
  else:
532
  app = None # uvicorn handler:app çalıştırıldığında import error verir
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler — Demo Parity + Style Hint + Robust Fallbacks + Debug
4
  - Demo app.py ile aynı üretim ayarları:
5
  do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
  - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
  - Görsel tensörü: .half() ve model cihazında
8
  - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
  - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
+ - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression)
11
+ - Post-process: yalnızca whitespace/biçim temizliği
12
+ - Ekler:
13
+ * DEBUG yardımcıları (ENV: DEBUG=1)
14
+ * image_processor fallback (AutoProcessor → CLIPImageProcessor)
15
+ * process_images fallback (torchvision + CLIP norm)
16
+ * FastAPI wrapper: /health, /info, /query, /debug
17
  """
18
 
19
  import os
 
24
  import datetime
25
  from io import BytesIO
26
  from threading import Thread
27
+ from typing import Optional, Union, Any, Dict
28
 
29
  import torch
30
  from PIL import Image
31
  import requests
32
 
33
+ # ====== Debug Helpers ======
34
+ def _env_bool(name: str, default: bool = False) -> bool:
35
+ v = os.getenv(name)
36
+ if v is None:
37
+ return default
38
+ return str(v).strip().lower() in {"1", "true", "yes", "y", "on"}
39
+
40
+ DEBUG = _env_bool("DEBUG", False)
41
+
42
+ def dbg(*args, **kwargs):
43
+ if DEBUG:
44
+ print("[DEBUG]", *args, **kwargs)
45
+
46
+ def warn(*args, **kwargs):
47
+ print("[WARN]", *args, **kwargs)
48
+
49
  # ====== LLaVA & Transformers ======
50
  try:
51
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
 
 
 
52
  from llava.conversation import conv_templates, SeparatorStyle
53
  from llava.model.builder import load_pretrained_model
54
+ from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
 
 
 
 
55
  from llava.utils import disable_torch_init
56
  LLAVA_AVAILABLE = True
57
  except Exception as e:
58
  LLAVA_AVAILABLE = False
59
+ warn(f"LLaVA not available: {e}")
60
 
61
  try:
62
  from transformers import TextIteratorStreamer, StoppingCriteria
63
  TRANSFORMERS_AVAILABLE = True
64
  except Exception as e:
65
  TRANSFORMERS_AVAILABLE = False
66
+ warn(f"transformers not available: {e}")
67
 
68
  # ====== HF Hub logging (opsiyonel) ======
69
  try:
 
80
  api = HfApi()
81
  repo_name = os.environ.get("LOG_REPO", "")
82
  except Exception as e:
83
+ warn(f"[HF Hub] init failed: {e}")
84
  api = None
85
  repo_name = ""
86
 
 
105
  )
106
 
107
  # ===================== Utilities =====================
 
108
  def _safe_upload(path: str):
109
  if api and repo_name and path and os.path.isfile(path):
110
  try:
 
115
  repo_type="dataset",
116
  )
117
  except Exception as e:
118
+ warn(f"[upload] failed for {path}: {e}")
119
 
120
  def _conv_log_path() -> str:
121
  t = datetime.datetime.now()
 
149
  raise ValueError("Unsupported image input format")
150
 
151
  def _normalize_whitespace(text: str) -> str:
 
 
 
 
 
 
152
  text = text.replace("\r\n", "\n").replace("\r", "\n")
153
  lines = [re.sub(r"[ \t]+", " ", ln.strip()) for ln in text.split("\n")]
154
  text = "\n".join(lines).strip()
 
156
  return text
157
 
158
  def _postprocess_min(text: str) -> str:
 
159
  return _normalize_whitespace(text)
160
 
161
  # ====== Güvenli Stop Kriteri (conv separator) ======
162
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
 
 
 
163
  def __init__(self, keyword: str, tokenizer):
164
  self.tokenizer = tokenizer
165
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
 
177
  return torch.equal(tail, kw)
178
 
179
  # ===================== Core Generation =====================
 
180
  class InferenceDemo:
181
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
182
  if not LLAVA_AVAILABLE:
 
185
  self.tokenizer, self.model, self.image_processor, self.context_len = (
186
  tokenizer_, model_, image_processor_, context_len_
187
  )
 
188
  self.conv_mode = "llava_v1"
189
  self.conversation = conv_templates[self.conv_mode].copy()
190
  self.num_frames = getattr(args, "num_frames", 16)
 
201
  self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
202
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
203
  self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
 
204
  self.chatbot.conversation = conv_templates[self.chatbot.conv_mode].copy()
205
  return self.chatbot
206
 
207
  chat_manager = ChatSessionManager()
208
 
209
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
 
210
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
211
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
212
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
213
  prompt = chatbot.conversation.get_prompt()
 
214
  input_ids = tokenizer_image_token(
215
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
216
  ).unsqueeze(0).to(device)
 
225
  max_new_tokens: Optional[int] = None,
226
  conv_mode_override: Optional[str] = None,
227
  repetition_penalty: Optional[float] = None,
228
+ det_seed: Optional[int] = None,
229
  ):
230
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
231
  return {"error": "Required libraries not available (llava/transformers)"}
232
  if not message_text or image_input is None:
233
  return {"error": "Both 'message' and 'image' are required"}
234
 
 
235
  if temperature is None: temperature = 0.05
236
  if top_p is None: top_p = 1.0
237
  if max_new_tokens is None: max_new_tokens = 4096
238
+ if repetition_penalty is None: repetition_penalty = 1.0
239
+
240
+ dbg(f"[gen] temperature={temperature} top_p={top_p} max_new_tokens={max_new_tokens} rep={repetition_penalty} seed={det_seed}")
241
 
 
242
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
243
  if conv_mode_override and conv_mode_override in conv_templates:
244
  chatbot.conversation = conv_templates[conv_mode_override].copy()
245
 
 
246
  try:
247
  pil_img = load_image_any(image_input)
248
  except Exception as e:
249
  return {"error": f"Failed to load image: {e}"}
250
 
 
251
  img_hash, img_path = "NA", None
252
  try:
253
  buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
 
258
  if not os.path.isfile(img_path):
259
  pil_img.save(img_path)
260
  except Exception as e:
261
+ warn(f"[log] save image failed: {e}")
262
 
 
263
  device = next(chatbot.model.parameters()).device
264
+ dtype = torch.float16
265
 
266
+ # Görüntü ön-işleme → tensör (fallback'lı)
267
  try:
268
+ dbg(f"[pre] PIL image size={pil_img.size}, mode={pil_img.mode}, processor={type(chatbot.image_processor)}")
269
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
270
+ dbg("[pre] process_images ok")
271
+
272
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
273
  image_tensor = processed[0]
274
  elif isinstance(processed, torch.Tensor):
275
  image_tensor = processed[0] if processed.ndim == 4 else processed
276
  else:
277
+ raise ValueError("Image processing returned empty")
278
+
279
  if image_tensor.ndim == 3:
280
+ image_tensor = image_tensor.unsqueeze(0)
281
+ image_tensor = image_tensor.to(device=device, dtype=dtype)
282
+ dbg(f"[pre] tensor shape={tuple(image_tensor.shape)} dtype={image_tensor.dtype} device={image_tensor.device}")
283
  except Exception as e:
284
+ warn(f"[pre] process_images failed: {e} → manual CLIP preprocess fallback kullanılacak.")
285
+ try:
286
+ from torchvision import transforms
287
+ from torchvision.transforms import InterpolationMode
288
+ preprocess = transforms.Compose([
289
+ transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
290
+ transforms.CenterCrop(224),
291
+ transforms.ToTensor(),
292
+ transforms.Normalize(
293
+ mean=[0.48145466, 0.4578275, 0.40821073],
294
+ std=[0.26862954, 0.26130258, 0.27577711]
295
+ ),
296
+ ])
297
+ image_tensor = preprocess(pil_img).unsqueeze(0).to(device=device, dtype=dtype)
298
+ dbg("[pre] manual CLIP preprocess fallback ok → tensor shape=" + str(tuple(image_tensor.shape)))
299
+ except Exception as ee:
300
+ return {"error": f"Image processing failed (and fallback failed): {ee}"}
301
 
 
302
  msg = (message_text or "").strip()
303
  msg = f"{msg}\n\n{STYLE_HINT}"
304
+ dbg(f"[prompt] conv_sep_style={chatbot.conversation.sep_style} sep_len={len(chatbot.conversation.sep)}")
305
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
306
 
 
307
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
308
  stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
309
 
 
310
  if det_seed is not None:
311
  try:
312
  s = int(det_seed)
 
317
  except Exception:
318
  pass
319
 
320
+ streamer = TextIteratorStreamer(chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
321
 
 
322
  gen_kwargs = dict(
323
  inputs=input_ids,
324
  images=image_tensor,
325
  streamer=streamer,
326
+ do_sample=True,
327
+ temperature=float(temperature),
328
+ top_p=float(top_p),
329
+ max_new_tokens=int(max_new_tokens),
330
+ repetition_penalty=float(repetition_penalty),
331
  use_cache=False,
332
+ stopping_criteria=[stopping],
333
  )
334
 
 
335
  try:
336
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
337
  t.start()
 
339
  for piece in streamer:
340
  chunks.append(piece)
341
  text = "".join(chunks)
342
+ text = _postprocess_min(text)
343
  chatbot.conversation.messages[-1][-1] = text
344
  except Exception as e:
345
  return {"error": f"Generation failed: {e}"}
346
 
 
347
  try:
348
  row = {
349
  "time": datetime.datetime.now().isoformat(),
 
357
  f.write(json.dumps(row, ensure_ascii=False) + "\n")
358
  _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
359
  except Exception as e:
360
+ warn(f"[log] failed: {e}")
361
 
362
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
363
 
364
  # ===================== Public API =====================
 
365
  def query(payload: dict):
366
  """HF Endpoint entry (demo-like)."""
367
  global model_initialized, tokenizer, model, image_processor, context_len, args
 
376
  if not message.strip(): return {"error": "Missing 'message' text"}
377
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
378
 
 
379
  temperature = float(payload.get("temperature", 0.05))
380
  top_p = float(payload.get("top_p", 1.0))
381
  max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
382
+ repetition_penalty = float(payload.get("repetition_penalty", 1.0))
383
 
384
  conv_mode_override = payload.get("conv_mode", None)
385
  det_seed = payload.get("det_seed", None)
 
419
  }
420
 
421
  # ===================== Init & Session =====================
 
422
  class _Args:
423
  def __init__(self):
424
  self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
425
  self.model_base = None
426
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
427
+ self.conv_mode = "llava_v1"
428
  self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
429
  self.num_frames = 16
430
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
 
434
  def initialize_model():
435
  global tokenizer, model, image_processor, context_len, args
436
  if not LLAVA_AVAILABLE:
437
+ warn("[init] LLaVA not available; cannot init.")
438
  return False
439
  try:
440
  args = _Args()
441
+ dbg(f"[init] HF_MODEL_ID={args.model_path} | LOAD_8BIT={args.load_8bit} | LOAD_4BIT={args.load_4bit}")
442
  model_name = get_model_name_from_path(args.model_path)
443
+
444
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
445
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
446
  )
447
+ dbg(f"[init] load_pretrained_model ok | tokenizer={type(tokenizer_)} | model={type(model_)} | image_processor={type(image_processor_)} | context_len={context_len_}")
448
+
449
  try:
450
  _ = next(model_.parameters()).device
451
  except Exception:
452
  if torch.cuda.is_available():
453
  model_ = model_.to(torch.device("cuda"))
454
  model_.eval()
455
+ dbg(f"[init] device={next(model_.parameters()).device}, cuda_available={torch.cuda.is_available()}")
456
+
457
+ # --- image_processor fallback zinciri ---
458
+ try:
459
+ if image_processor_ is None:
460
+ dbg("[init] image_processor None → AutoProcessor fallback deneniyor…")
461
+ try:
462
+ from transformers import AutoProcessor
463
+ image_processor_ = AutoProcessor.from_pretrained(args.model_path)
464
+ dbg("[init] image_processor: AutoProcessor.from_pretrained(model_path) ile yüklendi.")
465
+ except Exception as _e1:
466
+ dbg(f"[init] AutoProcessor failed: {_e1} → CLIPImageProcessor fallback deneniyor…")
467
+ from transformers import CLIPImageProcessor
468
+ image_processor_ = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
469
+ warn("[init] image_processor: CLIPImageProcessor(openai/clip-vit-large-patch14) fallback kullanılıyor.")
470
+ except Exception as _e:
471
+ warn(f"[init] image_processor fallback failed: {_e}")
472
+
473
+ # --- image_processor introspection ---
474
+ try:
475
+ ip = image_processor_
476
+ if ip is not None:
477
+ crop_sz = getattr(getattr(ip, "crop_size", None), "height", None) or getattr(ip, "crop_size", None)
478
+ size_sz = getattr(getattr(ip, "size", None), "height", None) or getattr(ip, "size", None)
479
+ dbg(f"[init] image_processor crop_size={crop_sz} size={size_sz} class={ip.__class__.__name__}")
480
+ else:
481
+ warn("[init] image_processor yine None (fallback da başarısız).")
482
+ except Exception as e_ip:
483
+ warn(f"[init] image_processor inspect error: {e_ip}")
484
 
485
  globals()["tokenizer"] = tokenizer_
486
  globals()["model"] = model_
 
491
  print("[init] model/tokenizer/image_processor loaded.")
492
  return True
493
  except Exception as e:
494
+ warn(f"[init] failed: {e}")
495
  return False
496
 
497
  # ===================== HF EndpointHandler =====================
 
498
  class EndpointHandler:
499
  """Hugging Face Endpoint uyumlu sınıf"""
500
  def __init__(self, model_dir):
 
510
  return get_model_info()
511
 
512
  if __name__ == "__main__":
513
+ print("Handler ready (Demo Parity + Style Hint + whitespace post-process + fallbacks + debug). Use `EndpointHandler` or `query`.")
 
514
 
515
  # ===================== Minimal FastAPI Wrapper =====================
516
  try:
517
+ from fastapi import FastAPI
518
  from pydantic import BaseModel
 
519
  FASTAPI_AVAILABLE = True
520
  except Exception as e:
521
  FASTAPI_AVAILABLE = False
522
+ warn(f"fastapi/pydantic not available: {e}")
523
 
524
  if FASTAPI_AVAILABLE:
525
  app = FastAPI(title="PULSE ECG Handler API", version="1.0.0")
526
 
527
  class QueryIn(BaseModel):
 
528
  message: str | None = None
529
  query: str | None = None
530
  prompt: str | None = None
 
556
  async def _info():
557
  return get_model_info()
558
 
559
+ @app.get("/debug")
560
+ async def _debug():
561
+ try:
562
+ dev = str(next(model.parameters()).device) if model else "Unknown"
563
+ except Exception:
564
+ dev = "Unknown"
565
+
566
+ try:
567
+ ip = image_processor
568
+ ip_cls = ip.__class__.__name__ if ip else None
569
+ crop_sz = getattr(getattr(ip, "crop_size", None), "height", None) or getattr(ip, "crop_size", None)
570
+ size_sz = getattr(getattr(ip, "size", None), "height", None) or getattr(ip, "size", None)
571
+ except Exception:
572
+ ip_cls, crop_sz, size_sz = None, None, None
573
+
574
+ return {
575
+ "debug": bool(DEBUG),
576
+ "llava_available": LLAVA_AVAILABLE,
577
+ "transformers_available": TRANSFORMERS_AVAILABLE,
578
+ "device": dev,
579
+ "context_len": context_len,
580
+ "image_processor_class": ip_cls,
581
+ "image_processor_crop_size": crop_sz,
582
+ "image_processor_size": size_sz,
583
+ "model_path": args.model_path if args else None,
584
+ }
585
+
586
  @app.post("/query")
587
  async def _query(payload: QueryIn):
 
588
  return query({k: v for k, v in payload.dict().items() if v is not None})
 
589
  else:
590
  app = None # uvicorn handler:app çalıştırıldığında import error verir