CanerDedeoglu commited on
Commit
7133210
·
verified ·
1 Parent(s): 2777fd6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +141 -566
handler.py CHANGED
@@ -1,73 +1,57 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- PULSE ECG Handler — Deterministic JSON → Table + Narrative (age+sex aware) with Robust Fallbacks
4
-
5
- Modes
6
- - output_mode="json" → returns structured JSON (single model call)
7
- - output_mode="report_en" → returns JSON + table + deterministic narrative (single model call)
8
- - output_mode="narrative" → classic free-form model narrative (STYLE_HINT used)
9
-
10
- Highlights
11
- - Age group ("0-15" | "15-65" | "65+") and sex ("male" | "female") are accepted in payload and are
12
- used only in deterministic narrative rendering (not sent to the model).
13
- - Robust JSON parsing:
14
- 1) direct JSON slice
15
- 2) cleanup pseudo-JSON (_coerce_pseudo_json)
16
- 3) regex-based field extraction from free text (_extract_fields_from_text)
17
- - Safe stop criteria, dynamic vision-size processor, logging hooks (optional HF Hub upload).
18
  """
19
 
20
  import os
21
  import re
22
  import json
23
  import base64
24
- import math
25
  import hashlib
26
  import datetime
27
  from io import BytesIO
28
  from threading import Thread
29
- from typing import Optional, Union, Any, Dict
30
 
31
  import torch
32
  from PIL import Image
33
  import requests
34
 
35
- # ========= Debug Helpers =========
36
- def _env_bool(name: str, default: bool = False) -> bool:
37
- v = os.getenv(name)
38
- if v is None:
39
- return default
40
- return str(v).strip().lower() in {"1", "true", "yes", "y", "on"}
41
-
42
- DEBUG = _env_bool("DEBUG", False)
43
-
44
- def dbg(*args, **kwargs):
45
- if DEBUG:
46
- print("[DEBUG]", *args, **kwargs)
47
-
48
- def warn(*args, **kwargs):
49
- print("[WARN]", *args, **kwargs)
50
-
51
- # ========= LLaVA & Transformers =========
52
  try:
53
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
 
 
 
54
  from llava.conversation import conv_templates, SeparatorStyle
55
  from llava.model.builder import load_pretrained_model
56
- from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
 
 
 
 
57
  from llava.utils import disable_torch_init
58
  LLAVA_AVAILABLE = True
59
  except Exception as e:
60
  LLAVA_AVAILABLE = False
61
- warn(f"LLaVA not available: {e}")
62
 
63
  try:
64
  from transformers import TextIteratorStreamer, StoppingCriteria
65
  TRANSFORMERS_AVAILABLE = True
66
  except Exception as e:
67
  TRANSFORMERS_AVAILABLE = False
68
- warn(f"transformers not available: {e}")
69
 
70
- # ========= (Optional) HF Hub logging =========
71
  try:
72
  from huggingface_hub import HfApi, login
73
  HF_HUB_AVAILABLE = True
@@ -82,13 +66,14 @@ if HF_HUB_AVAILABLE and "HF_TOKEN" in os.environ:
82
  api = HfApi()
83
  repo_name = os.environ.get("LOG_REPO", "")
84
  except Exception as e:
85
- warn(f"[HF Hub] init failed: {e}")
86
- api, repo_name = None, ""
 
87
 
88
  LOGDIR = "./logs"
89
  os.makedirs(LOGDIR, exist_ok=True)
90
 
91
- # ========= Global State =========
92
  tokenizer = None
93
  model = None
94
  image_processor = None
@@ -96,7 +81,7 @@ context_len = None
96
  args = None
97
  model_initialized = False
98
 
99
- # ========= Prompts =========
100
  STYLE_HINT = (
101
  "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
102
  "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
@@ -105,26 +90,8 @@ STYLE_HINT = (
105
  "followed by a succinct, comma-separated summary of the key diagnoses."
106
  )
107
 
108
- # Example-only schema (no type hints). The model copies this structure.
109
- JSON_SCHEMA_HINT_EN = """
110
- Return ONLY a valid JSON object. Do not include comments, types, or extra text.
111
- If a value is unknown, use null (for numbers) or "" (for strings).
112
-
113
- {
114
- "heart_rate_bpm": 100,
115
- "rhythm": "Sinus rhythm",
116
- "qrs_axis": "Normal",
117
- "p_waves": "Normal",
118
- "pr_interval_ms": 160,
119
- "qrs_duration_ms": 90,
120
- "t_waves": "Normal",
121
- "qtc_ms": 420,
122
- "qtc_comment": "Normal",
123
- "additional_comments": ""
124
- }
125
- """
126
 
127
- # ========= Utilities =========
128
  def _safe_upload(path: str):
129
  if api and repo_name and path and os.path.isfile(path):
130
  try:
@@ -135,7 +102,7 @@ def _safe_upload(path: str):
135
  repo_type="dataset",
136
  )
137
  except Exception as e:
138
- warn(f"[upload] failed for {path}: {e}")
139
 
140
  def _conv_log_path() -> str:
141
  t = datetime.datetime.now()
@@ -143,7 +110,11 @@ def _conv_log_path() -> str:
143
 
144
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
145
  """
