CanerDedeoglu commited on
Commit
2777fd6
·
verified ·
1 Parent(s): 3a8fcc6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +224 -127
handler.py CHANGED
@@ -1,21 +1,27 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler — Deterministic JSON→Narrative (age+sex aware)
4
- - Model still processes image (LLaVA/transformers)
5
- - output_mode="json" → returns structured JSON (single model call)
6
- - output_mode="report_en" JSON + table + narrative (derived deterministically from JSON; still single model call)
7
- - output_mode="narrative" classic narrative paragraph (model free-form)
8
-
9
- Notes:
10
- - For "json" and "report_en" modes we prompt the model with a strict JSON schema hint.
11
- - Age group ("0-15" | "15-65" | "65+") and sex ("male" | "female") are accepted from payload
12
- and used only in deterministic narrative rendering (not sent to the model).
 
 
 
 
 
13
  """
14
 
15
  import os
16
  import re
17
  import json
18
  import base64
 
19
  import hashlib
20
  import datetime
21
  from io import BytesIO
@@ -26,7 +32,7 @@ import torch
26
  from PIL import Image
27
  import requests
28
 
29
- # ==== Debug helpers ====
30
  def _env_bool(name: str, default: bool = False) -> bool:
31
  v = os.getenv(name)
32
  if v is None:
@@ -42,7 +48,7 @@ def dbg(*args, **kwargs):
42
  def warn(*args, **kwargs):
43
  print("[WARN]", *args, **kwargs)
44
 
45
- # ==== LLaVA & Transformers ====
46
  try:
47
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
48
  from llava.conversation import conv_templates, SeparatorStyle
@@ -61,7 +67,7 @@ except Exception as e:
61
  TRANSFORMERS_AVAILABLE = False
62
  warn(f"transformers not available: {e}")
63
 
64
- # ==== HF Hub logging (optional) ====
65
  try:
66
  from huggingface_hub import HfApi, login
67
  HF_HUB_AVAILABLE = True
@@ -77,13 +83,12 @@ if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
77
  repo_name = os.environ.get("LOG_REPO", "")
78
  except Exception as e:
79
  warn(f"[HF Hub] init failed: {e}")
80
- api = None
81
- repo_name = ""
82
 
83
  LOGDIR = "./logs"
84
  os.makedirs(LOGDIR, exist_ok=True)
85
 
86
- # ==== Global state ====
87
  tokenizer = None
88
  model = None
89
  image_processor = None
@@ -91,7 +96,7 @@ context_len = None
91
  args = None
92
  model_initialized = False
93
 
94
- # ==== Prompts ====
95
  STYLE_HINT = (
96
  "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
97
  "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
@@ -100,28 +105,26 @@ STYLE_HINT = (
100
  "followed by a succinct, comma-separated summary of the key diagnoses."
101
  )
102
 
 
103
  JSON_SCHEMA_HINT_EN = """
104
- Return ONLY a valid JSON object that matches EXACTLY this schema:
 
 
105
  {
106
- "heart_rate_bpm": int | null,
107
- "rhythm": "string",
108
- "qrs_axis": "string",
109
- "p_waves": "string",
110
- "pr_interval_ms": int | null,
111
- "qrs_duration_ms": int | null,
112
- "t_waves": "string",
113
- "qtc_ms": int | null,
114
- "qtc_comment": "string",
115
- "additional_comments": "string"
116
  }
