CanerDedeoglu commited on
Commit
3a8fcc6
·
verified ·
1 Parent(s): e20ad3e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +254 -268
handler.py CHANGED
@@ -1,21 +1,15 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler — Demo Parity + Style Hint + Robust Fallbacks + Debug + Dynamic Vision Size + JSON/Report (EN)
4
- - Generation settings aligned with demo app.py:
5
- do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
- - Stopping: safe keyword match on conversation separator (conv.sep/sep2)
7
- - Image tensor: .half() on model device
8
- - Streamer: TextIteratorStreamer with background thread (demo-like)
9
- - Stochastic by default (seed/deterministic OFF unless provided)
10
- - STYLE_HINT: narrative + single-line 'Structured clinical impression:' ending
11
- - Post-process: whitespace cleanup only
12
- - Extras:
13
- * DEBUG helpers (ENV: DEBUG=1)
14
- * Dynamic vision size (vision tower -> processor + preprocess/fallback)
15
- * image_processor fallback (AutoProcessor → CLIPImageProcessor)
16
- * process_images fallback (torchvision + CLIP norm)
17
- * FastAPI wrapper: /health, /info, /query, /debug, /analyze/json, /analyze/report-en
18
- * JSON schema (EN) and report renderer (table text + narrative)
19
  """
20
 
21
  import os
@@ -32,7 +26,7 @@ import torch
32
  from PIL import Image
33
  import requests
34
 
35
- # ====== Debug Helpers ======
36
  def _env_bool(name: str, default: bool = False) -> bool:
37
  v = os.getenv(name)
38
  if v is None:
@@ -48,7 +42,7 @@ def dbg(*args, **kwargs):
48
  def warn(*args, **kwargs):
49
  print("[WARN]", *args, **kwargs)
50
 
51
- # ====== LLaVA & Transformers ======
52
  try:
53
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
54
  from llava.conversation import conv_templates, SeparatorStyle
@@ -67,7 +61,7 @@ except Exception as e:
67
  TRANSFORMERS_AVAILABLE = False
68
  warn(f"transformers not available: {e}")
69
 
70
- # ====== HF Hub logging (optional) ======
71
  try:
72
  from huggingface_hub import HfApi, login
73
  HF_HUB_AVAILABLE = True
@@ -89,7 +83,7 @@ if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
89
  LOGDIR = "./logs"
90
  os.makedirs(LOGDIR, exist_ok=True)
91
 
92
- # ====== Global State ======
93
  tokenizer = None
94
  model = None
95
  image_processor = None
@@ -97,7 +91,7 @@ context_len = None
97
  args = None
98
  model_initialized = False
99
 
100
- # ====== Style Hint (demo-like narrative) ======
101
  STYLE_HINT = (
102
  "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
103
  "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
@@ -106,27 +100,24 @@ STYLE_HINT = (
106
  "followed by a succinct, comma-separated summary of the key diagnoses."
107
  )
108
 
109
- # ====== JSON Schema (EN) for strict machine-readable output ======
110
  JSON_SCHEMA_HINT_EN = """
111
  Return ONLY a valid JSON object that matches EXACTLY this schema:
112
-
113
  {
114
- "heart_rate_bpm": int, // e.g., 128
115
- "rhythm": "string", // e.g., "Sinus tachycardia"
116
- "qrs_axis": "string", // e.g., "Normal (+16°)"
117
- "p_waves": "string", // e.g., "Normal"
118
- "pr_interval_ms": int, // e.g., 160
119
- "qrs_duration_ms": int, // e.g., 84
120
- "t_waves": "string", // e.g., "Negative in DIII, aVF, V1–V4"
121
- "qtc_ms": int, // e.g., 467
122
- "qtc_comment": "string", // e.g., "Mildly prolonged"
123
- "additional_comments": "string" // e.g., "S1Q3T3 pattern and anterior T-wave inversions present."
124
  }
125
-
126
  Rules:
127
  - Output MUST be valid JSON with no extra text before or after.
128
- - Units: use numbers for bpm and ms (integers only).
129
- - If unknown, use null (ints may be null).
130
  - Use standard cardiology terminology in English.