146
- Supports: http(s) URL, local path, base64 (with or without data URL prefix), or {"image": <...>}
 
 
 
 
147
  """
148
  if isinstance(image_input, str):
149
  s = image_input.strip()
@@ -153,15 +124,24 @@ def load_image_any(image_input: Union[str, dict]) -> Image.Image:
153
  return Image.open(BytesIO(r.content)).convert("RGB")
154
  if os.path.exists(s):
155
  return Image.open(s).convert("RGB")
 
156
  if s.startswith("data:image"):
157
  s = s.split(",", 1)[1]
158
  raw = base64.b64decode(s)
159
  return Image.open(BytesIO(raw)).convert("RGB")
 
160
  if isinstance(image_input, dict) and "image" in image_input:
161
  return load_image_any(image_input["image"])
 
162
  raise ValueError("Unsupported image input format")
163
 
164
  def _normalize_whitespace(text: str) -> str:
 
 
 
 
 
 
165
  text = text.replace("\r\n", "\n").replace("\r", "\n")
166
  lines = [re.sub(r"[ \t]+", " ", ln.strip()) for ln in text.split("\n")]
167
  text = "\n".join(lines).strip()
@@ -169,168 +149,32 @@ def _normalize_whitespace(text: str) -> str:
169
  return text
170
 
171
  def _postprocess_min(text: str) -> str:
 
172
  return _normalize_whitespace(text)
173
 
174
- def _coerce_pseudo_json(text: str) -> str:
175
- """
176
- Coerce pseudo-JSON (e.g., 'int | none', 'none', Python booleans) into valid JSON string.
177
- """
178
- if not isinstance(text, str):
179
- return ""
180
- s = text
181
-
182
- # Keep only the outermost JSON object if stray tokens are around
183
- i, j = s.find("{"), s.rfind("}")
184
- if i != -1 and j != -1 and j > i:
185
- s = s[i:j+1]
186
-
187
- # Remove type-like hints → replace with valid JSON placeholders
188
- s = re.sub(r':\s*int\s*\|\s*none', ': null', s, flags=re.I)
189
- s = re.sub(r':\s*string\s*\|\s*none', ': ""', s, flags=re.I)
190
-
191
- # Python/other tokens → JSON
192
- s = re.sub(r'\bNone\b|\bnone\b', 'null', s, flags=re.I)
193
- s = re.sub(r'\bTrue\b', 'true', s)
194
- s = re.sub(r'\bFalse\b', 'false', s)
195
-
196
- # Strip inline comments
197
- s = re.sub(r'//.*', '', s) # JS style
198
- s = re.sub(r'#.*', '', s) # Python style
199
-
200
- # Collapse repeated commas
201
- s = re.sub(r',\s*,+', ',', s)
202
-
203
- return s.strip()
204
-
205
- def _to_int_or_none(x: Optional[str]) -> Optional[int]:
206
- if x is None:
207
- return None
208
- x = x.strip()
209
- if not x:
210
- return None
211
- try:
212
- v = int(float(x))
213
- if math.isnan(v):
214
- return None
215
- return v
216
- except Exception:
217
- return None
218
-
219
- def _extract_fields_from_text(text: str) -> Dict[str, Any]:
220
- """
221
- Extract fields from free text when model failed to return valid JSON.
222
- Missing numeric fields -> None; missing text -> "".
223
- """
224
- if not isinstance(text, str):
225
- text = str(text or "")
226
-
227
- def rex(pattern, flags=re.I):
228
- m = re.search(pattern, text, flags)
229
- return m.group(1).strip() if m else None
230
-
231
- # bpm
232
- hr = rex(r"(?:heart\s*rate|hr)\s*[:=]?\s*(\d{1,3})\s*(?:bpm|beats?/min)?")
233
- if hr is None:
234
- hr = rex(r"\b(\d{2,3})\s*(?:bpm|beats?/min)\b")
235
-
236
- # PR/QRS/QTc ms
237
- pr = rex(r"\bPR\s*(?:interval)?\s*[:=]?\s*(\d{2,4})\s*ms\b")
238
- qrs = rex(r"\bQRS\s*(?:duration)?\s*[:=]?\s*(\d{2,4})\s*ms\b")
239
- qtc = rex(r"\bQTc?\s*[:=]?\s*(\d{2,4})\s*ms\b")
240
-
241
- # Axis
242
- axis = rex(r"\bQRS\s*axis\s*[:=]?\s*([+\-]?\d+°|normal|left|right|indeterminate)\b")
243
-
244
- # Rhythm
245
- rhythm = rex(r"\brhythm\s*[:=]?\s*([A-Za-z \-]+)")
246
- if rhythm is None:
247
- rhythm = rex(r"\b(sinus\s+(?:tachycardia|bradycardia|rhythm)|atrial fibrillation|afib|atrial flutter|junctional rhythm)\b")
248
-
249
- # P / T waves
250
- p_waves = rex(r"\bP\s*waves?\s*[:=]?\s*([A-Za-z0-9, \-]+)")
251
- t_waves = rex(r"\bT\s*waves?\s*[:=]?\s*([A-Za-z0-9, \-]+)")
252
-
253
- # QTc comment
254
- qtc_comment = rex(r"\bQTc\s*(?:comment|status)?\s*[:=]?\s*([A-Za-z \-]+)")
255
-
256
- # Additional
257
- additional = rex(r"(?:Additional\s*comments|Notes?)\s*[:\-]?\s*([\s\S]{0,300})")
258
- if not additional:
259
- additional = rex(r"\b(ST[- ](?:elevation|depression)|S1Q3T3|early repolarization|strain pattern)\b(?:[^\n\r]{0,120})")
260
-
261
- return {
262
- "heart_rate_bpm": _to_int_or_none(hr),
263
- "rhythm": (rhythm or "").strip(),
264
- "qrs_axis": (axis or "").strip(),
265
- "p_waves": (p_waves or "").strip(),
266
- "pr_interval_ms": _to_int_or_none(pr),
267
- "qrs_duration_ms": _to_int_or_none(qrs),
268
- "t_waves": (t_waves or "").strip(),
269
- "qtc_ms": _to_int_or_none(qtc),
270
- "qtc_comment": (qtc_comment or "").strip(),
271
- "additional_comments": (additional or "").strip(),
272
- }
273
-
274
- # ========= Vision helpers =========
275
- def get_vision_expected_size(m, default: int = 336) -> int:
276
  """