117
- Rules:
118
- - Output MUST be valid JSON with no extra text before or after.
119
- - Units: use integers for bpm and ms where applicable.
120
- - If unknown, use null for numeric fields and empty string for text fields.
121
- - Use standard cardiology terminology in English.
122
  """
123
 
124
- # ===================== Utilities =====================
125
  def _safe_upload(path: str):
126
  if api and repo_name and path and os.path.isfile(path):
127
  try:
@@ -139,6 +142,9 @@ def _conv_log_path() -> str:
139
  return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
140
 
141
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
 
 
 
142
  if isinstance(image_input, str):
143
  s = image_input.strip()
144
  if s.startswith(("http://", "https://")):
@@ -165,8 +171,111 @@ def _normalize_whitespace(text: str) -> str:
165
  def _postprocess_min(text: str) -> str:
166
  return _normalize_whitespace(text)
167
 
168
- # ====== Vision helpers ======
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def get_vision_expected_size(m, default: int = 336) -> int:
 
 
 
170
  try:
171
  vt = m.get_vision_tower()
172
  vt_cfg = getattr(getattr(vt, "vision_tower", vt), "config", None)
@@ -182,30 +291,31 @@ def get_vision_expected_size(m, default: int = 336) -> int:
182
  return default
183
 
184
  def force_processor_size(proc, size: int):
 
185
  try:
186
  if hasattr(proc, "size"):
187
  if isinstance(proc.size, dict):
188
  proc.size["shortest_edge"] = size
189
  else:
190
  try:
191
- proc.size.shortest_edge = size
192
  except Exception:
193
  proc.size = {"shortest_edge": size}
194
  if hasattr(proc, "crop_size"):
195
  if isinstance(proc.crop_size, dict):
196
  proc.crop_size["height"] = size
197
- proc.crop_size["width"] = size
198
  else:
199
  try:
200
- proc.crop_size.height = size
201
- proc.crop_size.width = size
202
  except Exception:
203
  proc.crop_size = {"height": size, "width": size}
204
  dbg(f"[processor] forced size={size}")
205
  except Exception as e:
206
  warn(f"[processor] force size failed: {e}")
207
 
208
- # ====== Stop Criteria ======
209
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
210
  def __init__(self, keyword: str, tokenizer):
211
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
@@ -220,7 +330,7 @@ class SafeKeywordsStoppingCriteria(StoppingCriteria):
220
  tail = out[-n:]
221
  return torch.equal(tail, self.kw_ids.to(tail.device))
222
 
223
- # ===================== Core =====================
224
  class InferenceDemo:
225
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
226
  if not LLAVA_AVAILABLE:
@@ -260,10 +370,10 @@ def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
260
  ).unsqueeze(0).to(device)
261
  return prompt, input_ids
262
 
263
- # ===================== Deterministic renderers =====================
264
  def render_ecg_table_en(d: Dict[str, Any]) -> str:
265
  lines = ["ECG ANALYSIS", "────────────"]
266
- if "heart_rate_bpm" in d and d["heart_rate_bpm"] is not None:
267
  lines.append(f"Heart rate : {d['heart_rate_bpm']} beats/min")
268
  if "rhythm" in d:
269
  lines.append(f"Rhythm : {d['rhythm']}")
@@ -271,24 +381,20 @@ def render_ecg_table_en(d: Dict[str, Any]) -> str:
271
  lines.append(f"QRS axis : {d['qrs_axis']}")
272
  if "p_waves" in d:
273
  lines.append(f"P waves : {d['p_waves']}")
274
- if "pr_interval_ms" in d and d["pr_interval_ms"] is not None:
275
  lines.append(f"PR interval : {d['pr_interval_ms']} ms")
276
- if "qrs_duration_ms" in d and d["qrs_duration_ms"] is not None:
277
  lines.append(f"QRS duration : {d['qrs_duration_ms']} ms")
278
  if "t_waves" in d:
279
  lines.append(f"T waves : {d['t_waves']}")
280
- if "qtc_ms" in d and d["qtc_ms"] is not None:
281
- qtc_c = d.get("qtc_comment", "").strip()
282
- qtc_c = qtc_c if qtc_c else "—"
283
  lines.append(f"QTc : {qtc_c} ({d['qtc_ms']} ms)")
284
- lines.append("")
285
- lines.append("Additional comments")
286
- lines.append("──────────────────")
287
- lines.append(d.get("additional_comments", "").strip())
288
  return "\n".join(lines)
289
 
290
  def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
291
- """Deterministic narrative based on JSON + age_group + sex"""
292
  hr = d.get("heart_rate_bpm")
293
  rhythm = d.get("rhythm")
294
  axis = d.get("qrs_axis")
@@ -327,9 +433,17 @@ def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
327
  elif sex:
328
  para.append(f"The patient is {sex}.")
329
 
 
330
  if rhythm:
331
- para.append(f"The electrocardiogram shows {rhythm.lower()}.")
 
 
 
 
 
 
332
 
 
333
  if isinstance(hr, int):
334
  if hr < hr_low:
335
  hr_comment = "bradycardia"
@@ -339,12 +453,11 @@ def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
339
  hr_comment = "within normal range"
340
  para.append(f"The heart rate is {hr} bpm ({hr_comment}).")
341
 
 
342
  if axis:
343
  para.append(f"The QRS axis is {axis.lower()}.")
344
-
345
  if p:
346
  para.append(f"P waves are {p.lower()}.")
347
-
348
  if isinstance(pr, int):
349
  if pr < pr_low:
350
  pr_comment = "short PR interval"
@@ -353,17 +466,11 @@ def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
353
  else:
354
  pr_comment = "within normal range"
355
  para.append(f"PR interval is {pr} ms ({pr_comment}).")
356
-
357
  if isinstance(qrs_dur, int):
358
- if qrs_dur >= qrs_limit:
359
- qrs_comment = "prolonged QRS (possible conduction delay)"
360
- else:
361
- qrs_comment = "normal QRS duration"
362
  para.append(f"QRS duration is {qrs_dur} ms ({qrs_comment}).")
363
-
364
  if t:
365
  para.append(f"T waves: {t}.")
366
-
367
  if isinstance(qtc, int):
368
  if sex == "male":
369
  if qtc > qtc_male:
@@ -393,6 +500,7 @@ def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
393
 
394
  paragraph = " ".join(para).strip()
395
 
 
396
  sci_bits = []
397
  if rhythm: sci_bits.append(rhythm)
398
  if axis: sci_bits.append(f"QRS axis: {axis}")
@@ -403,7 +511,7 @@ def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
403
 
404
  return paragraph + "\n\n" + "Structured clinical impression: " + ", ".join(sci_bits)
405
 
406
- # ===================== Generation =====================
407
  def generate_response(
408
  message_text: str,
409
  image_input,
@@ -428,36 +536,28 @@ def generate_response(
428
  if max_new_tokens is None: max_new_tokens = 4096
429
  if repetition_penalty is None: repetition_penalty = 1.0
430
 
 
 
 
 
 
 
 
431
  dbg(f"[gen] temp={temperature} top_p={top_p} max_new={max_new_tokens} rep={repetition_penalty} mode={output_mode}")
432
 
433
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
434
  if conv_mode_override and conv_mode_override in conv_templates:
435
  chatbot.conversation = conv_templates[conv_mode_override].copy()
436
 
437
- # Load image
438
  try:
439
  pil_img = load_image_any(image_input)
440
  except Exception as e:
441
  return {"error": f"Failed to load image: {e}"}
442
 
443
- # Save image (log)
444
- img_hash, img_path = "NA", None
445
- try:
446
- buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
447
- img_hash = hashlib.md5(raw).hexdigest()
448
- t = datetime.datetime.now()
449
- img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
450
- os.makedirs(os.path.dirname(img_path), exist_ok=True)
451
- if not os.path.isfile(img_path):
452
- pil_img.save(img_path)
453
- except Exception as e:
454
- warn(f"[log] save image failed: {e}")
455
-
456
  device = next(chatbot.model.parameters()).device
457
  dtype = torch.float16
458
 
459
- # Preprocess image → tensor
460
- expected_size = get_vision_expected_size(chatbot.model, default=336)
461
  image_tensor = None
462
  try:
463
  if hasattr(chatbot.image_processor, "preprocess"):
@@ -471,7 +571,6 @@ def generate_response(
471
  else:
472
  raise AttributeError("processor has no preprocess")
473
  except Exception:
474
- # Fallback chain: process_images → manual CLIP norm
475
  try:
476
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
477
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
@@ -486,6 +585,7 @@ def generate_response(
486
  except Exception:
487
  from torchvision import transforms
488
  from torchvision.transforms import InterpolationMode
 
489
  preprocess = transforms.Compose([
490
  transforms.Resize(expected_size, interpolation=InterpolationMode.BICUBIC),
491
  transforms.CenterCrop(expected_size),
@@ -500,13 +600,14 @@ def generate_response(
500
  if image_tensor is None:
501
  return {"error": "Image processing failed (no tensor produced)"}
502
 
503
- # Prompt selection
504
  base_msg = (message_text or "").strip()
505
  if output_mode in ("json", "report_en"):
506
  msg = f"{base_msg}\n\n{JSON_SCHEMA_HINT_EN}"
507
- else: # narrative
508
  msg = f"{base_msg}\n\n{STYLE_HINT}"
509
 
 
510
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
511
 
512
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
@@ -522,12 +623,13 @@ def generate_response(
522
  except Exception:
523
  pass
524
 
 
525
  streamer = TextIteratorStreamer(chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
526
  gen_kwargs = dict(
527
  inputs=input_ids,
528
  images=image_tensor,
529
  streamer=streamer,
530
- do_sample=True,
531
  temperature=float(temperature),
532
  top_p=float(top_p),
533
  max_new_tokens=int(max_new_tokens),
@@ -536,7 +638,6 @@ def generate_response(
536
  stopping_criteria=[stopping],
537
  )
538
 
539
- # Generate
540
  try:
541
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
542
  t.start()
@@ -548,36 +649,28 @@ def generate_response(
548
  except Exception as e:
549
  return {"error": f"Generation failed: {e}"}
550
 
551
- # Log
552
- try:
553
- row = {
554
- "time": datetime.datetime.now().isoformat(),
555
- "type": "chat",
556
- "model": "PULSE-7B",
557
- "state": [(message_text, text)],
558
- "image_hash": img_hash,
559
- "image_path": img_path or "",
560
- }
561
- with open(_conv_log_path(), "a", encoding="utf-8") as f:
562
- f.write(json.dumps(row, ensure_ascii=False) + "\n")
563
- _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
564
- except Exception as e:
565
- warn(f"[log] failed: {e}")
566
-
567
- # Output modes
568
  if output_mode == "narrative":
569
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
570
 
571
- # For json & report_en we need to parse JSON once
572
  try:
573
  start = text.find("{"); end = text.rfind("}")
574
  if start == -1 or end == -1 or end <= start:
575
- return {"error": "JSON block not found", "raw": text}
576
  data = json.loads(text[start:end+1])
577
- except Exception as e:
578
- return {"error": f"JSON parse failed: {e}", "raw": text}
 
 
 
 
 
 
 
 
579
 
580
- # Inject patient metadata (not sent to model; used for deterministic narrative)
581
  if patient_age_group:
582
  data["patient_age_group"] = patient_age_group
583
  if patient_sex:
@@ -598,7 +691,7 @@ def generate_response(
598
  # Fallback
599
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
600
 
601
- # ===================== Public API =====================
602
  def query(payload: dict):
603
  global model_initialized, tokenizer, model, image_processor, context_len, args
604
  if not model_initialized:
@@ -621,7 +714,7 @@ def query(payload: dict):
621
  det_seed = payload.get("det_seed", None)
622
  output_mode = payload.get("output_mode", "narrative")
623
 
624
- # Optional patient meta
625
  patient_age_group = payload.get("patient_age_group")
626
  patient_sex = payload.get("patient_sex")
627
 
@@ -663,7 +756,7 @@ def get_model_info():
663
  "device": str(next(model.parameters()).device) if model else "Unknown",
664
  }
665
 
666
- # ===================== Init & Session =====================
667
  class _Args:
668
  def __init__(self):
669
  self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
@@ -689,6 +782,7 @@ def initialize_model():
689
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
690
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
691
  )
 
692
 
693
  try:
694
  _ = next(model_.parameters()).device
@@ -696,33 +790,37 @@ def initialize_model():
696
  if torch.cuda.is_available():
697
  model_ = model_.to(torch.device("cuda"))
698
  model_.eval()
 
699
 
700
  expected_size = get_vision_expected_size(model_, default=336)
701
- if image_processor_ is None:
702
- try:
703
- from transformers import AutoProcessor
704
- image_processor_ = AutoProcessor.from_pretrained(args.model_path)
705
- except Exception:
706
- from transformers import CLIPImageProcessor
707
- clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
708
- image_processor_ = CLIPImageProcessor.from_pretrained(clip_id)
709
- force_processor_size(image_processor_, expected_size)
 
 
710
 
 
711
  globals()["tokenizer"] = tokenizer_
712
  globals()["model"] = model_
713
  globals()["image_processor"] = image_processor_
714
  globals()["context_len"] = context_len_
715
 
716
  chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_)
717
- print("[init] model/tokenizer/image_processor loaded.]")
718
  return True
719
  except Exception as e:
720
  warn(f"[init] failed: {e}")
721
  return False
722
 
723
- # ===================== EndpointHandler =====================
724
  class EndpointHandler:
725
- """Hugging Face Endpoint compatible"""
726
  def __init__(self, model_dir):
727
  self.model_dir = model_dir
728
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
@@ -736,9 +834,9 @@ class EndpointHandler:
736
  return get_model_info()
737
 
738
  if __name__ == "__main__":
739
- print("Handler ready (Deterministic JSON→Narrative, age+sex aware). Use `EndpointHandler` or `query`.")
740
 
741
- # ===================== FastAPI Wrapper =====================
742
  try:
743
  from fastapi import FastAPI
744
  from pydantic import BaseModel
@@ -748,7 +846,7 @@ except Exception as e:
748
  warn(f"fastapi/pydantic not available: {e}")
749
 
750
  if FASTAPI_AVAILABLE:
751
- app = FastAPI(title="PULSE ECG Handler API", version="1.2.0")
752
 
753
  class QueryIn(BaseModel):
754
  message: str | None = None
@@ -801,5 +899,4 @@ if FASTAPI_AVAILABLE:
801
  data["output_mode"] = "report_en"
802
  return query(data)
803
  else:
804
- app = None
805
-
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler — Deterministic JSON Table + Narrative (age+sex aware) with Robust Fallbacks
4
+
5
+ Modes
6
+ - output_mode="json" returns structured JSON (single model call)
7
+ - output_mode="report_en" returns JSON + table + deterministic narrative (single model call)
8
+ - output_mode="narrative" → classic free-form model narrative (STYLE_HINT used)
9
+
10
+ Highlights
11
+ - Age group ("0-15" | "15-65" | "65+") and sex ("male" | "female") are accepted in payload and are
12
+ used only in deterministic narrative rendering (not sent to the model).
13
+ - Robust JSON parsing:
14
+ 1) direct JSON slice
15
+ 2) cleanup pseudo-JSON (_coerce_pseudo_json)
16
+ 3) regex-based field extraction from free text (_extract_fields_from_text)
17
+ - Safe stop criteria, dynamic vision-size processor, logging hooks (optional HF Hub upload).
18
  """