131
  """
132
 
@@ -148,13 +139,6 @@ def _conv_log_path() -> str:
148
  return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
149
 
150
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
151
- """
152
- Supported:
153
- - URL (http/https)
154
- - local file path
155
- - base64 (optionally with data URL prefix)
156
- - {"image": <base64|dataurl>}
157
- """
158
  if isinstance(image_input, str):
159
  s = image_input.strip()
160
  if s.startswith(("http://", "https://")):
@@ -163,15 +147,12 @@ def load_image_any(image_input: Union[str, dict]) -> Image.Image:
163
  return Image.open(BytesIO(r.content)).convert("RGB")
164
  if os.path.exists(s):
165
  return Image.open(s).convert("RGB")
166
- # base64 (maybe dataurl)
167
  if s.startswith("data:image"):
168
  s = s.split(",", 1)[1]
169
  raw = base64.b64decode(s)
170
  return Image.open(BytesIO(raw)).convert("RGB")
171
-
172
  if isinstance(image_input, dict) and "image" in image_input:
173
  return load_image_any(image_input["image"])
174
-
175
  raise ValueError("Unsupported image input format")
176
 
177
  def _normalize_whitespace(text: str) -> str:
@@ -184,11 +165,8 @@ def _normalize_whitespace(text: str) -> str:
184
  def _postprocess_min(text: str) -> str:
185
  return _normalize_whitespace(text)
186
 
187
- # ====== Vision helpers (dynamic size) ======
188
  def get_vision_expected_size(m, default: int = 336) -> int:
189
- """
190
- Returns expected image size for the model's vision tower (e.g., 336).
191
- """
192
  try:
193
  vt = m.get_vision_tower()
194
  vt_cfg = getattr(getattr(vt, "vision_tower", vt), "config", None)
@@ -204,51 +182,45 @@ def get_vision_expected_size(m, default: int = 336) -> int:
204
  return default
205
 
206
  def force_processor_size(proc, size: int):
207
- """Force processor resize/crop to target size safely."""
208
  try:
209
- # size
210
  if hasattr(proc, "size"):
211
  if isinstance(proc.size, dict):
212
  proc.size["shortest_edge"] = size
213
  else:
214
  try:
215
- proc.size.shortest_edge = size # type: ignore[attr-defined]
216
  except Exception:
217
  proc.size = {"shortest_edge": size}
218
- # crop_size
219
  if hasattr(proc, "crop_size"):
220
  if isinstance(proc.crop_size, dict):
221
  proc.crop_size["height"] = size
222
- proc.crop_size["width"] = size
223
  else:
224
  try:
225
- proc.crop_size.height = size # type: ignore[attr-defined]
226
- proc.crop_size.width = size # type: ignore[attr-defined]
227
  except Exception:
228
  proc.crop_size = {"height": size, "width": size}
229
  dbg(f"[processor] forced size={size}")
230
  except Exception as e:
231
  warn(f"[processor] force size failed: {e}")
232
 
233
- # ====== Safe Stop Criteria (conv separator) ======
234
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
235
  def __init__(self, keyword: str, tokenizer):
236
- self.tokenizer = tokenizer
237
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
238
- self.kw_ids = tok # shape: (n,)
239
-
240
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
241
  if input_ids is None or input_ids.shape[0] == 0:
242
  return False
243
- out = input_ids[0] # assume bsz=1
244
  n = self.kw_ids.shape[0]
245
  if out.shape[0] < n:
246
  return False
247
  tail = out[-n:]
248
- kw = self.kw_ids.to(tail.device)
249
- return torch.equal(tail, kw)
250
 
251
- # ===================== Core Generation =====================
252
  class InferenceDemo:
253
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
254
  if not LLAVA_AVAILABLE:
@@ -288,6 +260,150 @@ def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
288
  ).unsqueeze(0).to(device)
289
  return prompt, input_ids
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  def generate_response(
292
  message_text: str,
293
  image_input,
@@ -298,7 +414,9 @@ def generate_response(
298
  conv_mode_override: Optional[str] = None,
299
  repetition_penalty: Optional[float] = None,
300
  det_seed: Optional[int] = None,
301
- output_mode: Optional[str] = "narrative", # "narrative" | "json" | "report_en"
 
 
302
  ):
303
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
304
  return {"error": "Required libraries not available (llava/transformers)"}
@@ -310,63 +428,19 @@ def generate_response(
310
  if max_new_tokens is None: max_new_tokens = 4096
311
  if repetition_penalty is None: repetition_penalty = 1.0
312
 
313
- dbg(f"[gen] temperature={temperature} top_p={top_p} max_new_tokens={max_new_tokens} rep={repetition_penalty} seed={det_seed} mode={output_mode}")
314
-
315
- if output_mode == "report_en":
316
- # Ensure a session exists so we can safely expose a conversation_id
317
- try:
318
- _cb = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
319
- conv_id = id(_cb.conversation)
320
- except Exception:
321
- conv_id = None
322
-
323
- # 1) Produce strict JSON (machine-readable)
324
- first = generate_response(
325
- message_text=message_text,
326
- image_input=image_input,
327
- temperature=temperature, top_p=top_p,
328
- max_new_tokens=max_new_tokens,
329
- conv_mode_override=conv_mode_override,
330
- repetition_penalty=repetition_penalty,
331
- det_seed=det_seed,
332
- output_mode="json",
333
- )
334
- if not isinstance(first, dict) or "response" not in first or not isinstance(first["response"], dict):
335
- return first
336
- data = first["response"]
337
-
338
- # 2) Produce short narrative (human-readable)
339
- second = generate_response(
340
- message_text=message_text,
341
- image_input=image_input,
342
- temperature=temperature, top_p=top_p,
343
- max_new_tokens=min(int(max_new_tokens), 512),
344
- conv_mode_override=conv_mode_override,
345
- repetition_penalty=repetition_penalty,
346
- det_seed=det_seed,
347
- output_mode="narrative",
348
- )
349
- narrative = second.get("response") if isinstance(second, dict) else None
350
-
351
- table_txt = render_ecg_table_en(data)
352
- return {
353
- "status": "success",
354
- "report": {"table_text": table_txt, "json": data, "narrative": narrative},
355
- "conversation_id": conv_id
356
- }
357
-
358
 
359
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
360
  if conv_mode_override and conv_mode_override in conv_templates:
361
  chatbot.conversation = conv_templates[conv_mode_override].copy()
362
 
363
- # Load image (PIL)
364
  try:
365
  pil_img = load_image_any(image_input)
366
  except Exception as e:
367
  return {"error": f"Failed to load image: {e}"}
368
 
369
- # Save image log (optional)
370
  img_hash, img_path = "NA", None
371
  try:
372
  buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
@@ -382,10 +456,8 @@ def generate_response(
382
  device = next(chatbot.model.parameters()).device
383
  dtype = torch.float16
384
 
385
- # === Image preprocessing → tensor (dynamic size) ===
386
  expected_size = get_vision_expected_size(chatbot.model, default=336)
387
- dbg(f"[pre] dynamic expected_size={expected_size} | processor={type(chatbot.image_processor)}")
388
-
389
  image_tensor = None
390
  try:
391
  if hasattr(chatbot.image_processor, "preprocess"):
@@ -396,11 +468,10 @@ def generate_response(
396
  if image_tensor.ndim == 3:
397
  image_tensor = image_tensor.unsqueeze(0)
398
  image_tensor = image_tensor.to(device=device, dtype=dtype)
399
- dbg(f"[pre] processor.preprocess ok → {tuple(image_tensor.shape)}")
400
  else:
401
  raise AttributeError("processor has no preprocess")
402
- except Exception as e_pre:
403
- warn(f"[pre] processor.preprocess not used: {e_pre}process_images fallback…")
404
  try:
405
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
406
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
@@ -409,13 +480,10 @@ def generate_response(
409
  image_tensor = processed[0] if processed.ndim == 4 else processed
410
  else:
411
  raise ValueError("process_images returned empty")
412
-
413
  if image_tensor.ndim == 3:
414
  image_tensor = image_tensor.unsqueeze(0)
415
  image_tensor = image_tensor.to(device=device, dtype=dtype)
416
- dbg(f"[pre] process_images ok → {tuple(image_tensor.shape)}")
417
- except Exception as e_proc:
418
- warn(f"[pre] process_images failed: {e_proc} → manual CLIP fallback (dynamic size).")
419
  from torchvision import transforms
420
  from torchvision.transforms import InterpolationMode
421
  preprocess = transforms.Compose([
@@ -428,19 +496,17 @@ def generate_response(
428
  ),
429
  ])
430
  image_tensor = preprocess(pil_img).unsqueeze(0).to(device=device, dtype=dtype)
431
- dbg(f"[pre] manual fallback ok → {tuple(image_tensor.shape)}")
432
 
433
  if image_tensor is None:
434
  return {"error": "Image processing failed (no tensor produced)"}
435
 
436
- # ===== Build message according to output_mode =====
437
  base_msg = (message_text or "").strip()
438
- if output_mode == "json":
439
  msg = f"{base_msg}\n\n{JSON_SCHEMA_HINT_EN}"
440
- else: # "narrative"
441
  msg = f"{base_msg}\n\n{STYLE_HINT}"
442
 
443
- dbg(f"[prompt] conv_sep_style={chatbot.conversation.sep_style} sep_len={len(chatbot.conversation.sep)}")
444
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
445
 
446
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
@@ -457,7 +523,6 @@ def generate_response(
457
  pass
458
 
459
  streamer = TextIteratorStreamer(chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
460
-
461
  gen_kwargs = dict(
462
  inputs=input_ids,
463
  images=image_tensor,
@@ -471,19 +536,19 @@ def generate_response(
471
  stopping_criteria=[stopping],
472
  )
473
 
 
474
  try:
475
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
476
  t.start()
477
  chunks = []
478
  for piece in streamer:
479
  chunks.append(piece)
480
- text = "".join(chunks)
481
- text = _postprocess_min(text)
482
  chatbot.conversation.messages[-1][-1] = text
483
  except Exception as e:
484
  return {"error": f"Generation failed: {e}"}
485
 
486
- # Logging
487
  try:
488
  row = {
489
  "time": datetime.datetime.now().isoformat(),
@@ -499,24 +564,42 @@ def generate_response(
499
  except Exception as e:
500
  warn(f"[log] failed: {e}")
501
 
502
- # If JSON mode, parse and return as object
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  if output_mode == "json":
504
- try:
505
- start = text.find("{"); end = text.rfind("}")
506
- if start != -1 and end != -1 and end > start:
507
- obj = json.loads(text[start:end+1])
508
- else:
509
- return {"error": "JSON block not found", "raw": text}
510
- except Exception as e:
511
- return {"error": f"JSON parse failed: {e}", "raw": text}
512
- return {"status": "success", "response": obj, "conversation_id": id(chatbot.conversation)}
513
 
514
- # Default narrative
 
 
 
 
 
 
 
 
 
515
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
516
 
517
  # ===================== Public API =====================
518
  def query(payload: dict):
519
- """HF Endpoint entry (demo-like)."""
520
  global model_initialized, tokenizer, model, image_processor, context_len, args
521
  if not model_initialized:
522
  if not initialize_model():
@@ -536,7 +619,11 @@ def query(payload: dict):
536
 
537
  conv_mode_override = payload.get("conv_mode", None)
538
  det_seed = payload.get("det_seed", None)
539
- output_mode = payload.get("output_mode", "narrative") # "narrative" | "json" | "report_en"
 
 
 
 
540
 
541
  if det_seed is not None:
542
  try: det_seed = int(det_seed)
@@ -552,6 +639,8 @@ def query(payload: dict):
552
  repetition_penalty=repetition_penalty,
553
  det_seed=det_seed,
554
  output_mode=output_mode,
 
 
555
  )
556
  except Exception as e:
557
  return {"error": f"Query failed: {e}"}
@@ -600,7 +689,6 @@ def initialize_model():
600
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
601
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
602
  )
603
- dbg(f"[init] load_pretrained_model ok | tokenizer={type(tokenizer_)} | model={type(model_)} | image_processor={type(image_processor_)} | context_len={context_len_}")
604
 
605
  try:
606
  _ = next(model_.parameters()).device
@@ -608,53 +696,17 @@ def initialize_model():
608
  if torch.cuda.is_available():
609
  model_ = model_.to(torch.device("cuda"))
610
  model_.eval()
611
- dbg(f"[init] device={next(model_.parameters()).device}, cuda_available={torch.cuda.is_available()}")
612
 
613
- # Vision tower expected image size
614
  expected_size = get_vision_expected_size(model_, default=336)
615
- dbg(f"[init] vision expected image_size={expected_size}")
616
-
617
- # image_processor fallback chain
618
- try:
619
- if image_processor_ is None:
620
- dbg("[init] image_processor None → AutoProcessor(model_path)…")
621
- try:
622
- from transformers import AutoProcessor
623
- image_processor_ = AutoProcessor.from_pretrained(args.model_path)
624
- dbg("[init] image_processor: AutoProcessor.from_pretrained(model_path) loaded.")
625
- except Exception as _e1:
626
- dbg(f"[init] AutoProcessor(model_path) failed: {_e1}")
627
- try:
628
- from transformers import AutoProcessor
629
- clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
630
- image_processor_ = AutoProcessor.from_pretrained(clip_id)
631
- dbg(f"[init] AutoProcessor({clip_id}) loaded.")
632
- except Exception as _e2:
633
- from transformers import CLIPImageProcessor
634
- clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
635
- image_processor_ = CLIPImageProcessor.from_pretrained(clip_id)
636
- warn(f"[init] CLIPImageProcessor({clip_id}) fallback in use.")
637
- except Exception as _e:
638
- warn(f"[init] image_processor fallback chain failed: {_e}")
639
-
640
- # Force processor sizes to match tower
641
- try:
642
- if image_processor_ is not None:
643
- force_processor_size(image_processor_, expected_size)
644
- except Exception as e_ip:
645
- warn(f"[init] processor size set error: {e_ip}")
646
-
647
- # Processor introspection
648
- try:
649
- ip = image_processor_
650
- if ip is not None:
651
- crop_sz = getattr(getattr(ip, "crop_size", None), "height", None) or getattr(ip, "crop_size", None)
652
- size_sz = getattr(getattr(ip, "size", None), "shortest_edge", None) or getattr(ip, "size", None)
653
- dbg(f"[init] image_processor crop_size={crop_sz} size={size_sz} class={ip.__class__.__name__}")
654
- else:
655
- warn("[init] image_processor still None (fallback failed).")
656
- except Exception as e_ip2:
657
- warn(f"[init] image_processor inspect error: {e_ip2}")
658
 
659
  globals()["tokenizer"] = tokenizer_
660
  globals()["model"] = model_
@@ -662,51 +714,15 @@ def initialize_model():
662
  globals()["context_len"] = context_len_
663
 
664
  chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_)
665
- print("[init] model/tokenizer/image_processor loaded.")
666
  return True
667
  except Exception as e:
668
  warn(f"[init] failed: {e}")
669
  return False
670
 
671
- # ===================== Report rendering (EN) =====================
672
- def render_ecg_table_en(d: Dict[str, Any]) -> str:
673
- def g(k, default="—"):
674
- v = d.get(k, None)
675
- if v is None: return default
676
- return str(v)
677
-
678
- hr = g("heart_rate_bpm")
679
- rhythm = g("rhythm")
680
- axis = g("qrs_axis")
681
- p = g("p_waves")
682
- pr = g("pr_interval_ms")
683
- qrs_dur = g("qrs_duration_ms")
684
- t = g("t_waves")
685
- qtc = g("qtc_ms")
686
- qtc_c = g("qtc_comment")
687
- extra = g("additional_comments")
688
-
689
- lines = [
690
- "ECG ANALYSIS",
691
- "────────────",
692
- f"Heart rate : {hr} beats/min",
693
- f"Rhythm : {rhythm}",
694
- f"QRS axis : {axis}",
695
- f"P waves : {p}",
696
- f"PR interval : {pr} ms",
697
- f"QRS duration : {qrs_dur} ms",
698
- f"T waves : {t}",
699
- f"QTc : {qtc_c} ({qtc} ms)",
700
- "",
701
- "Additional comments",
702
- "──────────────────",
703
- f"{extra}"
704
- ]
705
- return "\n".join(lines)
706
-
707
- # ===================== HF EndpointHandler =====================
708
  class EndpointHandler:
709
- """Hugging Face Endpoint-compatible wrapper."""
710
  def __init__(self, model_dir):
711
  self.model_dir = model_dir
712
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
@@ -720,9 +736,9 @@ class EndpointHandler:
720
  return get_model_info()
721
 
722
  if __name__ == "__main__":
723
- print("Handler ready (Demo Parity + Style Hint + whitespace post-process + dynamic size + fallbacks + debug + JSON/Report-EN). Use `EndpointHandler` or `query`.")
724
 
725
- # ===================== Minimal FastAPI Wrapper =====================
726
  try:
727
  from fastapi import FastAPI
728
  from pydantic import BaseModel
@@ -732,7 +748,7 @@ except Exception as e:
732
  warn(f"fastapi/pydantic not available: {e}")
733
 
734
  if FASTAPI_AVAILABLE:
735
- app = FastAPI(title="PULSE ECG Handler API", version="1.1.0")
736
 
737
  class QueryIn(BaseModel):
738
  message: str | None = None
@@ -750,7 +766,9 @@ if FASTAPI_AVAILABLE:
750
  repetition_penalty: float | None = None
751
  conv_mode: str | None = None
752
  det_seed: int | None = None
753
- output_mode: str | None = None # "narrative" | "json" | "report_en"
 
 
754
 
755
  @app.on_event("startup")
756
  async def _startup():
@@ -767,39 +785,6 @@ if FASTAPI_AVAILABLE:
767
  async def _info():
768
  return get_model_info()
769
 
770
- @app.get("/debug")
771
- async def _debug():
772
- try:
773
- dev = str(next(model.parameters()).device) if model else "Unknown"
774
- except Exception:
775
- dev = "Unknown"
776
-
777
- try:
778
- ip = image_processor
779
- ip_cls = ip.__class__.__name__ if ip else None
780
- crop_sz = getattr(getattr(ip, "crop_size", None), "height", None) or getattr(ip, "crop_size", None)
781
- size_short = getattr(getattr(ip, "size", None), "shortest_edge", None) or getattr(ip, "size", None)
782
- except Exception:
783
- ip_cls, crop_sz, size_short = None, None, None
784
-
785
- try:
786
- ve = get_vision_expected_size(model, default=None) if model else None
787
- except Exception:
788
- ve = None
789
-
790
- return {
791
- "debug": bool(DEBUG),
792
- "llava_available": LLAVA_AVAILABLE,
793
- "transformers_available": TRANSFORMERS_AVAILABLE,
794
- "device": dev,
795
- "context_len": context_len,
796
- "image_processor_class": ip_cls,
797
- "image_processor_crop_size": crop_sz,
798
- "image_processor_size": {"shortest_edge": size_short},
799
- "vision_expected_image_size": ve,
800
- "model_path": args.model_path if args else None,
801
- }
802
-
803
  @app.post("/query")
804
  async def _query(payload: QueryIn):
805
  return query({k: v for k, v in payload.dict().items() if v is not None})
@@ -816,4 +801,5 @@ if FASTAPI_AVAILABLE:
816
  data["output_mode"] = "report_en"
817
  return query(data)
818
  else:
819
- app = None # Running "uvicorn handler:app" will raise import error if FastAPI missing
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler — Deterministic JSON→Narrative (age+sex aware)
4
+ - Model still processes image (LLaVA/transformers)
5
+ - output_mode="json" → returns structured JSON (single model call)
6
+ - output_mode="report_en" JSON + table + narrative (derived deterministically from JSON; still single model call)
7
+ - output_mode="narrative" classic narrative paragraph (model free-form)
8
+
9
+ Notes:
10
+ - For "json" and "report_en" modes we prompt the model with a strict JSON schema hint.
11
+ - Age group ("0-15" | "15-65" | "65+") and sex ("male" | "female") are accepted from payload
12
+ and used only in deterministic narrative rendering (not sent to the model).
 
 
 
 
 
 
13
  """