277
- Return expected image size for the model vision tower if available.
278
  """
279
- try:
280
- vt = m.get_vision_tower()
281
- vt_cfg = getattr(getattr(vt, "vision_tower", vt), "config", None)
282
- if vt_cfg is None:
283
- return default
284
- if getattr(vt_cfg, "image_size", None):
285
- return int(vt_cfg.image_size)
286
- vc = getattr(vt_cfg, "vision_config", None)
287
- if vc and getattr(vc, "image_size", None):
288
- return int(vc.image_size)
289
- except Exception as e:
290
- dbg(f"[get_vision_expected_size] fallback default={default} because: {e}")
291
- return default
292
-
293
- def force_processor_size(proc, size: int):
294
- """Force processor resize/crop to target size safely."""
295
- try:
296
- if hasattr(proc, "size"):
297
- if isinstance(proc.size, dict):
298
- proc.size["shortest_edge"] = size
299
- else:
300
- try:
301
- proc.size.shortest_edge = size # type: ignore[attr-defined]
302
- except Exception:
303
- proc.size = {"shortest_edge": size}
304
- if hasattr(proc, "crop_size"):
305
- if isinstance(proc.crop_size, dict):
306
- proc.crop_size["height"] = size
307
- proc.crop_size["width"] = size
308
- else:
309
- try:
310
- proc.crop_size.height = size # type: ignore[attr-defined]
311
- proc.crop_size.width = size # type: ignore[attr-defined]
312
- except Exception:
313
- proc.crop_size = {"height": size, "width": size}
314
- dbg(f"[processor] forced size={size}")
315
- except Exception as e:
316
- warn(f"[processor] force size failed: {e}")
317
-
318
- # ========= Safe Stopper =========
319
- class SafeKeywordsStoppingCriteria(StoppingCriteria):
320
  def __init__(self, keyword: str, tokenizer):
 
321
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
322
- self.kw_ids = tok
 
323
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
324
  if input_ids is None or input_ids.shape[0] == 0:
325
  return False
326
- out = input_ids[0]
327
  n = self.kw_ids.shape[0]
328
  if out.shape[0] < n:
329
  return False
330
  tail = out[-n:]
331
- return torch.equal(tail, self.kw_ids.to(tail.device))
 
 
 
332
 
333
- # ========= Core Session =========
334
  class InferenceDemo:
335
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
336
  if not LLAVA_AVAILABLE:
@@ -339,6 +183,7 @@ class InferenceDemo:
339
  self.tokenizer, self.model, self.image_processor, self.context_len = (
340
  tokenizer_, model_, image_processor_, context_len_
341
  )
 
342
  self.conv_mode = "llava_v1"
343
  self.conversation = conv_templates[self.conv_mode].copy()
344
  self.num_frames = getattr(args, "num_frames", 16)
@@ -355,163 +200,24 @@ class ChatSessionManager:
355
  self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
356
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
357
  self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
 
358
  self.chatbot.conversation = conv_templates[self.chatbot.conv_mode].copy()
359
  return self.chatbot
360
 
361
  chat_manager = ChatSessionManager()
362
 
363
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
 
364
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
365
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
366
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
367
  prompt = chatbot.conversation.get_prompt()
 
368
  input_ids = tokenizer_image_token(
369
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
370
  ).unsqueeze(0).to(device)
371
  return prompt, input_ids
372
 
373
- # ========= Deterministic Renderers =========
374
- def render_ecg_table_en(d: Dict[str, Any]) -> str:
375
- lines = ["ECG ANALYSIS", "────────────"]
376
- if d.get("heart_rate_bpm") is not None:
377
- lines.append(f"Heart rate : {d['heart_rate_bpm']} beats/min")
378
- if "rhythm" in d:
379
- lines.append(f"Rhythm : {d['rhythm']}")
380
- if "qrs_axis" in d:
381
- lines.append(f"QRS axis : {d['qrs_axis']}")
382
- if "p_waves" in d:
383
- lines.append(f"P waves : {d['p_waves']}")
384
- if d.get("pr_interval_ms") is not None:
385
- lines.append(f"PR interval : {d['pr_interval_ms']} ms")
386
- if d.get("qrs_duration_ms") is not None:
387
- lines.append(f"QRS duration : {d['qrs_duration_ms']} ms")
388
- if "t_waves" in d:
389
- lines.append(f"T waves : {d['t_waves']}")
390
- if d.get("qtc_ms") is not None:
391
- qtc_c = (d.get("qtc_comment") or "").strip() or "—"
392
- lines.append(f"QTc : {qtc_c} ({d['qtc_ms']} ms)")
393
- lines += ["", "Additional comments", "──────────────────", (d.get("additional_comments") or "").strip()]
394
- return "\n".join(lines)
395
-
396
- def render_ecg_narrative_en(d: Dict[str, Any]) -> str:
397
- """Deterministic narrative based on JSON + age_group + sex with 'Structured clinical impression' at the end."""
398
- hr = d.get("heart_rate_bpm")
399
- rhythm = d.get("rhythm")
400
- axis = d.get("qrs_axis")
401
- p = d.get("p_waves")
402
- pr = d.get("pr_interval_ms")
403
- qrs_dur = d.get("qrs_duration_ms")
404
- t = d.get("t_waves")
405
- qtc = d.get("qtc_ms")
406
- extra = d.get("additional_comments")
407
- age_group = d.get("patient_age_group") # "0-15" | "15-65" | "65+"
408
- sex = d.get("patient_sex") # "male" | "female"
409
-
410
- # thresholds by age group
411
- if age_group == "0-15":
412
- hr_low, hr_high = 70, 120
413
- pr_low, pr_high = 110, 180
414
- qrs_limit = 100
415
- qtc_male, qtc_female = 460, 470
416
- elif age_group == "65+":
417
- hr_low, hr_high = 50, 100
418
- pr_low, pr_high = 120, 220
419
- qrs_limit = 120
420
- qtc_male, qtc_female = 460, 480
421
- else: # default 15-65
422
- hr_low, hr_high = 60, 100
423
- pr_low, pr_high = 120, 200
424
- qrs_limit = 120
425
- qtc_male, qtc_female = 450, 470
426
-
427
- para = []
428
- # patient context
429
- if age_group and sex:
430
- para.append(f"The patient is a {age_group} years {sex}.")
431
- elif age_group:
432
- para.append(f"The patient belongs to the {age_group} years age group.")
433
- elif sex:
434
- para.append(f"The patient is {sex}.")
435
-
436
- # Rhythm with age-adjusted normalization for sinus tachycardia
437
- if rhythm:
438
- if rhythm.lower() == "sinus tachycardia" and isinstance(hr, int) and hr_low <= hr <= hr_high:
439
- para.append(
440
- f"The electrocardiogram shows sinus rhythm, normal for age. "
441
- f"Although labelled as sinus tachycardia, the heart rate of {hr} bpm is within the normal range for this age group."
442
- )
443
- else:
444
- para.append(f"The electrocardiogram shows {rhythm.lower()}.")
445
-
446
- # Heart rate comment
447
- if isinstance(hr, int):
448
- if hr < hr_low:
449
- hr_comment = "bradycardia"
450
- elif hr > hr_high:
451
- hr_comment = "tachycardia"
452
- else:
453
- hr_comment = "within normal range"
454
- para.append(f"The heart rate is {hr} bpm ({hr_comment}).")
455
-
456
- # Axis / P / PR / QRS / T / QTc
457
- if axis:
458
- para.append(f"The QRS axis is {axis.lower()}.")
459
- if p:
460
- para.append(f"P waves are {p.lower()}.")
461
- if isinstance(pr, int):
462
- if pr < pr_low:
463
- pr_comment = "short PR interval"
464
- elif pr > pr_high:
465
- pr_comment = "prolonged PR interval"
466
- else:
467
- pr_comment = "within normal range"
468
- para.append(f"PR interval is {pr} ms ({pr_comment}).")
469
- if isinstance(qrs_dur, int):
470
- qrs_comment = "normal QRS duration" if qrs_dur < qrs_limit else "prolonged QRS (possible conduction delay)"
471
- para.append(f"QRS duration is {qrs_dur} ms ({qrs_comment}).")
472
- if t:
473
- para.append(f"T waves: {t}.")
474
- if isinstance(qtc, int):
475
- if sex == "male":
476
- if qtc > qtc_male:
477
- qtc_comment = "prolonged for male"
478
- elif qtc < 350:
479
- qtc_comment = "shortened"
480
- else:
481
- qtc_comment = "normal for male"
482
- elif sex == "female":
483
- if qtc > qtc_female:
484
- qtc_comment = "prolonged for female"
485
- elif qtc < 360:
486
- qtc_comment = "shortened"
487
- else:
488
- qtc_comment = "normal for female"
489
- else:
490
- if qtc > max(qtc_male, qtc_female):
491
- qtc_comment = "prolonged"
492
- elif qtc < 350:
493
- qtc_comment = "shortened"
494
- else:
495
- qtc_comment = "normal"
496
- para.append(f"QTc is {qtc} ms ({qtc_comment}).")
497
-
498
- if isinstance(extra, str) and extra.strip():
499
- para.append(extra.strip())
500
-
501
- paragraph = " ".join(para).strip()
502
-
503
- # Structured clinical impression (deterministic summary)
504
- sci_bits = []
505
- if rhythm: sci_bits.append(rhythm)
506
- if axis: sci_bits.append(f"QRS axis: {axis}")
507
- if isinstance(pr, int): sci_bits.append(f"PR {pr} ms")
508
- if isinstance(qrs_dur, int): sci_bits.append(f"QRS {qrs_dur} ms")
509
- if isinstance(qtc, int): sci_bits.append(f"QTc {qtc} ms")
510
- if isinstance(extra, str) and extra.strip(): sci_bits.append(extra.strip())
511
-
512
- return paragraph + "\n\n" + "Structured clinical impression: " + ", ".join(sci_bits)
513
-
514
- # ========= Generation =========
515
  def generate_response(
516
  message_text: str,
517
  image_input,
@@ -521,98 +227,72 @@ def generate_response(
521
  max_new_tokens: Optional[int] = None,
522
  conv_mode_override: Optional[str] = None,
523
  repetition_penalty: Optional[float] = None,
524
- det_seed: Optional[int] = None,
525
- output_mode: str = "narrative", # "narrative" | "json" | "report_en"
526
- patient_age_group: Optional[str] = None,
527
- patient_sex: Optional[str] = None,
528
  ):
529
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
530
  return {"error": "Required libraries not available (llava/transformers)"}
531
  if not message_text or image_input is None:
532
  return {"error": "Both 'message' and 'image' are required"}
533
 
 
534
  if temperature is None: temperature = 0.05
535
  if top_p is None: top_p = 1.0
536
  if max_new_tokens is None: max_new_tokens = 4096
537
- if repetition_penalty is None: repetition_penalty = 1.0
538
-
539
- # Deterministic settings for schema modes
540
- if output_mode in ("json", "report_en"):
541
- temperature = 0.0
542
- top_p = 1.0
543
- repetition_penalty = 1.0
544
- max_new_tokens = min(int(max_new_tokens), 1024)
545
-
546
- dbg(f"[gen] temp={temperature} top_p={top_p} max_new={max_new_tokens} rep={repetition_penalty} mode={output_mode}")
547
 
 
548
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
549
  if conv_mode_override and conv_mode_override in conv_templates:
550
  chatbot.conversation = conv_templates[conv_mode_override].copy()
551
 
552
- # Load image → tensor
553
  try:
554
  pil_img = load_image_any(image_input)
555
  except Exception as e:
556
  return {"error": f"Failed to load image: {e}"}
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  device = next(chatbot.model.parameters()).device
559
- dtype = torch.float16
560
 
561
- image_tensor = None
562
  try:
563
- if hasattr(chatbot.image_processor, "preprocess"):
564
- px = chatbot.image_processor.preprocess(pil_img, return_tensors="pt")
565
- image_tensor = px.get("pixel_values", px)
566
- if not isinstance(image_tensor, torch.Tensor):
567
- image_tensor = image_tensor["pixel_values"]
568
- if image_tensor.ndim == 3:
569
- image_tensor = image_tensor.unsqueeze(0)
570
- image_tensor = image_tensor.to(device=device, dtype=dtype)
571
  else:
572
- raise AttributeError("processor has no preprocess")
573
- except Exception:
574
- try:
575
- processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
576
- if isinstance(processed, (list, tuple)) and len(processed) > 0:
577
- image_tensor = processed[0]
578
- elif isinstance(processed, torch.Tensor):
579
- image_tensor = processed[0] if processed.ndim == 4 else processed
580
- else:
581
- raise ValueError("process_images returned empty")
582
- if image_tensor.ndim == 3:
583
- image_tensor = image_tensor.unsqueeze(0)
584
- image_tensor = image_tensor.to(device=device, dtype=dtype)
585
- except Exception:
586
- from torchvision import transforms
587
- from torchvision.transforms import InterpolationMode
588
- expected_size = get_vision_expected_size(chatbot.model, default=336)
589
- preprocess = transforms.Compose([
590
- transforms.Resize(expected_size, interpolation=InterpolationMode.BICUBIC),
591
- transforms.CenterCrop(expected_size),
592
- transforms.ToTensor(),
593
- transforms.Normalize(
594
- mean=[0.48145466, 0.4578275, 0.40821073],
595
- std=[0.26862954, 0.26130258, 0.27577711]
596
- ),
597
- ])
598
- image_tensor = preprocess(pil_img).unsqueeze(0).to(device=device, dtype=dtype)
599
-
600
- if image_tensor is None:
601
- return {"error": "Image processing failed (no tensor produced)"}
602
-
603
- # Build prompt
604
- base_msg = (message_text or "").strip()
605
- if output_mode in ("json", "report_en"):
606
- msg = f"{base_msg}\n\n{JSON_SCHEMA_HINT_EN}"
607
- else: # "narrative"
608
- msg = f"{base_msg}\n\n{STYLE_HINT}"
609
-
610
- dbg(f"[prompt] mode={output_mode}")
611
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
612
 
 
613
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
614
  stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
615
 
 
616
  if det_seed is not None:
617
  try:
618
  s = int(det_seed)
@@ -623,76 +303,60 @@ def generate_response(
623
  except Exception:
624
  pass
625
 
626
- # Generate with streamer
627
- streamer = TextIteratorStreamer(chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
628
  gen_kwargs = dict(
629
  inputs=input_ids,
630
  images=image_tensor,
631
  streamer=streamer,
632
- do_sample=(temperature > 0.0),
633
- temperature=float(temperature),
634
- top_p=float(top_p),
635
- max_new_tokens=int(max_new_tokens),
636
- repetition_penalty=float(repetition_penalty),
637
  use_cache=False,
638
- stopping_criteria=[stopping],
639
  )
640
 
 
641
  try:
642
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
643
  t.start()
644
  chunks = []
645
  for piece in streamer:
646
  chunks.append(piece)
647
- text = _postprocess_min("".join(chunks))
 
648
  chatbot.conversation.messages[-1][-1] = text
649
  except Exception as e:
650
  return {"error": f"Generation failed: {e}"}
651
 
652
- # output_mode handlers
653
- if output_mode == "narrative":
654
- return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
655
-
656
- # For json & report_en → parse once, with robust fallbacks
657
  try:
658
- start = text.find("{"); end = text.rfind("}")
659
- if start == -1 or end == -1 or end <= start:
660
- raise ValueError("JSON braces not found")
661
- data = json.loads(text[start:end+1])
662
- data["_parse_mode"] = "direct"
663
- except Exception:
664
- cleaned = _coerce_pseudo_json(text)
665
- try:
666
- data = json.loads(cleaned)
667
- data["_parse_mode"] = "cleaned"
668
- except Exception:
669
- # Last resort: extract with regex from free text
670
- data = _extract_fields_from_text(text)
671
- data["_parse_mode"] = "extracted"
672
-
673
- # Inject patient meta (local only)
674
- if patient_age_group:
675
- data["patient_age_group"] = patient_age_group
676
- if patient_sex:
677
- data["patient_sex"] = patient_sex
678
-
679
- if output_mode == "json":
680
- return {"status": "success", "response": data, "conversation_id": id(chatbot.conversation)}
681
-
682
- if output_mode == "report_en":
683
- narrative = render_ecg_narrative_en(data)
684
- table_txt = render_ecg_table_en(data)
685
- return {
686
- "status": "success",
687
- "report": {"table_text": table_txt, "json": data, "narrative": narrative},
688
- "conversation_id": id(chatbot.conversation)
689
  }
 
 
 
 
 
690
 
691
- # Fallback
692
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
693
 
694
- # ========= Public API =========
 
695
  def query(payload: dict):
 
696
  global model_initialized, tokenizer, model, image_processor, context_len, args
697
  if not model_initialized:
698
  if not initialize_model():
@@ -705,19 +369,14 @@ def query(payload: dict):
705
  if not message.strip(): return {"error": "Missing 'message' text"}
706
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
707
 
 
708
  temperature = float(payload.get("temperature", 0.05))
709
  top_p = float(payload.get("top_p", 1.0))
710
  max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
711
- repetition_penalty = float(payload.get("repetition_penalty", 1.0))
712
 
713
  conv_mode_override = payload.get("conv_mode", None)
714
  det_seed = payload.get("det_seed", None)
715
- output_mode = payload.get("output_mode", "narrative")
716
-
717
- # Optional patient meta (local use only)
718
- patient_age_group = payload.get("patient_age_group")
719
- patient_sex = payload.get("patient_sex")
720
-
721
  if det_seed is not None:
722
  try: det_seed = int(det_seed)
723
  except Exception: det_seed = None
@@ -731,9 +390,6 @@ def query(payload: dict):
731
  conv_mode_override=conv_mode_override,
732
  repetition_penalty=repetition_penalty,
733
  det_seed=det_seed,
734
- output_mode=output_mode,
735
- patient_age_group=patient_age_group,
736
- patient_sex=patient_sex,
737
  )
738
  except Exception as e:
739
  return {"error": f"Query failed: {e}"}
@@ -756,13 +412,14 @@ def get_model_info():
756
  "device": str(next(model.parameters()).device) if model else "Unknown",
757
  }
758
 
759
- # ========= Init & Session =========
 
760
  class _Args:
761
  def __init__(self):
762
  self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
763
  self.model_base = None
764
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
765
- self.conv_mode = "llava_v1"
766
  self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
767
  self.num_frames = 16
768
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
@@ -772,40 +429,22 @@ class _Args:
772
  def initialize_model():
773
  global tokenizer, model, image_processor, context_len, args
774
  if not LLAVA_AVAILABLE:
775
- warn("[init] LLaVA not available; cannot init.")
776
  return False
777
  try:
778
  args = _Args()
779
- dbg(f"[init] HF_MODEL_ID={args.model_path} | LOAD_8BIT={args.load_8bit} | LOAD_4BIT={args.load_4bit}")
780
  model_name = get_model_name_from_path(args.model_path)
781
-
782
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
783
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
784
  )
785
- dbg(f"[init] loaded model/tokenizer/processor | context_len={context_len_}")
786
-
787
  try:
788
  _ = next(model_.parameters()).device
789
  except Exception:
790
  if torch.cuda.is_available():
791
  model_ = model_.to(torch.device("cuda"))
792
  model_.eval()
793
- dbg(f"[init] device={next(model_.parameters()).device}, cuda={torch.cuda.is_available()}")
794
 
795
- expected_size = get_vision_expected_size(model_, default=336)
796
- try:
797
- if image_processor_ is None:
798
- from transformers import AutoProcessor, CLIPImageProcessor
799
- try:
800
- image_processor_ = AutoProcessor.from_pretrained(args.model_path)
801
- except Exception:
802
- clip_id = "openai/clip-vit-large-patch14-336" if expected_size >= 336 else "openai/clip-vit-large-patch14"
803
- image_processor_ = CLIPImageProcessor.from_pretrained(clip_id)
804
- force_processor_size(image_processor_, expected_size)
805
- except Exception as e_ip:
806
- warn(f"[init] image_processor fallback/size set failed: {e_ip}")
807
-
808
- # publish
809
  globals()["tokenizer"] = tokenizer_
810
  globals()["model"] = model_
811
  globals()["image_processor"] = image_processor_
@@ -815,12 +454,13 @@ def initialize_model():
815
  print("[init] model/tokenizer/image_processor loaded.")
816
  return True
817
  except Exception as e:
818
- warn(f"[init] failed: {e}")
819
  return False
820
 
821
- # ========= HF EndpointHandler =========
 
822
  class EndpointHandler:
823
- """Hugging Face Endpoint compatible."""
824
  def __init__(self, model_dir):
825
  self.model_dir = model_dir
826
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
@@ -834,69 +474,4 @@ class EndpointHandler:
834
  return get_model_info()
835
 
836
  if __name__ == "__main__":
837
- print("Handler ready (Deterministic JSON→Narrative with robust fallbacks, age+sex aware). Use `EndpointHandler` or `query`.")
838
-
839
- # ========= Optional FastAPI Wrapper =========
840
- try:
841
- from fastapi import FastAPI
842
- from pydantic import BaseModel
843
- FASTAPI_AVAILABLE = True
844
- except Exception as e:
845
- FASTAPI_AVAILABLE = False
846
- warn(f"fastapi/pydantic not available: {e}")
847
-
848
- if FASTAPI_AVAILABLE:
849
- app = FastAPI(title="PULSE ECG Handler API", version="1.4.0")
850
-
851
- class QueryIn(BaseModel):
852
- message: str | None = None
853
- query: str | None = None
854
- prompt: str | None = None
855
- istem: str | None = None
856
- image: str | Dict[str, Any] | None = None
857
- image_url: str | None = None
858
- img: str | None = None
859
- temperature: float | None = None
860
- top_p: float | None = None
861
- max_output_tokens: int | None = None
862
- max_new_tokens: int | None = None
863
- max_tokens: int | None = None
864
- repetition_penalty: float | None = None
865
- conv_mode: str | None = None
866
- det_seed: int | None = None
867
- output_mode: str | None = None
868
- patient_age_group: str | None = None
869
- patient_sex: str | None = None
870
-
871
- @app.on_event("startup")
872
- async def _startup():
873
- global model_initialized
874
- if not model_initialized:
875
- model_initialized = initialize_model()
876
- print(f"[startup] model_initialized={model_initialized}")
877
-
878
- @app.get("/health")
879
- async def _health():
880
- return health_check()
881
-
882
- @app.get("/info")
883
- async def _info():
884
- return get_model_info()
885
-
886
- @app.post("/query")
887
- async def _query(payload: QueryIn):
888
- return query({k: v for k, v in payload.dict().items() if v is not None})
889
-
890
- @app.post("/analyze/json")
891
- async def analyze_json(payload: QueryIn):
892
- data = {k: v for k, v in payload.dict().items() if v is not None}
893
- data["output_mode"] = "json"
894
- return query(data)
895
-
896
- @app.post("/analyze/report-en")
897
- async def analyze_report_en(payload: QueryIn):
898
- data = {k: v for k, v in payload.dict().items() if v is not None}
899
- data["output_mode"] = "report_en"
900
- return query(data)
901
- else:
902
- app = None # uvicorn handler:app would fail if FastAPI is not installed
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ PULSE ECG Handler — Demo Parity + Style Hint
4
+ - Demo app.py ile aynı üretim ayarları:
5
+ do_sample=True, temperature=0.05, top_p=1.0, max_new_tokens=4096
6
+ - Stopping: konuşma ayırıcıda (conv.sep/sep2) güvenli token-eşleşmeli kriter
7
+ - Görsel tensörü: .half() ve model cihazında
8
+ - Streamer: TextIteratorStreamer (demo gibi), thread ile generate
9
+ - Seed/deterministic KAPALI (göndermezseniz); demo gibi stokastik
10
+ - STYLE_HINT: demo üslubuna (narratif + sonda tek satır structured impression) yaklaşmak için
11
+ - Post-process: YALNIZCA whitespace/biçim normalizasyonu (yönetim/öneri cümleleri korunur)
 
 
 
 
 
 
12
  """