19
 
20
  import os
21
  import re
22
  import json
23
  import base64
24
+ import math
25
  import hashlib
26
  import datetime
27
  from io import BytesIO
 
32
  from PIL import Image
33
  import requests
34
 
35
+ # ========= Debug Helpers =========
36
  def _env_bool(name: str, default: bool = False) -> bool:
37
  v = os.getenv(name)
38
  if v is None:
 
48
  def warn(*args, **kwargs):
49
  print("[WARN]", *args, **kwargs)
50
 
51
+ # ========= LLaVA & Transformers =========
52
  try:
53
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
54
  from llava.conversation import conv_templates, SeparatorStyle
 
67
  TRANSFORMERS_AVAILABLE = False
68
  warn(f"transformers not available: {e}")
69
 
70
+ # ========= (Optional) HF Hub logging =========
71
  try:
72
  from huggingface_hub import HfApi, login
73
  HF_HUB_AVAILABLE = True
 
83
  repo_name = os.environ.get("LOG_REPO", "")
84
  except Exception as e:
85
  warn(f"[HF Hub] init failed: {e}")
86
+ api, repo_name = None, ""
 
87
 
88
  LOGDIR = "./logs"
89
  os.makedirs(LOGDIR, exist_ok=True)
90
 
91
+ # ========= Global State =========
92
  tokenizer = None
