CanerDedeoglu commited on
Commit
f3e0682
·
verified ·
1 Parent(s): 18308a9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +175 -46
handler.py CHANGED
@@ -1,20 +1,21 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler — Demo Parity + Style Hint + Robust Fallbacks + Debug + Dynamic Vision Size
4
- - Demo app.py ile aynı üretim ayarları:
5
  do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
- - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
- - Görsel tensörü: .half() ve model cihazında
8
- - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
- - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
- - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression)
11
- - Post-process: yalnızca whitespace/biçim temizliği
12
- - Ekler:
13
- * DEBUG yardımcıları (ENV: DEBUG=1)
14
- * Dynamic vision size: vision tower -> processor + preprocess/fallback
15
  * image_processor fallback (AutoProcessor → CLIPImageProcessor)
16
  * process_images fallback (torchvision + CLIP norm)
17
- * FastAPI wrapper: /health, /info, /query, /debug
 
18
  """
19
 
20
  import os
@@ -66,7 +67,7 @@ except Exception as e:
66
  TRANSFORMERS_AVAILABLE = False
67
  warn(f"transformers not available: {e}")
68
 
69
- # ====== HF Hub logging (opsiyonel) ======
70
  try:
71
  from huggingface_hub import HfApi, login
72
  HF_HUB_AVAILABLE = True
@@ -96,7 +97,7 @@ context_len = None
96
  args = None
97
  model_initialized = False
98
 
99
- # ====== Style Hint (demo benzeri üslup) ======
100
  STYLE_HINT = (
101
  "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
102
  "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
@@ -105,6 +106,30 @@ STYLE_HINT = (
105
  "followed by a succinct, comma-separated summary of the key diagnoses."
106
  )
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # ===================== Utilities =====================
109
  def _safe_upload(path: str):
110
  if api and repo_name and path and os.path.isfile(path):
@@ -124,10 +149,10 @@ def _conv_log_path() -> str:
124
 
125
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
126
  """
127
- Desteklenen:
128
  - URL (http/https)
129
- - yerel dosya yolu
130
- - base64 (opsiyonel data URL prefix ile)
131
  - {"image": <base64|dataurl>}
132
  """
133
  if isinstance(image_input, str):
@@ -138,7 +163,7 @@ def load_image_any(image_input: Union[str, dict]) -> Image.Image:
138
  return Image.open(BytesIO(r.content)).convert("RGB")
139
  if os.path.exists(s):
140
  return Image.open(s).convert("RGB")
141
- # base64 (dataurl olabilir)
142
  if s.startswith("data:image"):
143
  s = s.split(",", 1)[1]
144
  raw = base64.b64decode(s)
@@ -162,8 +187,7 @@ def _postprocess_min(text: str) -> str:
162
  # ====== Vision helpers (dynamic size) ======
163
  def get_vision_expected_size(m, default: int = 336) -> int:
164
  """
165
- Modelin vision tower'ının beklediği input boyutunu döndürür (örn. 336).
166
- LLaVA/CLIP konfiglerinde genelde `image_size` bulunur.
167
  """
168
  try:
169
  vt = m.get_vision_tower()
@@ -180,7 +204,7 @@ def get_vision_expected_size(m, default: int = 336) -> int:
180
  return default
181
 
182
  def force_processor_size(proc, size: int):
183
- """Processor'ın resize/crop alanlarını güvenle hedef boyuta zorlar."""
184
  try:
185
  # size
186
  if hasattr(proc, "size"):
@@ -206,7 +230,7 @@ def force_processor_size(proc, size: int):
206
  except Exception as e:
207
  warn(f"[processor] force size failed: {e}")
208
 
209
- # ====== Güvenli Stop Kriteri (conv separator) ======
210
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
211
  def __init__(self, keyword: str, tokenizer):
212
  self.tokenizer = tokenizer
@@ -241,7 +265,7 @@ class ChatSessionManager:
241
  def __init__(self):
242
  self.chatbot = None
243
  self.args = None
244
- self.model_path = None
245
  def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len):
246
  if self.chatbot is None:
247
  self.args = args
@@ -274,6 +298,7 @@ def generate_response(
274
  conv_mode_override: Optional[str] = None,
275
  repetition_penalty: Optional[float] = None,
276
  det_seed: Optional[int] = None,
 
277
  ):