14
 
15
  import os
 
26
  from PIL import Image
27
  import requests
28
 
29
+ # ==== Debug helpers ====
30
  def _env_bool(name: str, default: bool = False) -> bool:
31
  v = os.getenv(name)
32
  if v is None:
 
42
  def warn(*args, **kwargs):
43
  print("[WARN]", *args, **kwargs)
44
 
45
+ # ==== LLaVA & Transformers ====
46
  try:
47
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
48
  from llava.conversation import conv_templates, SeparatorStyle
 
61
  TRANSFORMERS_AVAILABLE = False
62
  warn(f"transformers not available: {e}")
63
 
64
+ # ==== HF Hub logging (optional) ====
65
  try:
66
  from huggingface_hub import HfApi, login
67
  HF_HUB_AVAILABLE = True
 
83
  LOGDIR = "./logs"
84
  os.makedirs(LOGDIR, exist_ok=True)
85
 
86
+ # ==== Global state ====
87
  tokenizer = None
88
  model = None
89
  image_processor = None
 
91
  args = None
92
  model_initialized = False
93
 
94
+ # ==== Prompts ====
95
  STYLE_HINT = (
96
  "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
97
  "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
 
100
  "followed by a succinct, comma-separated summary of the key diagnoses."
101
  )
102
 
 
103
  JSON_SCHEMA_HINT_EN = """
104
  Return ONLY a valid JSON object that matches EXACTLY this schema:
 
105
  {
106
+ "heart_rate_bpm": int | null,
107
+ "rhythm": "string",
108
+ "qrs_axis": "string",
109
+ "p_waves": "string",
110
+ "pr_interval_ms": int | null,
111
+ "qrs_duration_ms": int | null,
112
+ "t_waves": "string",
113
+ "qtc_ms": int | null,
114
+ "qtc_comment": "string",
115
+ "additional_comments": "string"
116
  }
 
117
  Rules:
118
  - Output MUST be valid JSON with no extra text before or after.
119
+ - Units: use integers for bpm and ms where applicable.
120
+ - If unknown, use null for numeric fields and empty string for text fields.
121
  - Use standard cardiology terminology in English.
122
  """