93
  model = None
94
  image_processor = None
 
96
  args = None
97
  model_initialized = False
98
 
99
+ # ========= Prompts =========
100
  STYLE_HINT = (
101
  "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
102
  "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
 
105
  "followed by a succinct, comma-separated summary of the key diagnoses."
106
  )
107
 
108
+ # Example-only schema (no type hints). The model copies this structure.
109
  JSON_SCHEMA_HINT_EN = """
110
+ Return ONLY a valid JSON object. Do not include comments, types, or extra text.
111
+ If a value is unknown, use null (for numbers) or "" (for strings).
112
+
113
  {
114
+ "heart_rate_bpm": 100,
115
+ "rhythm": "Sinus rhythm",
116
+ "qrs_axis": "Normal",
117
+ "p_waves": "Normal",
118
+ "pr_interval_ms": 160,
119
+ "qrs_duration_ms": 90,
120
+ "t_waves": "Normal",
121
+ "qtc_ms": 420,
122
+ "qtc_comment": "Normal",
123
+ "additional_comments": ""
124
  }
 
 
 
 
 
125
  """
126
 
127
+ # ========= Utilities =========
128
  def _safe_upload(path: str):
129
  if api and repo_name and path and os.path.isfile(path):
130
  try:
 
142
  return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
143
 