13
 
14
  import os
15
  import re
16
  import json
17
  import base64
 
18
  import hashlib
19
  import datetime
20
  from io import BytesIO
21
  from threading import Thread
22
+ from typing import Optional, Union
23
 
24
  import torch
25
  from PIL import Image
26
  import requests
27
 
28
+ # ====== LLaVA & Transformers ======
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  try:
30
+ from llava.constants import (
31
+ IMAGE_TOKEN_INDEX,
32
+ DEFAULT_IMAGE_TOKEN,
33
+ )
34
  from llava.conversation import conv_templates, SeparatorStyle
35
  from llava.model.builder import load_pretrained_model
36
+ from llava.mm_utils import (
37
+ tokenizer_image_token,
38
+ process_images,
39
+ get_model_name_from_path,
40
+ )
41
  from llava.utils import disable_torch_init
42
  LLAVA_AVAILABLE = True
43
  except Exception as e:
44
  LLAVA_AVAILABLE = False
45
+ print(f"[WARN] LLaVA not available: {e}")
46
 
47
  try:
48
  from transformers import TextIteratorStreamer, StoppingCriteria
49
  TRANSFORMERS_AVAILABLE = True
50
  except Exception as e:
51
  TRANSFORMERS_AVAILABLE = False
52
+ print(f"[WARN] transformers not available: {e}")
53
 