123
 
 
139
  return os.path.join(LOGDIR, f"{t.year:04d}-{t.month:02d}-{t.day:02d}-user_conv.json")
140
 
141
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
 
 
 
 
 
 
 
142
  if isinstance(image_input, str):
143
  s = image_input.strip()
144
  if s.startswith(("http://", "https://")):
 
147
  return Image.open(BytesIO(r.content)).convert("RGB")
148
  if os.path.exists(s):
149
  return Image.open(s).convert("RGB")
 
150
  if s.startswith("data:image"):
151
  s = s.split(",", 1)[1]
152
  raw = base64.b64decode(s)
153
  return Image.open(BytesIO(raw)).convert("RGB")
 
154
  if isinstance(image_input, dict) and "image" in image_input:
155
  return load_image_any(image_input["image"])
 
156
  raise ValueError("Unsupported image input format")
157
 
158
  def _normalize_whitespace(text: str) -> str:
 
165
  def _postprocess_min(text: str) -> str:
166
  return _normalize_whitespace(text)
167
 
168
+ # ====== Vision helpers ======
169
  def get_vision_expected_size(m, default: int = 336) -> int:
 
 
 
170
  try:
171
  vt = m.get_vision_tower()
172
  vt_cfg = getattr(getattr(vt, "vision_tower", vt), "config", None)
 