144
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
145
+ """
146
+ Supports: http(s) URL, local path, base64 (with or without data URL prefix), or {"image": <...>}
147
+ """
148
  if isinstance(image_input, str):
149
  s = image_input.strip()
150
  if s.startswith(("http://", "https://")):
 
171
  def _postprocess_min(text: str) -> str:
172
  return _normalize_whitespace(text)
173
 
174
+ def _coerce_pseudo_json(text: str) -> str:
175
+ """
176
+ Coerce pseudo-JSON (e.g., 'int | none', 'none', Python booleans) into valid JSON string.
177
+ """
178
+ if not isinstance(text, str):
179
+ return ""
180
+ s = text
181
+
182
+ # Keep only the outermost JSON object if stray tokens are around
183
+ i, j = s.find("{"), s.rfind("}")
184
+ if i != -1 and j != -1 and j > i:
185
+ s = s[i:j+1]
186
+
187
+ # Remove type-like hints → replace with valid JSON placeholders
188
+ s = re.sub(r':\s*int\s*\|\s*none', ': null', s, flags=re.I)
189
+ s = re.sub(r':\s*string\s*\|\s*none', ': ""', s, flags=re.I)
190
+
191
+ # Python/other tokens → JSON
192
+ s = re.sub(r'\bNone\b|\bnone\b', 'null', s, flags=re.I)
193
+ s = re.sub(r'\bTrue\b', 'true', s)
194
+ s = re.sub(r'\bFalse\b', 'false', s)
195
+
196
+ # Strip inline comments
197
+ s = re.sub(r'//.*', '', s) # JS style
198
+ s = re.sub(r'#.*', '', s) # Python style
199
+
200
+ # Collapse repeated commas
201
+ s = re.sub(r',\s*,+', ',', s)
202
+
203
+ return s.strip()
204
+
205
+ def _to_int_or_none(x: Optional[str]) -> Optional[int]:
206
+ if x is None:
207
+ return None
208
+ x = x.strip()
209
+ if not x:
210
+ return None
211
+ try:
212
+ v = int(float(x))
213
+ if math.isnan(v):
214
+ return None
215
+ return v
216
+ except Exception:
217
+ return None
218
+
219
+ def _extract_fields_from_text(text: str) -> Dict[str, Any]:
220
+ """
221
+ Extract fields from free text when model failed to return valid JSON.
222
+ Missing numeric fields -> None; missing text -> "".
223
+ """
224
+ if not isinstance(text, str):
225
+ text = str(text or "")
226
+
227
+ def rex(pattern, flags=re.I):
228
+ m = re.search(pattern, text, flags)
229
+ return m.group(1).strip() if m else None
230
+
231
+ # bpm
232
+ hr = rex(r"(?:heart\s*rate|hr)\s*[:=]?\s*(\d{1,3})\s*(?:bpm|beats?/min)?")
233
+ if hr is None:
234
+ hr = rex(r"\b(\d{2,3})\s*(?:bpm|beats?/min)\b")
235
+
236
+ # PR/QRS/QTc ms
237
+ pr = rex(r"\bPR\s*(?:interval)?\s*[:=]?\s*(\d{2,4})\s*ms\b")
238
+ qrs = rex(r"\bQRS\s*(?:duration)?\s*[:=]?\s*(\d{2,4})\s*ms\b")
239
+ qtc = rex(r"\bQTc?\s*[:=]?\s*(\d{2,4})\s*ms\b")
240
+
241
+ # Axis
242
+ axis = rex(r"\bQRS\s*axis\s*[:=]?\s*([+\-]?\d+°|normal|left|right|indeterminate)\b")
243
+
244
+ # Rhythm
245
+ rhythm = rex(r"\brhythm\s*[:=]?\s*([A-Za-z \-]+)")
246
+ if rhythm is None:
247
+ rhythm = rex(r"\b(sinus\s+(?:tachycardia|bradycardia|rhythm)|atrial fibrillation|afib|atrial flutter|junctional rhythm)\b")
248
+
249
+ # P / T waves
250
+ p_waves = rex(r"\bP\s*waves?\s*[:=]?\s*([A-Za-z0-9, \-]+)")
251
+ t_waves = rex(r"\bT\s*waves?\s*[:=]?\s*([A-Za-z0-9, \-]+)")
252
+
253
+ # QTc comment
254
+ qtc_comment = rex(r"\bQTc\s*(?:comment|status)?\s*[:=]?\s*([A-Za-z \-]+)")
255
+
256
+ # Additional
257
+ additional = rex(r"(?:Additional\s*comments|Notes?)\s*[:\-]?\s*([\s\S]{0,300})")
258
+ if not additional:
259
+ additional = rex(r"\b(ST[- ](?:elevation|depression)|S1Q3T3|early repolarization|strain pattern)\b(?:[^\n\r]{0,120})")
260
+
261
+ return {
262
+ "heart_rate_bpm": _to_int_or_none(hr),
263
+ "rhythm": (rhythm or "").strip(),
264
+ "qrs_axis": (axis or "").strip(),
265
+ "p_waves": (p_waves or "").strip(),
266
+ "pr_interval_ms": _to_int_or_none(pr),
267
+ "qrs_duration_ms": _to_int_or_none(qrs),
268
+ "t_waves": (t_waves or "").strip(),
269
+ "qtc_ms": _to_int_or_none(qtc),
270
+ "qtc_comment": (qtc_comment or "").strip(),
271
+ "additional_comments": (additional or "").strip(),
272
+ }
273
+
274
+ # ========= Vision helpers =========
275
  def get_vision_expected_size(m, default: int = 336) -> int:
276
+ """
277
+ Return expected image size for the model vision tower if available.
278
+ """
279
  try:
280
  vt = m.get_vision_tower()
281
  vt_cfg = getattr(getattr(vt, "vision_tower", vt), "config", None)
 
291
  return default
292
 
293
  def force_processor_size(proc, size: int):
294
+ """Force processor resize/crop to target size safely."""
295
  try:
296
  if hasattr(proc, "size"):
297
  if isinstance(proc.size, dict):
298
  proc.size["shortest_edge"] = size
299
  else:
300
  try:
301
+ proc.size.shortest_edge = size # type: ignore[attr-defined]
302
  except Exception:
303
  proc.size = {"shortest_edge": size}
304
  if hasattr(proc, "crop_size"):
305
  if isinstance(proc.crop_size, dict):
306
  proc.crop_size["height"] = size
307
+ proc.crop_size["width"] = size
308
  else:
309
  try:
310
+ proc.crop_size.height = size # type: ignore[attr-defined]
311
+ proc.crop_size.width = size # type: ignore[attr-defined]
312
  except Exception:
313
  proc.crop_size = {"height": size, "width": size}
314
  dbg(f"[processor] forced size={size}")
315
  except Exception as e:
316
  warn(f"[processor] force size failed: {e}")
317
 
318
+ # ========= Safe Stopper =========
319
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
320
  def __init__(self, keyword: str, tokenizer):
321
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
 
330
  tail = out[-n:]
331
  return torch.equal(tail, self.kw_ids.to(tail.device))
332
 
333
+ # ========= Core Session =========
334
  class InferenceDemo:
335
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
336
  if not LLAVA_AVAILABLE:
 
370
  ).unsqueeze(0).to(device)
371
  return prompt, input_ids
372
 
373
+ # ========= Deterministic Renderers =========
374
  def render_ecg_table_en(d: Dict[str, Any]) -> str:
375
  lines = ["ECG ANALYSIS", "────────────"]
376
+ if d.get("heart_rate_bpm") is not None:
377
  lines.append(f"Heart rate : {d['heart_rate_bpm']} beats/min")
378
  if "rhythm" in d:
379
  lines.append(f"Rhythm : {d['rhythm']}")
 
381
  lines.append(f"QRS axis : {d['qrs_axis']}")
382
  if "p_waves" in d:
383
  lines.append(f"P waves : {d['p_waves']}")
384
+ if d.get("pr_interval_ms") is not None:
385
  lines.append(f"PR interval : {d['pr_interval_ms']} ms")
386
+ if d.get("qrs_duration_ms") is not None:
387
  lines.append(f"QRS duration : {d['qrs_duration_ms']} ms")
388
  if "t_waves" in d:
389
  lines.append(f"T waves : {d['t_waves']}")
390
+ if d.get("qtc_ms") is not None:
391
+ qtc_c = (d.get("qtc_comment") or "").strip() or "—"
 