54
+ # ====== HF Hub logging (opsiyonel) ======
55
  try:
56
  from huggingface_hub import HfApi, login
57
  HF_HUB_AVAILABLE = True
 
66
  api = HfApi()
67
  repo_name = os.environ.get("LOG_REPO", "")
68
  except Exception as e:
69
+ print(f"[HF Hub] init failed: {e}")
70
+ api = None
71
+ repo_name = ""
72
 
73
  LOGDIR = "./logs"
74
  os.makedirs(LOGDIR, exist_ok=True)
75
 
76
+ # ====== Global State ======
77
  tokenizer = None
78
  model = None
79
  image_processor = None
 
81
  args = None
82
  model_initialized = False
83
 
84
+ # ====== Style Hint (demo benzeri üslup) ======
85
  STYLE_HINT = (
86
  "Write one concise narrative paragraph that covers rhythm, heart rate, cardiac axis, "
87
  "P waves and PR interval, QRS morphology and duration, ST segments, T waves, and QT/QTc. "
 
90
  "followed by a succinct, comma-separated summary of the key diagnoses."
91
  )
92
 
93
+ # ===================== Utilities =====================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
95
  def _safe_upload(path: str):
96
  if api and repo_name and path and os.path.isfile(path):
97
  try:
 
102
  repo_type="dataset",