182
  return default
183
 
184
  def force_processor_size(proc, size: int):
 
185
  try:
 
186
  if hasattr(proc, "size"):
187
  if isinstance(proc.size, dict):
188
  proc.size["shortest_edge"] = size
189
  else:
190
  try:
191
+ proc.size.shortest_edge = size
192
  except Exception:
193
  proc.size = {"shortest_edge": size}
 
194
  if hasattr(proc, "crop_size"):
195
  if isinstance(proc.crop_size, dict):
196
  proc.crop_size["height"] = size
197
+ proc.crop_size["width"] = size
198
  else:
199
  try:
200
+ proc.crop_size.height = size
201
+ proc.crop_size.width = size
202
  except Exception:
203
  proc.crop_size = {"height": size, "width": size}
204
  dbg(f"[processor] forced size={size}")
205
  except Exception as e:
206
  warn(f"[processor] force size failed: {e}")
207
 
208
+ # ====== Stop Criteria ======
209
  class SafeKeywordsStoppingCriteria(StoppingCriteria):
210
  def __init__(self, keyword: str, tokenizer):
 
211
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
212
+ self.kw_ids = tok
 
213
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
214
  if input_ids is None or input_ids.shape[0] == 0:
215
  return False
216
+ out = input_ids[0]
217
  n = self.kw_ids.shape[0]