278
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
279
  return {"error": "Required libraries not available (llava/transformers)"}
@@ -285,17 +310,54 @@ def generate_response(
285
  if max_new_tokens is None: max_new_tokens = 4096
286
  if repetition_penalty is None: repetition_penalty = 1.0
287
 
288
- dbg(f"[gen] temperature={temperature} top_p={top_p} max_new_tokens={max_new_tokens} rep={repetition_penalty} seed={det_seed}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
291
  if conv_mode_override and conv_mode_override in conv_templates:
292
  chatbot.conversation = conv_templates[conv_mode_override].copy()
293
 
 
294
  try:
295
  pil_img = load_image_any(image_input)
296
  except Exception as e:
297
  return {"error": f"Failed to load image: {e}"}
298
 
 
299
  img_hash, img_path = "NA", None
300
  try:
301
  buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
@@ -311,18 +373,16 @@ def generate_response(
311
  device = next(chatbot.model.parameters()).device
312
  dtype = torch.float16
313
 
314
- # === Görüntü ön-işlemetensör (dinamik boy) ===
315
  expected_size = get_vision_expected_size(chatbot.model, default=336)
316
  dbg(f"[pre] dynamic expected_size={expected_size} | processor={type(chatbot.image_processor)}")
317
 
318
- # 3.1) Processor.preprocess varsa kullan (en stabil yol)
319
  image_tensor = None
320
  try:
321
  if hasattr(chatbot.image_processor, "preprocess"):
322
  px = chatbot.image_processor.preprocess(pil_img, return_tensors="pt")
323
  image_tensor = px.get("pixel_values", px)
324
  if not isinstance(image_tensor, torch.Tensor):
325
- # Bazı processor'lar nested dict döndürebilir
326
  image_tensor = image_tensor["pixel_values"]
327
  if image_tensor.ndim == 3:
328
  image_tensor = image_tensor.unsqueeze(0)
@@ -331,8 +391,7 @@ def generate_response(
331
  else:
332
  raise AttributeError("processor has no preprocess")
333
  except Exception as e_pre:
334
- warn(f"[pre] processor.preprocess not used: {e_pre} → process_images denenecek…")
335
- # 3.2) LLaVA'nın process_images yolu
336
  try:
337
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
338
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
@@ -347,8 +406,7 @@ def generate_response(
347
  image_tensor = image_tensor.to(device=device, dtype=dtype)
348
  dbg(f"[pre] process_images ok → {tuple(image_tensor.shape)}")
349
  except Exception as e_proc:
350
- warn(f"[pre] process_images failed: {e_proc} → manual CLIP fallback (dinamik size) kullanılacak.")
351
- # 3.3) Manuel CLIP fallback (dinamik expected_size)
352
  from torchvision import transforms
353
  from torchvision.transforms import InterpolationMode
354
  preprocess = transforms.Compose([
@@ -366,8 +424,13 @@ def generate_response(
366
  if image_tensor is None:
367
  return {"error": "Image processing failed (no tensor produced)"}
368
 
369
- msg = (message_text or "").strip()
370
- msg = f"{msg}\n\n{STYLE_HINT}"
 
 
 
 
 
371
  dbg(f"[prompt] conv_sep_style={chatbot.conversation.sep_style} sep_len={len(chatbot.conversation.sep)}")
372
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
373
 
@@ -411,6 +474,7 @@ def generate_response(
411
  except Exception as e:
412
  return {"error": f"Generation failed: {e}"}
413
 
 
414
  try:
415
  row = {
416
  "time": datetime.datetime.now().isoformat(),
@@ -426,6 +490,19 @@ def generate_response(
426
  except Exception as e:
427
  warn(f"[log] failed: {e}")
428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
430
 
431
  # ===================== Public API =====================
@@ -450,6 +527,8 @@ def query(payload: dict):
450
 
451
  conv_mode_override = payload.get("conv_mode", None)
452
  det_seed = payload.get("det_seed", None)
 
 
453
  if det_seed is not None:
454
  try: det_seed = int(det_seed)
455
  except Exception: det_seed = None
@@ -463,6 +542,7 @@ def query(payload: dict):
463
  conv_mode_override=conv_mode_override,
464
  repetition_penalty=repetition_penalty,
465
  det_seed=det_seed,
 
466
  )
467
  except Exception as e:
468
  return {"error": f"Query failed: {e}"}
@@ -521,41 +601,41 @@ def initialize_model():
521
  model_.eval()
522
  dbg(f"[init] device={next(model_.parameters()).device}, cuda_available={torch.cuda.is_available()}")
523
 
524
- # --- vision tower beklenen image_size'ı al ---
525
  expected_size = get_vision_expected_size(model_, default=336)
526
  dbg(f"[init] vision expected image_size={expected_size}")
527
 
528
- # --- image_processor fallback zinciri (model path > AutoProcessor > CLIP 224/336) ---
529
  try:
530
  if image_processor_ is None:
531
- dbg("[init] image_processor None → AutoProcessor(model_path) deneniyor…")
532
  try:
533
  from transformers import AutoProcessor
534
  image_processor_ = AutoProcessor.from_pretrained(args.model_path)
535
- dbg("[init] image_processor: AutoProcessor.from_pretrained(model_path) yüklendi.")
536
  except Exception as _e1:
537
  dbg(f"[init] AutoProcessor(model_path) failed: {_e1}")
538
  try:
539
  from transformers import AutoProcessor
540
  clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
541
  image_processor_ = AutoProcessor.from_pretrained(clip_id)
542
- dbg(f"[init] AutoProcessor({clip_id}) yüklendi.")
543
  except Exception as _e2:
544
  from transformers import CLIPImageProcessor
545
  clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
546
  image_processor_ = CLIPImageProcessor.from_pretrained(clip_id)
547
- warn(f"[init] CLIPImageProcessor({clip_id}) fallback kullanılıyor.")
548
  except Exception as _e:
549
  warn(f"[init] image_processor fallback chain failed: {_e}")
550
 
551
- # --- processor'ın boyutlarını vision tower'a uydur ---
552
  try:
553
  if image_processor_ is not None:
554
  force_processor_size(image_processor_, expected_size)
555
  except Exception as e_ip:
556
  warn(f"[init] processor size set error: {e_ip}")
557
 
558
- # --- image_processor introspection ---
559
  try:
560
  ip = image_processor_
561
  if ip is not None:
@@ -563,7 +643,7 @@ def initialize_model():
563
  size_sz = getattr(getattr(ip, "size", None), "shortest_edge", None) or getattr(ip, "size", None)
564
  dbg(f"[init] image_processor crop_size={crop_sz} size={size_sz} class={ip.__class__.__name__}")
565
  else:
566
- warn("[init] image_processor yine None (fallback da başarısız).")
567
  except Exception as e_ip2:
568
  warn(f"[init] image_processor inspect error: {e_ip2}")
569
 
@@ -579,9 +659,45 @@ def initialize_model():
579
  warn(f"[init] failed: {e}")
580
  return False
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  # ===================== HF EndpointHandler =====================
583
  class EndpointHandler:
584
- """Hugging Face Endpoint uyumlu sınıf"""
585
  def __init__(self, model_dir):
586
  self.model_dir = model_dir
587
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
@@ -595,7 +711,7 @@ class EndpointHandler:
595
  return get_model_info()
596
 
597
  if __name__ == "__main__":
598
- print("Handler ready (Demo Parity + Style Hint + whitespace post-process + dynamic size + fallbacks + debug). Use `EndpointHandler` or `query`.")
599
 
600
  # ===================== Minimal FastAPI Wrapper =====================
601
  try:
@@ -607,7 +723,7 @@ except Exception as e:
607
  warn(f"fastapi/pydantic not available: {e}")
608
 
609
  if FASTAPI_AVAILABLE:
610
- app = FastAPI(title="PULSE ECG Handler API", version="1.0.0")
611
 
612
  class QueryIn(BaseModel):
613
  message: str | None = None
@@ -625,6 +741,7 @@ if FASTAPI_AVAILABLE:
625
  repetition_penalty: float | None = None
626
  conv_mode: str | None = None
627
  det_seed: int | None = None
 
628
 
629
  @app.on_event("startup")
630
  async def _startup():
@@ -677,5 +794,17 @@ if FASTAPI_AVAILABLE:
677
  @app.post("/query")
678
  async def _query(payload: QueryIn):
679
  return query({k: v for k, v in payload.dict().items() if v is not None})
 
 
 
 
 
 
 
 
 
 
 
 
680
  else:
681
- app = None # uvicorn handler:app çalıştırıldığında import error verir
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler — Demo Parity + Style Hint + Robust Fallbacks + Debug + Dynamic Vision Size + JSON/Report (EN)
4
+ - Generation settings aligned with demo app.py:
5
  do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
+ - Stopping: safe keyword match on conversation separator (conv.sep/sep2)
7
+ - Image tensor: .half() on model device
8
+ - Streamer: TextIteratorStreamer with background thread (demo-like)
9
+ - Stochastic by default (seed/deterministic OFF unless provided)
10
+ - STYLE_HINT: narrative + single-line 'Structured clinical impression:' ending
11
+ - Post-process: whitespace cleanup only
12
+ - Extras:
13
+ * DEBUG helpers (ENV: DEBUG=1)
14
+ * Dynamic vision size (vision tower -> processor + preprocess/fallback)
15
  * image_processor fallback (AutoProcessor → CLIPImageProcessor)
16
  * process_images fallback (torchvision + CLIP norm)
17
+ * FastAPI wrapper: /health, /info, /query, /debug, /analyze/json, /analyze/report-en
18
+ * JSON schema (EN) and report renderer (table text + narrative)
19
  """
20
 
21
  import os
 
67
  TRANSFORMERS_AVAILABLE = False
68
  warn(f"transformers not available: {e}")
69
 
70
+ # ====== HF Hub logging (optional) ======
71
  try:
72
  from huggingface_hub import HfApi, login
73
  HF_HUB_AVAILABLE = True
 
97
  args = None
98
  model_initialized = False
99
 
100
+ # ====== Style Hint (demo-like narrative) ======
101
  STYLE_HINT = (
102
  "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
103
  "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
 
106
  "followed by a succinct, comma-separated summary of the key diagnoses."
107
  )
108
 
109
+ # ====== JSON Schema (EN) for strict machine-readable output ======
110
+ JSON_SCHEMA_HINT_EN = """
111
+ Return ONLY a valid JSON object that matches EXACTLY this schema:
112
+
113
+ {
114
+ "heart_rate_bpm": int, // e.g., 128
115
+ "rhythm": "string", // e.g., "Sinus tachycardia"
116
+ "qrs_axis": "string", // e.g., "Normal (+16°)"
117
+ "p_waves": "string", // e.g., "Normal"
118
+ "pr_interval_ms": int, // e.g., 160
119
+ "qrs_duration_ms": int, // e.g., 84
120
+ "t_waves": "string", // e.g., "Negative in DIII, aVF, V1–V4"
121
+ "qtc_ms": int, // e.g., 467
122
+ "qtc_comment": "string", // e.g., "Mildly prolonged"
123
+ "additional_comments": "string" // e.g., "S1Q3T3 pattern and anterior T-wave inversions present."
124
+ }
125
+
126
+ Rules:
127
+ - Output MUST be valid JSON with no extra text before or after.
128
+ - Units: use numbers for bpm and ms (integers only).
129
+ - If unknown, use null (ints may be null).
130
+ - Use standard cardiology terminology in English.
131
+ """
132
+
133
  # ===================== Utilities =====================
134
  def _safe_upload(path: str):
135
  if api and repo_name and path and os.path.isfile(path):
 
149
 
150
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
151
  """
152
+ Supported:
153
  - URL (http/https)
154
+ - local file path
155
+ - base64 (optionally with data URL prefix)
156
  - {"image": <base64|dataurl>}
157
  """
158
  if isinstance(image_input, str):
 
163
  return Image.open(BytesIO(r.content)).convert("RGB")
164
  if os.path.exists(s):
165
  return Image.open(s).convert("RGB")
166
+ # base64 (maybe dataurl)
167
  if s.startswith("data:image"):
168
  s = s.split(",", 1)[1]
169
  raw = base64.b64decode(s)
 
187
  # ====== Vision helpers (dynamic size) ======
188
  def get_vision_expected_size(m, default: int = 336) -> int:
189
  """
190
+ Returns expected image size for the model's vision tower (e.g., 336).
 
191
  """
192
  try:
193
  vt = m.get_vision_tower()
 
204
  return default
205
 
206
  def force_processor_size(proc, size: int):
207
+ """Force processor resize/crop to target size safely."""
208
  try:
209
  # size
210
  if hasattr(proc, "size"):
 
230
  except Exception as e:
231
  warn(f"[processor] force size failed: {e}")
232
 
233
+ # ====== Safe Stop Criteria (conv separator) ======
234
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
235
  def __init__(self, keyword: str, tokenizer):
236
  self.tokenizer = tokenizer
 
265
  def __init__(self):
266
  self.chatbot = None
267
  self.args = None
268
+ self.model_path = None
269
  def init_if_needed(self, args, model_path, tokenizer, model, image_processor, context_len):
270
  if self.chatbot is None:
271
  self.args = args
 
298
  conv_mode_override: Optional[str] = None,
299
  repetition_penalty: Optional[float] = None,
300
  det_seed: Optional[int] = None,
301
+ output_mode: Optional[str] = "narrative", # "narrative" | "json" | "report_en"
302
  ):
303
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
304
  return {"error": "Required libraries not available (llava/transformers)"}
 
310
  if max_new_tokens is None: max_new_tokens = 4096
311
  if repetition_penalty is None: repetition_penalty = 1.0
312
 
313
+ dbg(f"[gen] temperature={temperature} top_p={top_p} max_new_tokens={max_new_tokens} rep={repetition_penalty} seed={det_seed} mode={output_mode}")
314
+
315
+ # For "report_en", compose by calling json + narrative branches
316
+ if output_mode == "report_en":
317
+ first = generate_response(
318
+ message_text=message_text,
319
+ image_input=image_input,
320
+ temperature=temperature, top_p=top_p,
321
+ max_new_tokens=max_new_tokens,
322
+ conv_mode_override=conv_mode_override,
323
+ repetition_penalty=repetition_penalty,
324
+ det_seed=det_seed,
325
+ output_mode="json",
326
+ )
327
+ if not isinstance(first, dict) or "response" not in first or not isinstance(first["response"], dict):
328
+ return first
329
+ data = first["response"]
330
+
331
+ second = generate_response(
332
+ message_text=message_text,
333
+ image_input=image_input,
334
+ temperature=temperature, top_p=top_p,
335
+ max_new_tokens=min(int(max_new_tokens), 512),
336
+ conv_mode_override=conv_mode_override,
337
+ repetition_penalty=repetition_penalty,
338
+ det_seed=det_seed,
339
+ output_mode="narrative",
340
+ )
341
+ narrative = second.get("response") if isinstance(second, dict) else None
342
+
343
+ table_txt = render_ecg_table_en(data)
344
+ return {
345
+ "status": "success",
346
+ "report": {"table_text": table_txt, "json": data, "narrative": narrative},
347
+ "conversation_id": id(chatbot) # not conversation; narrative branch already logged
348
+ }
349
 
350
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
351
  if conv_mode_override and conv_mode_override in conv_templates:
352
  chatbot.conversation = conv_templates[conv_mode_override].copy()
353
 
354
+ # Load image (PIL)
355
  try:
356
  pil_img = load_image_any(image_input)
357
  except Exception as e:
358
  return {"error": f"Failed to load image: {e}"}
359
 
360
+ # Save image log (optional)
361
  img_hash, img_path = "NA", None
362
  try:
363
  buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
 
373
  device = next(chatbot.model.parameters()).device
374
  dtype = torch.float16
375
 
376
+ # === Image preprocessingtensor (dynamic size) ===
377
  expected_size = get_vision_expected_size(chatbot.model, default=336)
378
  dbg(f"[pre] dynamic expected_size={expected_size} | processor={type(chatbot.image_processor)}")
379
 
 
380
  image_tensor = None
381
  try:
382
  if hasattr(chatbot.image_processor, "preprocess"):
383
  px = chatbot.image_processor.preprocess(pil_img, return_tensors="pt")
384
  image_tensor = px.get("pixel_values", px)
385
  if not isinstance(image_tensor, torch.Tensor):
 
386
  image_tensor = image_tensor["pixel_values"]
387
  if image_tensor.ndim == 3:
388
  image_tensor = image_tensor.unsqueeze(0)
 
391
  else:
392
  raise AttributeError("processor has no preprocess")
393
  except Exception as e_pre:
394
+ warn(f"[pre] processor.preprocess not used: {e_pre} → process_images fallback…")
 
395
  try:
396
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
397
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
 
406
  image_tensor = image_tensor.to(device=device, dtype=dtype)
407
  dbg(f"[pre] process_images ok → {tuple(image_tensor.shape)}")
408
  except Exception as e_proc:
409
+ warn(f"[pre] process_images failed: {e_proc} → manual CLIP fallback (dynamic size).")
 
410
  from torchvision import transforms
411
  from torchvision.transforms import InterpolationMode
412
  preprocess = transforms.Compose([
 
424
  if image_tensor is None:
425
  return {"error": "Image processing failed (no tensor produced)"}
426
 
427
+ # ===== Build message according to output_mode =====
428
+ base_msg = (message_text or "").strip()
429
+ if output_mode == "json":
430
+ msg = f"{base_msg}\n\n{JSON_SCHEMA_HINT_EN}"
431
+ else: # "narrative"
432
+ msg = f"{base_msg}\n\n{STYLE_HINT}"
433
+
434
  dbg(f"[prompt] conv_sep_style={chatbot.conversation.sep_style} sep_len={len(chatbot.conversation.sep)}")
435
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
436
 
 
474
  except Exception as e:
475
  return {"error": f"Generation failed: {e}"}
476
 
477
+ # Logging
478
  try:
479
  row = {
480
  "time": datetime.datetime.now().isoformat(),
 
490
  except Exception as e:
491
  warn(f"[log] failed: {e}")
492
 
493
+ # If JSON mode, parse and return as object
494
+ if output_mode == "json":
495
+ try:
496
+ start = text.find("{"); end = text.rfind("}")
497
+ if start != -1 and end != -1 and end > start:
498
+ obj = json.loads(text[start:end+1])
499
+ else:
500
+ return {"error": "JSON block not found", "raw": text}
501
+ except Exception as e:
502
+ return {"error": f"JSON parse failed: {e}", "raw": text}
503
+ return {"status": "success", "response": obj, "conversation_id": id(chatbot.conversation)}
504
+
505
+ # Default narrative
506
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
507
 
508
  # ===================== Public API =====================
 
527
 
528
  conv_mode_override = payload.get("conv_mode", None)
529
  det_seed = payload.get("det_seed", None)
530
+ output_mode = payload.get("output_mode", "narrative") # "narrative" | "json" | "report_en"
531
+
532
  if det_seed is not None:
533
  try: det_seed = int(det_seed)
534
  except Exception: det_seed = None
 
542
  conv_mode_override=conv_mode_override,
543
  repetition_penalty=repetition_penalty,
544
  det_seed=det_seed,
545
+ output_mode=output_mode,
546
  )
547
  except Exception as e:
548
  return {"error": f"Query failed: {e}"}
 
601
  model_.eval()
602
  dbg(f"[init] device={next(model_.parameters()).device}, cuda_available={torch.cuda.is_available()}")
603
 
604
+ # Vision tower expected image size
605
  expected_size = get_vision_expected_size(model_, default=336)
606
  dbg(f"[init] vision expected image_size={expected_size}")
607
 
608
+ # image_processor fallback chain
609
  try:
610
  if image_processor_ is None:
611
+ dbg("[init] image_processor None → AutoProcessor(model_path)…")
612
  try:
613
  from transformers import AutoProcessor
614
  image_processor_ = AutoProcessor.from_pretrained(args.model_path)
615
+ dbg("[init] image_processor: AutoProcessor.from_pretrained(model_path) loaded.")
616
  except Exception as _e1:
617
  dbg(f"[init] AutoProcessor(model_path) failed: {_e1}")
618
  try:
619
  from transformers import AutoProcessor
620
  clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
621
  image_processor_ = AutoProcessor.from_pretrained(clip_id)
622
+ dbg(f"[init] AutoProcessor({clip_id}) loaded.")
623
  except Exception as _e2:
624
  from transformers import CLIPImageProcessor
625
  clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
626
  image_processor_ = CLIPImageProcessor.from_pretrained(clip_id)
627
+ warn(f"[init] CLIPImageProcessor({clip_id}) fallback in use.")
628
  except Exception as _e:
629
  warn(f"[init] image_processor fallback chain failed: {_e}")
630
 
631
+ # Force processor sizes to match tower
632
  try:
633
  if image_processor_ is not None:
634
  force_processor_size(image_processor_, expected_size)
635
  except Exception as e_ip:
636
  warn(f"[init] processor size set error: {e_ip}")
637
 
638
+ # Processor introspection
639
  try:
640
  ip = image_processor_
641
  if ip is not None:
 
643
  size_sz = getattr(getattr(ip, "size", None), "shortest_edge", None) or getattr(ip, "size", None)
644
  dbg(f"[init] image_processor crop_size={crop_sz} size={size_sz} class={ip.__class__.__name__}")
645
  else:
646
+ warn("[init] image_processor still None (fallback failed).")
647
  except Exception as e_ip2:
648
  warn(f"[init] image_processor inspect error: {e_ip2}")
649
 
 
659
  warn(f"[init] failed: {e}")
660
  return False
661
 
662
+ # ===================== Report rendering (EN) =====================
663
+ def render_ecg_table_en(d: Dict[str, Any]) -> str:
664
+ def g(k, default="—"):
665
+ v = d.get(k, None)
666
+ if v is None: return default
667
+ return str(v)
668
+
669
+ hr = g("heart_rate_bpm")
670
+ rhythm = g("rhythm")
671
+ axis = g("qrs_axis")
672
+ p = g("p_waves")
673
+ pr = g("pr_interval_ms")
674
+ qrs_dur = g("qrs_duration_ms")
675
+ t = g("t_waves")
676
+ qtc = g("qtc_ms")
677
+ qtc_c = g("qtc_comment")
678
+ extra = g("additional_comments")
679
+
680
+ lines = [
681
+ "ECG ANALYSIS",
682
+ "────────────",
683
+ f"Heart rate : {hr} beats/min",
684
+ f"Rhythm : {rhythm}",
685
+ f"QRS axis : {axis}",
686
+ f"P waves : {p}",
687
+ f"PR interval : {pr} ms",
688
+ f"QRS duration : {qrs_dur} ms",
689
+ f"T waves : {t}",
690
+ f"QTc : {qtc_c} ({qtc} ms)",
691
+ "",
692
+ "Additional comments",
693
+ "──────────────────",
694
+ f"{extra}"
695
+ ]
696
+ return "\n".join(lines)
697
+
698
  # ===================== HF EndpointHandler =====================
699
  class EndpointHandler:
700
+ """Hugging Face Endpoint-compatible wrapper."""
701
  def __init__(self, model_dir):
702
  self.model_dir = model_dir
703
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
711
  return get_model_info()
712
 
713
  if __name__ == "__main__":
714
+ print("Handler ready (Demo Parity + Style Hint + whitespace post-process + dynamic size + fallbacks + debug + JSON/Report-EN). Use `EndpointHandler` or `query`.")
715
 
716
  # ===================== Minimal FastAPI Wrapper =====================
717
  try:
 
723
  warn(f"fastapi/pydantic not available: {e}")
724
 
725
  if FASTAPI_AVAILABLE:
726
+ app = FastAPI(title="PULSE ECG Handler API", version="1.1.0")
727
 
728
  class QueryIn(BaseModel):
729
  message: str | None = None
 
741
  repetition_penalty: float | None = None
742
  conv_mode: str | None = None
743
  det_seed: int | None = None
744
+ output_mode: str | None = None # "narrative" | "json" | "report_en"
745
 
746
  @app.on_event("startup")
747
  async def _startup():
 
794
  @app.post("/query")
795
  async def _query(payload: QueryIn):
796
  return query({k: v for k, v in payload.dict().items() if v is not None})
797
+
798
+ @app.post("/analyze/json")
799
+ async def analyze_json(payload: QueryIn):
800
+ data = {k: v for k, v in payload.dict().items() if v is not None}
801
+ data["output_mode"] = "json"
802
+ return query(data)
803
+
804
+ @app.post("/analyze/report-en")
805
+ async def analyze_report_en(payload: QueryIn):
806
+ data = {k: v for k, v in payload.dict().items() if v is not None}
807
+ data["output_mode"] = "report_en"
808
+ return query(data)
809
  else:
810
+ app = None # Running "uvicorn handler:app" will raise import error if FastAPI missing