392
  lines.append(f"QTc : {qtc_c} ({d['qtc_ms']} ms)")
393
+ lines += ["", "Additional comments", "──────────────────", (d.get("additional_comments") or "").strip()]
 
 
 
394
  return "\n".join(lines)
395
 
396
  def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
397
+ """Deterministic narrative based on JSON + age_group + sex with 'Structured clinical impression' at the end."""
398
  hr = d.get("heart_rate_bpm")
399
  rhythm = d.get("rhythm")
400
  axis = d.get("qrs_axis")
 
433
  elif sex:
434
  para.append(f"The patient is {sex}.")
435
 
436
+ # Rhythm with age-adjusted normalization for sinus tachycardia
437
  if rhythm:
438
+ if rhythm.lower() == "sinus tachycardia" and isinstance(hr, int) and hr_low <= hr <= hr_high:
439
+ para.append(
440
+ f"The electrocardiogram shows sinus rhythm, normal for age. "
441
+ f"Although labelled as sinus tachycardia, the heart rate of {hr} bpm is within the normal range for this age group."
442
+ )
443
+ else:
444
+ para.append(f"The electrocardiogram shows {rhythm.lower()}.")
445
 
446
+ # Heart rate comment
447
  if isinstance(hr, int):
448
  if hr < hr_low:
449
  hr_comment = "bradycardia"
 
453
  hr_comment = "within normal range"
454
  para.append(f"The heart rate is {hr} bpm ({hr_comment}).")
455
 
456
+ # Axis / P / PR / QRS / T / QTc
457
  if axis:
458
  para.append(f"The QRS axis is {axis.lower()}.")
 
459
  if p:
460
  para.append(f"P waves are {p.lower()}.")
 
461
  if isinstance(pr, int):
462
  if pr < pr_low:
463
  pr_comment = "short PR interval"
 
466
  else:
467
  pr_comment = "within normal range"
468
  para.append(f"PR interval is {pr} ms ({pr_comment}).")
 
469
  if isinstance(qrs_dur, int):
470
+ qrs_comment = "normal QRS duration" if qrs_dur < qrs_limit else "prolonged QRS (possible conduction delay)"
 
 
 
471
  para.append(f"QRS duration is {qrs_dur} ms ({qrs_comment}).")
 
472
  if t:
473
  para.append(f"T waves: {t}.")
 
474
  if isinstance(qtc, int):
475
  if sex == "male":
476
  if qtc > qtc_male:
 
500
 
501
  paragraph = " ".join(para).strip()
502
 
503
+ # Structured clinical impression (deterministic summary)
504
  sci_bits = []
505
  if rhythm: sci_bits.append(rhythm)
506
  if axis: sci_bits.append(f"QRS axis: {axis}")
 
511
 
512
  return paragraph + "\n\n" + "Structured clinical impression: " + ", ".join(sci_bits)
513
 