218
  if out.shape[0] < n:
219
  return False
220
  tail = out[-n:]
221
+ return torch.equal(tail, self.kw_ids.to(tail.device))
 
222
 
223
+ # ===================== Core =====================
224
  class InferenceDemo:
225
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
226
  if not LLAVA_AVAILABLE:
 
260
  ).unsqueeze(0).to(device)
261
  return prompt, input_ids
262
 
263
+ # ===================== Deterministic renderers =====================
264
+ def render_ecg_table_en(d: Dict[str, Any]) -> str:
265
+ lines = ["ECG ANALYSIS", "────────────"]
266
+ if "heart_rate_bpm" in d and d["heart_rate_bpm"] is not None:
267
+ lines.append(f"Heart rate : {d['heart_rate_bpm']} beats/min")
268
+ if "rhythm" in d:
269
+ lines.append(f"Rhythm : {d['rhythm']}")
270
+ if "qrs_axis" in d:
271
+ lines.append(f"QRS axis : {d['qrs_axis']}")
272
+ if "p_waves" in d:
273
+ lines.append(f"P waves : {d['p_waves']}")
274
+ if "pr_interval_ms" in d and d["pr_interval_ms"] is not None:
275
+ lines.append(f"PR interval : {d['pr_interval_ms']} ms")
276
+ if "qrs_duration_ms" in d and d["qrs_duration_ms"] is not None:
277
+ lines.append(f"QRS duration : {d['qrs_duration_ms']} ms")
278
+ if "t_waves" in d:
279
+ lines.append(f"T waves : {d['t_waves']}")
280
+ if "qtc_ms" in d and d["qtc_ms"] is not None:
281
+ qtc_c = d.get("qtc_comment", "").strip()
282
+ qtc_c = qtc_c if qtc_c else "—"
283
+ lines.append(f"QTc : {qtc_c} ({d['qtc_ms']} ms)")
284
+ lines.append("")
285
+ lines.append("Additional comments")
286
+ lines.append("──────────────────")
287
+ lines.append(d.get("additional_comments", "").strip())
288
+ return "\n".join(lines)
289
+
290
+ def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
291
+ """Deterministic narrative based on JSON + age_group + sex"""
292
+ hr = d.get("heart_rate_bpm")
293
+ rhythm = d.get("rhythm")
294
+ axis = d.get("qrs_axis")
295
+ p = d.get("p_waves")
296
+ pr = d.get("pr_interval_ms")
297
+ qrs_dur = d.get("qrs_duration_ms")
298
+ t = d.get("t_waves")
299
+ qtc = d.get("qtc_ms")
300
+ extra = d.get("additional_comments")
301
+ age_group = d.get("patient_age_group") # "0-15" | "15-65" | "65+"
302
+ sex = d.get("patient_sex") # "male" | "female"
303
+
304
+ # thresholds by age group
305
+ if age_group == "0-15":
306
+ hr_low, hr_high = 70, 120
307
+ pr_low, pr_high = 110, 180
308
+ qrs_limit = 100
309
+ qtc_male, qtc_female = 460, 470
310
+ elif age_group == "65+":
311
+ hr_low, hr_high = 50, 100
312
+ pr_low, pr_high = 120, 220
313
+ qrs_limit = 120
314
+ qtc_male, qtc_female = 460, 480
315
+ else: # default 15-65
316
+ hr_low, hr_high = 60, 100
317
+ pr_low, pr_high = 120, 200
318
+ qrs_limit = 120
319
+ qtc_male, qtc_female = 450, 470
320
+
321
+ para = []
322
+ # patient context
323
+ if age_group and sex:
324
+ para.append(f"The patient is a {age_group} years {sex}.")
325
+ elif age_group:
326
+ para.append(f"The patient belongs to the {age_group} years age group.")
327
+ elif sex:
328
+ para.append(f"The patient is {sex}.")
329
+
330
+ if rhythm:
331
+ para.append(f"The electrocardiogram shows {rhythm.lower()}.")
332
+
333
+ if isinstance(hr, int):
334
+ if hr < hr_low:
335
+ hr_comment = "bradycardia"
336
+ elif hr > hr_high:
337
+ hr_comment = "tachycardia"
338
+ else:
339
+ hr_comment = "within normal range"
340
+ para.append(f"The heart rate is {hr} bpm ({hr_comment}).")
341
+
342
+ if axis:
343
+ para.append(f"The QRS axis is {axis.lower()}.")
344
+
345
+ if p:
346
+ para.append(f"P waves are {p.lower()}.")
347
+
348
+ if isinstance(pr, int):
349
+ if pr < pr_low:
350
+ pr_comment = "short PR interval"
351
+ elif pr > pr_high:
352
+ pr_comment = "prolonged PR interval"
353
+ else:
354
+ pr_comment = "within normal range"
355
+ para.append(f"PR interval is {pr} ms ({pr_comment}).")
356
+
357
+ if isinstance(qrs_dur, int):
358
+ if qrs_dur >= qrs_limit:
359
+ qrs_comment = "prolonged QRS (possible conduction delay)"
360
+ else:
361
+ qrs_comment = "normal QRS duration"
362
+ para.append(f"QRS duration is {qrs_dur} ms ({qrs_comment}).")
363
+
364
+ if t:
365
+ para.append(f"T waves: {t}.")
366
+
367
+ if isinstance(qtc, int):
368
+ if sex == "male":
369
+ if qtc > qtc_male:
370
+ qtc_comment = "prolonged for male"
371
+ elif qtc < 350:
372
+ qtc_comment = "shortened"
373
+ else:
374
+ qtc_comment = "normal for male"
375
+ elif sex == "female":
376
+ if qtc > qtc_female:
377
+ qtc_comment = "prolonged for female"
378
+ elif qtc < 360:
379
+ qtc_comment = "shortened"
380
+ else:
381
+ qtc_comment = "normal for female"
382
+ else:
383
+ if qtc > max(qtc_male, qtc_female):
384
+ qtc_comment = "prolonged"
385
+ elif qtc < 350:
386
+ qtc_comment = "shortened"
387
+ else:
388
+ qtc_comment = "normal"
389
+ para.append(f"QTc is {qtc} ms ({qtc_comment}).")
390
+
391
+ if isinstance(extra, str) and extra.strip():
392
+ para.append(extra.strip())
393
+
394
+ paragraph = " ".join(para).strip()
395
+
396
+ sci_bits = []
397
+ if rhythm: sci_bits.append(rhythm)
398
+ if axis: sci_bits.append(f"QRS axis: {axis}")
399
+ if isinstance(pr, int): sci_bits.append(f"PR {pr} ms")
400
+ if isinstance(qrs_dur, int): sci_bits.append(f"QRS {qrs_dur} ms")
401
+ if isinstance(qtc, int): sci_bits.append(f"QTc {qtc} ms")
402
+ if isinstance(extra, str) and extra.strip(): sci_bits.append(extra.strip())
403
+
404
+ return paragraph + "\n\n" + "Structured clinical impression: " + ", ".join(sci_bits)
405
+
406
+ # ===================== Generation =====================
407
  def generate_response(
408
  message_text: str,
409
  image_input,
 
414
  conv_mode_override: Optional[str] = None,
415
  repetition_penalty: Optional[float] = None,
416
  det_seed: Optional[int] = None,
417
+ output_mode: str = "narrative", # "narrative" | "json" | "report_en"
418
+ patient_age_group: Optional[str] = None,
419
+ patient_sex: Optional[str] = None,
420
  ):