103
  )
104
  except Exception as e:
105
+ print(f"[upload] failed for {path}: {e}")
106
 
107
  def _conv_log_path() -> str:
108
  t = datetime.datetime.now()
 
110
 
111
  def load_image_any(image_input: Union[str, dict]) -> Image.Image:
112
  """
113
+ Desteklenen:
114
+ - URL (http/https)
115
+ - yerel dosya yolu
116
+ - base64 (opsiyonel data URL prefix ile)
117
+ - {"image": <base64|dataurl>}
118
  """
119
  if isinstance(image_input, str):
120
  s = image_input.strip()
 
124
  return Image.open(BytesIO(r.content)).convert("RGB")
125
  if os.path.exists(s):
126
  return Image.open(s).convert("RGB")
127
+ # base64 (dataurl olabilir)
128
  if s.startswith("data:image"):
129
  s = s.split(",", 1)[1]
130
  raw = base64.b64decode(s)
131
  return Image.open(BytesIO(raw)).convert("RGB")
132
+
133
  if isinstance(image_input, dict) and "image" in image_input:
134
  return load_image_any(image_input["image"])
135
+
136
  raise ValueError("Unsupported image input format")
137
 
138
  def _normalize_whitespace(text: str) -> str:
139
+ """
140
+ Gereksiz boşluk/boş satırları toparlar:
141
+ - Satır başı/sonu boşluklarını siler
142
+ - Birden çok boşluğu tek boşluğa indirger
143
+ - 3+ boş satırı 1 boş satıra indirger
144
+ """
145
  text = text.replace("\r\n", "\n").replace("\r", "\n")