514
+ # ========= Generation =========
515
  def generate_response(
516
  message_text: str,
517
  image_input,
 
536
  if max_new_tokens is None: max_new_tokens = 4096
537
  if repetition_penalty is None: repetition_penalty = 1.0
538
 
539
+ # Deterministic settings for schema modes
540
+ if output_mode in ("json", "report_en"):
541
+ temperature = 0.0
542
+ top_p = 1.0
543
+ repetition_penalty = 1.0
544
+ max_new_tokens = min(int(max_new_tokens), 1024)
545
+
546
  dbg(f"[gen] temp={temperature} top_p={top_p} max_new={max_new_tokens} rep={repetition_penalty} mode={output_mode}")
547
 
548
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
549
  if conv_mode_override and conv_mode_override in conv_templates:
550
  chatbot.conversation = conv_templates[conv_mode_override].copy()
551
 
552
+ # Load image → tensor
553
  try:
554
  pil_img = load_image_any(image_input)
555
  except Exception as e:
556
  return {"error": f"Failed to load image: {e}"}
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  device = next(chatbot.model.parameters()).device
559
  dtype = torch.float16
560
 
 
 
561
  image_tensor = None
562
  try:
563
  if hasattr(chatbot.image_processor, "preprocess"):
 
571
  else:
572
  raise AttributeError("processor has no preprocess")
573
  except Exception:
 
574
  try:
575
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
576
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
 
585
  except Exception:
586
  from torchvision import transforms
587
  from torchvision.transforms import InterpolationMode
588
+ expected_size = get_vision_expected_size(chatbot.model, default=336)
589
  preprocess = transforms.Compose([
590
  transforms.Resize(expected_size, interpolation=InterpolationMode.BICUBIC),
591
  transforms.CenterCrop(expected_size),
 
600
  if image_tensor is None:
601
  return {"error": "Image processing failed (no tensor produced)"}
602
 
603
+ # Build prompt
604
  base_msg = (message_text or "").strip()
605
  if output_mode in ("json", "report_en"):
606
  msg = f"{base_msg}\n\n{JSON_SCHEMA_HINT_EN}"
607
+ else: # "narrative"
608
  msg = f"{base_msg}\n\n{STYLE_HINT}"
609
 
610
+ dbg(f"[prompt] mode={output_mode}")
611
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
612
 
613
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
 
623
  except Exception:
624
  pass
625
 
626
+ # Generate with streamer
627
  streamer = TextIteratorStreamer(chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
628
  gen_kwargs = dict(
629
  inputs=input_ids,
630
  images=image_tensor,
631
  streamer=streamer,
632
+ do_sample=(temperature > 0.0),
633
  temperature=float(temperature),
634
  top_p=float(top_p),
635
  max_new_tokens=int(max_new_tokens),
 
638
  stopping_criteria=[stopping],
639
  )
640
 
 
641
  try:
642
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
643
  t.start()
 
649
  except Exception as e:
650
  return {"error": f"Generation failed: {e}"}
651
 
652
+ # output_mode handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  if output_mode == "narrative":
654
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
655
 
656
+ # For json & report_en parse once, with robust fallbacks
657
  try:
658
  start = text.find("{"); end = text.rfind("}")
659
  if start == -1 or end == -1 or end <= start:
660
+ raise ValueError("JSON braces not found")
661
  data = json.loads(text[start:end+1])
662
+ data["_parse_mode"] = "direct"
663
+ except Exception:
664
+ cleaned = _coerce_pseudo_json(text)
665
+ try:
666
+ data = json.loads(cleaned)
667
+ data["_parse_mode"] = "cleaned"
668
+ except Exception:
669
+ # Last resort: extract with regex from free text
670
+ data = _extract_fields_from_text(text)
671
+ data["_parse_mode"] = "extracted"
672
 
673
+ # Inject patient meta (local only)
674
  if patient_age_group:
675
  data["patient_age_group"] = patient_age_group
676
  if patient_sex:
 
691
  # Fallback
692
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
693
 
694
+ # ========= Public API =========
695
  def query(payload: dict):
696
  global model_initialized, tokenizer, model, image_processor, context_len, args
697
  if not model_initialized:
 
714
  det_seed = payload.get("det_seed", None)
715
  output_mode = payload.get("output_mode", "narrative")
716
 
717
+ # Optional patient meta (local use only)
718
  patient_age_group = payload.get("patient_age_group")
719
  patient_sex = payload.get("patient_sex")
720
 
 
756
  "device": str(next(model.parameters()).device) if model else "Unknown",
757
  }
758
 
759
+ # ========= Init & Session =========
760
  class _Args:
761
  def __init__(self):
762
  self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
 
782
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
783
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
784
  )
785
+ dbg(f"[init] loaded model/tokenizer/processor | context_len={context_len_}")
786
 
787
  try:
788
  _ = next(model_.parameters()).device
 
790
  if torch.cuda.is_available():
791
  model_ = model_.to(torch.device("cuda"))
792
  model_.eval()
793
+ dbg(f"[init] device={next(model_.parameters()).device}, cuda={torch.cuda.is_available()}")
794
 
795
  expected_size = get_vision_expected_size(model_, default=336)
796
+ try:
797
+ if image_processor_ is None:
798
+ from transformers import AutoProcessor, CLIPImageProcessor
799
+ try:
800
+ image_processor_ = AutoProcessor.from_pretrained(args.model_path)
801
+ except Exception:
802
+ clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
803
+ image_processor_ = CLIPImageProcessor.from_pretrained(clip_id)
804
+ force_processor_size(image_processor_, expected_size)
805
+ except Exception as e_ip:
806
+ warn(f"[init] image_processor fallback/size set failed: {e_ip}")
807
 
808
+ # publish
809
  globals()["tokenizer"] = tokenizer_
810
  globals()["model"] = model_
811
  globals()["image_processor"] = image_processor_
812
  globals()["context_len"] = context_len_
813
 
814
  chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_)
815
+ print("[init] model/tokenizer/image_processor loaded.")
816
  return True
817
  except Exception as e:
818
  warn(f"[init] failed: {e}")
819
  return False
820
 
821
+ # ========= HF EndpointHandler =========
822
  class EndpointHandler:
823
+ """Hugging Face Endpoint compatible."""
824
  def __init__(self, model_dir):
825
  self.model_dir = model_dir
826
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
834
  return get_model_info()
835
 
836
  if __name__ == "__main__":
837
+ print("Handler ready (Deterministic JSON→Narrative with robust fallbacks, age+sex aware). Use `EndpointHandler` or `query`.")
838
 
839
+ # ========= Optional FastAPI Wrapper =========
840
  try:
841
  from fastapi import FastAPI
842
  from pydantic import BaseModel
 
846
  warn(f"fastapi/pydantic not available: {e}")
847
 
848
  if FASTAPI_AVAILABLE:
849
+ app = FastAPI(title="PULSE ECG Handler API", version="1.4.0")
850
 
851
  class QueryIn(BaseModel):
852
  message: str | None = None
 
899
  data["output_mode"] = "report_en"
900
  return query(data)
901
  else:
902
+ app = None # uvicorn handler:app would fail if FastAPI is not installed