421
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
422
  return {"error": "Required libraries not available (llava/transformers)"}
 
428
  if max_new_tokens is None: max_new_tokens = 4096
429
  if repetition_penalty is None: repetition_penalty = 1.0
430
 
431
+ dbg(f"[gen] temp={temperature} top_p={top_p} max_new={max_new_tokens} rep={repetition_penalty} mode={output_mode}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
434
  if conv_mode_override and conv_mode_override in conv_templates:
435
  chatbot.conversation = conv_templates[conv_mode_override].copy()
436
 
437
+ # Load image
438
  try:
439
  pil_img = load_image_any(image_input)
440
  except Exception as e:
441
  return {"error": f"Failed to load image: {e}"}
442
 
443
+ # Save image (log)
444
  img_hash, img_path = "NA", None
445
  try:
446
  buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
 
456
  device = next(chatbot.model.parameters()).device
457
  dtype = torch.float16
458
 
459
+ # Preprocess image → tensor
460
  expected_size = get_vision_expected_size(chatbot.model, default=336)
 
 
461
  image_tensor = None
462
  try:
463
  if hasattr(chatbot.image_processor, "preprocess"):
 
468
  if image_tensor.ndim == 3:
469
  image_tensor = image_tensor.unsqueeze(0)
470
  image_tensor = image_tensor.to(device=device, dtype=dtype)
 
471
  else:
472
  raise AttributeError("processor has no preprocess")
473
+ except Exception:
474
+ # Fallback chain: process_imagesmanual CLIP norm
475
  try:
476
  processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
477
  if isinstance(processed, (list, tuple)) and len(processed) > 0:
 
480
  image_tensor = processed[0] if processed.ndim == 4 else processed
481
  else:
482
  raise ValueError("process_images returned empty")
 
483
  if image_tensor.ndim == 3:
484
  image_tensor = image_tensor.unsqueeze(0)
485
  image_tensor = image_tensor.to(device=device, dtype=dtype)
486
+ except Exception:
 
 
487
  from torchvision import transforms
488
  from torchvision.transforms import InterpolationMode
489
  preprocess = transforms.Compose([
 
496
  ),
497
  ])
498
  image_tensor = preprocess(pil_img).unsqueeze(0).to(device=device, dtype=dtype)
 
499
 
500
  if image_tensor is None:
501
  return {"error": "Image processing failed (no tensor produced)"}
502
 
503
+ # Prompt selection
504
  base_msg = (message_text or "").strip()
505
+ if output_mode in ("json", "report_en"):
506
  msg = f"{base_msg}\n\n{JSON_SCHEMA_HINT_EN}"
507
+ else: # narrative
508
  msg = f"{base_msg}\n\n{STYLE_HINT}"
509
 
 
510
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
511
 
512
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
 
523
  pass
524
 