146
  lines = [re.sub(r"[ \t]+", " ", ln.strip()) for ln in text.split("\n")]
147
  text = "\n".join(lines).strip()
 
149
  return text
150
 
151
  def _postprocess_min(text: str) -> str:
152
+ # Yalnızca whitespace/biçim temizliği
153
  return _normalize_whitespace(text)
154
 
155
+ # ====== Güvenli Stop Kriteri (conv separator) ======
156
+ class SafeKeywordsStoppingCriteria(StoppingCriteria):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  """
158
+ conv.sep/sep2 bazlı token eşleşmesi; tensör bool hatası yok.
159
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def __init__(self, keyword: str, tokenizer):
161
+ self.tokenizer = tokenizer
162
  tok = tokenizer(keyword, add_special_tokens=False, return_tensors="pt").input_ids[0]
163
+ self.kw_ids = tok # shape: (n,)
164
+
165
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
166
  if input_ids is None or input_ids.shape[0] == 0:
167
  return False
168
+ out = input_ids[0] # assume bsz=1
169
  n = self.kw_ids.shape[0]
170
  if out.shape[0] < n:
171
  return False
172
  tail = out[-n:]
173
+ kw = self.kw_ids.to(tail.device)
174
+ return torch.equal(tail, kw)
175
+
176
+ # ===================== Core Generation =====================
177
 
 
178
  class InferenceDemo:
179
  def __init__(self, args, model_path, tokenizer_, model_, image_processor_, context_len_):
180
  if not LLAVA_AVAILABLE:
 
183
  self.tokenizer, self.model, self.image_processor, self.context_len = (
184
  tokenizer_, model_, image_processor_, context_len_
185
  )
186
+ # Parite için sabit şablon
187
  self.conv_mode = "llava_v1"
188
  self.conversation = conv_templates[self.conv_mode].copy()
189
  self.num_frames = getattr(args, "num_frames", 16)
 
200
  self.chatbot = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len)
201
  def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len):
202
  self.init_if_needed(args, model_path, tokenizer, model, image_processor, context_len)
203
+ # Her çağrıda taze template (demo gibi yeni tur)
204
  self.chatbot.conversation = conv_templates[self.chatbot.conv_mode].copy()
205
  return self.chatbot
206
 
207
  chat_manager = ChatSessionManager()
208
 
209
  def _build_prompt_and_ids(chatbot, user_text: str, device: torch.device):
210
+ # DEMO PARİTE: sarım yok, tek görüntü için tek image token
211
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
212
  chatbot.conversation.append_message(chatbot.conversation.roles[0], inp)
213
  chatbot.conversation.append_message(chatbot.conversation.roles[1], None)
214
  prompt = chatbot.conversation.get_prompt()
215
+
216
  input_ids = tokenizer_image_token(
217
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
218
  ).unsqueeze(0).to(device)
219
  return prompt, input_ids
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def generate_response(
222
  message_text: str,
223
  image_input,
 
227
  max_new_tokens: Optional[int] = None,
228
  conv_mode_override: Optional[str] = None,
229
  repetition_penalty: Optional[float] = None,
230
+ det_seed: Optional[int] = None, # None → stokastik (demo gibi)
 
 
 
231
  ):
232
  if not (LLAVA_AVAILABLE and TRANSFORMERS_AVAILABLE):
233
  return {"error": "Required libraries not available (llava/transformers)"}
234
  if not message_text or image_input is None:
235
  return {"error": "Both 'message' and 'image' are required"}
236
 
237
+ # Varsayılanlar → demo
238
  if temperature is None: temperature = 0.05
239
  if top_p is None: top_p = 1.0
240
  if max_new_tokens is None: max_new_tokens = 4096
241
+ if repetition_penalty is None: repetition_penalty = 1.0 # etkisiz
 
 
 
 
 
 
 
 
 
242
 
243
+ # Chat session
244
  chatbot = chat_manager.get_chatbot(args, args.model_path, tokenizer, model, image_processor, context_len)
245
  if conv_mode_override and conv_mode_override in conv_templates:
246
  chatbot.conversation = conv_templates[conv_mode_override].copy()
247
 
248
+ # Görüntü yükle
249
  try:
250
  pil_img = load_image_any(image_input)
251
  except Exception as e:
252
  return {"error": f"Failed to load image: {e}"}
253
 
254
+ # Log için hash+path
255
+ img_hash, img_path = "NA", None
256
+ try:
257
+ buf = BytesIO(); pil_img.save(buf, format="JPEG"); raw = buf.getvalue()
258
+ img_hash = hashlib.md5(raw).hexdigest()
259
+ t = datetime.datetime.now()
260
+ img_path = os.path.join(LOGDIR, "serve_images", f"{t.year:04d}-{t.month:02d}-{t.day:02d}", f"{img_hash}.jpg")
261
+ os.makedirs(os.path.dirname(img_path), exist_ok=True)
262
+ if not os.path.isfile(img_path):
263
+ pil_img.save(img_path)
264
+ except Exception as e:
265
+ print(f"[log] save image failed: {e}")
266
+
267
+ # Cihaz/dtype
268
  device = next(chatbot.model.parameters()).device
269
+ dtype = torch.float16 # demo: half
270
 
271
+ # Görüntü ön-işleme → tensör
272
  try:
273
+ processed = process_images([pil_img], chatbot.image_processor, chatbot.model.config)
274
+ if isinstance(processed, (list, tuple)) and len(processed) > 0:
275
+ image_tensor = processed[0]
276
+ elif isinstance(processed, torch.Tensor):
277
+ image_tensor = processed[0] if processed.ndim == 4 else processed
 
 
 
278
  else:
279
+ return {"error": "Image processing returned empty"}
280
+ if image_tensor.ndim == 3:
281
+ image_tensor = image_tensor.unsqueeze(0) # (1,C,H,W)
282
+ image_tensor = image_tensor.to(device=device, dtype=dtype) # demo: half + device
283
+ except Exception as e:
284
+ return {"error": f"Image processing failed: {e}"}
285
+
286
+ # STYLE_HINT ekle ve prompt hazırla
287
+ msg = (message_text or "").strip()
288
+ msg = f"{msg}\n\n{STYLE_HINT}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  _, input_ids = _build_prompt_and_ids(chatbot, msg, device)
290
 
291
+ # Stop string (conv separator) → güvenli kriter
292
  stop_str = chatbot.conversation.sep if chatbot.conversation.sep_style != SeparatorStyle.TWO else chatbot.conversation.sep2
293
  stopping = SafeKeywordsStoppingCriteria(stop_str, chatbot.tokenizer)
294
 
295
+ # Seed (gönderilmediyse stokastik → demo gibi)
296
  if det_seed is not None:
297
  try:
298
  s = int(det_seed)
 
303
  except Exception:
304
  pass
305
 
306
+ # Streamer (demo gibi)
307
+ streamer = TextIteratorStreamer(
308
+ chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True
309
+ )
310
+
311
+ # Generate kwargs — demo ayarları
312
  gen_kwargs = dict(
313
  inputs=input_ids,
314
  images=image_tensor,
315
  streamer=streamer,
316
+ do_sample=True, # DEMO
317
+ temperature=float(temperature), # DEMO default 0.05
318
+ top_p=float(top_p), # DEMO default 1.0
319
+ max_new_tokens=int(max_new_tokens), # DEMO slider
320
+ repetition_penalty=float(repetition_penalty), # default 1.0 → etkisiz
321
  use_cache=False,
322
+ stopping_criteria=[stopping], # DEMO-benzeri durdurma
323
  )
324
 
325
+ # Üretim (arka thread) + akışı topla
326
  try:
327
  t = Thread(target=chatbot.model.generate, kwargs=gen_kwargs)
328
  t.start()
329
  chunks = []
330
  for piece in streamer:
331
  chunks.append(piece)
332
+ text = "".join(chunks)
333
+ text = _postprocess_min(text) # yalnızca whitespace/format temizliği
334
  chatbot.conversation.messages[-1][-1] = text
335
  except Exception as e:
336
  return {"error": f"Generation failed: {e}"}
337
 
338
+ # Log
 
 
 
 
339
  try:
340
+ row = {
341
+ "time": datetime.datetime.now().isoformat(),
342
+ "type": "chat",
343
+ "model": "PULSE-7B",
344
+ "state": [(message_text, text)],
345
+ "image_hash": img_hash,
346
+ "image_path": img_path or "",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  }
348
+ with open(_conv_log_path(), "a", encoding="utf-8") as f:
349
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
350
+ _safe_upload(_conv_log_path()); _safe_upload(img_path or "")
351
+ except Exception as e:
352
+ print(f"[log] failed: {e}")
353
 
 
354
  return {"status": "success", "response": text, "conversation_id": id(chatbot.conversation)}
355
 
356
+ # ===================== Public API =====================
357
+
358
  def query(payload: dict):
359
+ """HF Endpoint entry (demo-like)."""
360
  global model_initialized, tokenizer, model, image_processor, context_len, args
361
  if not model_initialized:
362
  if not initialize_model():
 
369
  if not message.strip(): return {"error": "Missing 'message' text"}
370
  if image is None: return {"error": "Missing 'image'. Use 'image', 'image_url', or 'img'."}
371
 
372
+ # Demo varsayılanları — payload override edebilir
373
  temperature = float(payload.get("temperature", 0.05))
374
  top_p = float(payload.get("top_p", 1.0))
375
  max_new_tokens = int(payload.get("max_output_tokens", payload.get("max_new_tokens", payload.get("max_tokens", 4096))))
376
+ repetition_penalty = float(payload.get("repetition_penalty", 1.0)) # etkisiz default
377
 
378
  conv_mode_override = payload.get("conv_mode", None)
379
  det_seed = payload.get("det_seed", None)
 
 
 
 
 
 
380
  if det_seed is not None:
381
  try: det_seed = int(det_seed)
382
  except Exception: det_seed = None
 
390
  conv_mode_override=conv_mode_override,
391
  repetition_penalty=repetition_penalty,
392
  det_seed=det_seed,
 
 
 
393
  )
394
  except Exception as e:
395
  return {"error": f"Query failed: {e}"}
 
412
  "device": str(next(model.parameters()).device) if model else "Unknown",
413
  }
414
 
415
+ # ===================== Init & Session =====================
416
+
417
  class _Args:
418
  def __init__(self):
419
  self.model_path = os.getenv("HF_MODEL_ID", "PULSE-ECG/PULSE-7B")
420
  self.model_base = None
421
  self.num_gpus = int(os.getenv("NUM_GPUS", "1"))
422
+ self.conv_mode = "llava_v1" # Parite için sabit
423
  self.max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", "4096"))
424
  self.num_frames = 16
425
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
 
429
  def initialize_model():
430
  global tokenizer, model, image_processor, context_len, args
431
  if not LLAVA_AVAILABLE:
432
+ print("[init] LLaVA not available; cannot init.")
433
  return False
434
  try:
435
  args = _Args()
 
436
  model_name = get_model_name_from_path(args.model_path)
 
437
  tokenizer_, model_, image_processor_, context_len_ = load_pretrained_model(
438
  args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit
439
  )
440
+ # demo: model genelde cuda’da çalıştırır
 
441
  try:
442
  _ = next(model_.parameters()).device
443
  except Exception:
444
  if torch.cuda.is_available():
445
  model_ = model_.to(torch.device("cuda"))
446
  model_.eval()
 
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  globals()["tokenizer"] = tokenizer_
449
  globals()["model"] = model_
450
  globals()["image_processor"] = image_processor_
 
454
  print("[init] model/tokenizer/image_processor loaded.")
455
  return True
456
  except Exception as e:
457
+ print(f"[init] failed: {e}")
458
  return False
459
 
460
+ # ===================== HF EndpointHandler =====================
461
+
462
  class EndpointHandler:
463
+ """Hugging Face Endpoint uyumlu sınıf"""
464
  def __init__(self, model_dir):
465
  self.model_dir = model_dir
466
  print(f"EndpointHandler initialized with model_dir: {model_dir}")
 
474
  return get_model_info()
475
 
476
  if __name__ == "__main__":
477
+ print("Handler ready (Demo Parity + Style Hint + whitespace post-process). Use `EndpointHandler` or `query`.")