525
  streamer = TextIteratorStreamer(chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
 
526
  gen_kwargs = dict(
527
  inputs=input_ids,
528
  images=image_tensor,
 
536
  stopping_criteria=[stopping],
537
  )
538
 
539
+ # Generate
540
  try:
541
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
542
  t.start()
543
  chunks = []
544
  for piece in streamer:
545
  chunks.append(piece)
546
+ text = _postprocess_min("".join(chunks))
 
547
  chatbot.conversation.messages[-1][-1] = text
548
  except Exception as e:
549
  return {"error": f"Generation failed: {e}"}
550
 
551
+ # Log
552
  try:
553
  row = {
554
  "time": datetime.datetime.now().isoformat(),
 
564
  except Exception as e:
565
  warn(f"[log] failed: {e}")
566
 
567
+ # Output modes
568
+ if output_mode == "narrative":
569
+ return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
570
+
571
+ # For json & report_en we need to parse JSON once
572
+ try:
573
+ start = text.find("{"); end = text.rfind("}")
574
+ if start == -1 or end == -1 or end <= start:
575
+ return {"error": "JSON block not found", "raw": text}
576
+ data = json.loads(text[start:end+1])
577
+ except Exception as e:
578
+ return {"error": f"JSON parse failed: {e}", "raw": text}
579
+
580
+ # Inject patient metadata (not sent to model; used for deterministic narrative)
581
+ if patient_age_group:
582
+ data["patient_age_group"] = patient_age_group
583
+ if patient_sex:
584
+ data["patient_sex"] = patient_sex
585
+
586
  if output_mode == "json":
587
+ return {"status": "success", "response": data, "conversation_id": id(chatbot.conversation)}
 
 
 
 
 
 
 
 
588
 
589
+ if output_mode == "report_en":
590
+ narrative = render_ecg_narrative_en(data)
591
+ table_txt = render_ecg_table_en(data)
592
+ return {
593
+ "status": "success",
594
+ "report": {"table_text": table_txt, "json": data, "narrative": narrative},
595
+ "conversation_id": id(chatbot.conversation)
596
+ }
597
+
598
+ # Fallback
599
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
600
 
601
  # ===================== Public API =====================
602
  def query(payload: dict):
 
603
  global model_initialized, tokenizer, model, image_processor, context_len, args
604
  if not model_initialized:
605
  if not initialize_model():
 
619
 
620
  conv_mode_override = payload.get("conv_mode", None)
621
  det_seed = payload.get("det_seed", None)
622
+ output_mode = payload.get("output_mode", "narrative")
623
+
624
+ # Optional patient meta
625
+ patient_age_group = payload.get("patient_age_group")
626
+ patient_sex = payload.get("patient_sex")
627
 
628
  if det_seed is not None:
629
  try: det_seed = int(det_seed)
 
639
  repetition_penalty=repetition_penalty,
640
  det_seed=det_seed,
641
  output_mode=output_mode,
642
+ patient_age_group=patient_age_group,
643
+ patient_sex=patient_sex,
644
  )
645
  except Exception as e:
646
  return {"error": f"Query failed: {e}"}
 
689
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
690
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
691
  )
 
692
 
693
  try:
694
  _ = next(model_.parameters()).device
 
696
  if torch.cuda.is_available():
697
  model_ = model_.to(torch.device("cuda"))
698
  model_.eval()
 
699
 
 
700
  expected_size = get_vision_expected_size(model_, default=336)
701
+ if image_processor_ is None:
702
+ try:
703
+ from transformers import AutoProcessor
704
+ image_processor_ = AutoProcessor.from_pretrained(args.model_path)
705
+ except Exception:
706
+ from transformers import CLIPImageProcessor
707
+ clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
708
+ image_processor_ = CLIPImageProcessor.from_pretrained(clip_id)
709
+ force_processor_size(image_processor_, expected_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
  globals()["tokenizer"] = tokenizer_
712
  globals()["model"] = model_
 
714
  globals()["context_len"] = context_len_
715
 
716
  chat_manager.init_if_needed(args, args.model_path, tokenizer_, model_, image_processor_, context_len_)
717
+ print("[init] model/tokenizer/image_processor loaded.]")
718
  return True
719
  except Exception as e:
720
  warn(f"[init] failed: {e}")
721
  return False
722
 
723
+ # ===================== EndpointHandler =====================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  class EndpointHandler:
725
+ """Hugging Face Endpoint compatible"""
726
  def __init__(self, model_dir):
727
  self.model_dir = model_dir
728
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
736
  return get_model_info()
737
 
738
  if __name__ == "__main__":
739
+ print("Handler ready (Deterministic JSON→Narrative, age+sex aware). Use `EndpointHandler` or `query`.")
740
 
741
+ # ===================== FastAPI Wrapper =====================
742
  try:
743
  from fastapi import FastAPI
744
  from pydantic import BaseModel
 
748
  warn(f"fastapi/pydantic not available: {e}")
749
 
750
  if FASTAPI_AVAILABLE:
751
+ app = FastAPI(title="PULSE ECG Handler API", version="1.2.0")
752
 
753
  class QueryIn(BaseModel):
754
  message: str | None = None
 
766
  repetition_penalty: float | None = None
767
  conv_mode: str | None = None
768
  det_seed: int | None = None
769
+ output_mode: str | None = None
770
+ patient_age_group: str | None = None
771
+ patient_sex: str | None = None
772
 
773
  @app.on_event("startup")
774
  async def _startup():
 
785
  async def _info():
786
  return get_model_info()
787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  @app.post("/query")
789
  async def _query(payload: QueryIn):
790
  return query({k: v for k, v in payload.dict().items() if v is not None})
 
801
  data["output_mode"] = "report_en"
802
  return query(data)
803
  else:
804
+ app = None
805
+