CanerDedeoglu commited on
Commit
61bc5b4
·
verified ·
1 Parent(s): f9e643a

SECTION_ORDER added

Browse files
Files changed (1) hide show
  1. handler.py +76 -14
handler.py CHANGED
@@ -9,6 +9,7 @@ PULSE ECG Handler - Deterministic ECG Analysis Model (app.py uyumlu)
9
  """
10
 
11
  import os
 
12
  import datetime
13
  import torch
14
  import hashlib
@@ -146,11 +147,7 @@ def get_conv_vote_filename():
146
  def vote_last_response(state, vote_type, model_selector):
147
  try:
148
  with open(get_conv_vote_filename(), "a") as fout:
149
- data = {
150
- "type": vote_type,
151
- "model": model_selector,
152
- "state": state,
153
- }
154
  fout.write(json.dumps(data) + "\n")
155
  _safe_upload(get_conv_vote_filename())
156
  except Exception as e:
@@ -158,7 +155,7 @@ def vote_last_response(state, vote_type, model_selector):
158
 
159
  # Yalın uzantı listeleri (sorunlu formatlar çıkarıldı)
160
  IMAGE_EXTS = {"jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "jfif"}
161
- # HEIC/HEIF: pillow-heif yoksa desteklemeyelim
162
  try:
163
  import pillow_heif # noqa: F401
164
  IMAGE_EXTS.update({"heic", "heif"})
@@ -234,6 +231,73 @@ def process_image_input(image_input):
234
  return process_base64_image(image_input["image"])
235
  raise ValueError("Unsupported image input format")
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  # ---------- Oturum / Konuşma ----------
238
 
239
  class InferenceDemo(object):
@@ -311,7 +375,6 @@ def _build_prompt(chatbot, user_text: str) -> str:
311
  use_wrap = False
312
 
313
  if use_wrap:
314
- # <im_start><image></im_end>\n + metin
315
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
316
  else:
317
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
@@ -335,17 +398,14 @@ def generate_response(message_text,
335
  if not LLAVA_AVAILABLE:
336
  return {"error": "LLaVA modules not available"}
337
 
338
- # Zorunlu girişler
339
  if not message_text or not image_input:
340
  return {"error": "Both message text and image are required"}
341
 
342
- # Chatbot al
343
  chatbot = chat_manager.get_chatbot(
344
  args, args.model_path if args else "PULSE-ECG/PULSE-7B",
345
  tokenizer, model, image_processor, context_len
346
  )
347
 
348
- # İsteğe bağlı conv override
349
  if conv_mode_override and conv_mode_override in conv_templates:
350
  chatbot.conversation = conv_templates[conv_mode_override].copy()
351
  else:
@@ -381,11 +441,10 @@ def generate_response(message_text,
381
  processed = process_images([image], chatbot.image_processor, chatbot.model.config)
382
 
383
  if isinstance(processed, torch.Tensor):
384
- # (C,H,W) / (B,C,H,W) / (B,T,C,H,W)
385
  if processed.ndim == 3:
386
  image_tensor = processed.unsqueeze(0) # (1,C,H,W)
387
  elif processed.ndim == 4:
388
- image_tensor = processed # (B,C,H,W)
389
  elif processed.ndim == 5:
390
  b, t, c, h, w = processed.shape
391
  image_tensor = processed.reshape(b * t, c, h, w) # (B*T,C,H,W)
@@ -436,7 +495,7 @@ def generate_response(message_text,
436
  images=image_tensor,
437
  do_sample=False, # deterministik
438
  max_new_tokens=int(max_output_tokens),
439
- min_new_tokens=600, # en az bu kadar üret (step başlıkları garanti)
440
  repetition_penalty=float(repetition_penalty),
441
  use_cache=False,
442
  pad_token_id=eos_id,
@@ -449,6 +508,9 @@ def generate_response(message_text,
449
  gen = outputs[0][input_ids.shape[1]:]
450
  response = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
451
 
 
 
 
452
  # Konuşmaya yerleştir
453
  if chatbot.conversation.messages and isinstance(chatbot.conversation.messages[-1], list):
454
  chatbot.conversation.messages[-1][-1] = response
@@ -609,7 +671,7 @@ def initialize_model():
609
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
610
  self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
611
  self.debug = bool(int(os.getenv("DEBUG", "0")))
612
- # args globaline ata
613
  globals()["args"] = Args()
614
 
615
  model_name = get_model_name_from_path(args.model_path)
 
9
  """
10
 
11
  import os
12
+ import re
13
  import datetime
14
  import torch
15
  import hashlib
 
147
  def vote_last_response(state, vote_type, model_selector):
148
  try:
149
  with open(get_conv_vote_filename(), "a") as fout:
150
+ data = {"type": vote_type, "model": model_selector, "state": state}
 
 
 
 
151
  fout.write(json.dumps(data) + "\n")
152
  _safe_upload(get_conv_vote_filename())
153
  except Exception as e:
 
155
 
156
  # Yalın uzantı listeleri (sorunlu formatlar çıkarıldı)
157
  IMAGE_EXTS = {"jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "jfif"}
158
+ # HEIC/HEIF: pillow-heif yoksa destekleme
159
  try:
160
  import pillow_heif # noqa: F401
161
  IMAGE_EXTS.update({"heic", "heif"})
 
231
  return process_base64_image(image_input["image"])
232
  raise ValueError("Unsupported image input format")
233
 
234
+ # ---------- Şablon dayatma (post-format) ----------
235
+
236
+ SECTION_ORDER = [
237
+ "Step 1: Rhythm Analysis",
238
+ "Step 2: Heart Rate Analysis",
239
+ "Step 3: Cardiac Axis Analysis",
240
+ "Step 4: P Wave Analysis",
241
+ "Step 5: PR Interval Analysis",
242
+ "Step 6: QRS Complex Analysis",
243
+ "Step 7: ST Segment Analysis",
244
+ "Step 8: T Wave Analysis",
245
+ "Step 9: QT/QTc Interval Analysis",
246
+ "Structured Clinical Impression:",
247
+ ]
248
+
249
+ _SECTION_RE = re.compile(
250
+ r"(Step\s*1:\s*Rhythm Analysis|"
251
+ r"Step\s*2:\s*Heart Rate Analysis|"
252
+ r"Step\s*3:\s*Cardiac Axis Analysis|"
253
+ r"Step\s*4:\s*P Wave Analysis|"
254
+ r"Step\s*5:\s*PR Interval Analysis|"
255
+ r"Step\s*6:\s*QRS Complex Analysis|"
256
+ r"Step\s*7:\s*ST Segment Analysis|"
257
+ r"Step\s*8:\s*T Wave Analysis|"
258
+ r"Step\s*9:\s*QT/QTc Interval Analysis|"
259
+ r"Structured Clinical Impression:)",
260
+ flags=re.IGNORECASE
261
+ )
262
+
263
+ def _enforce_section_template(text: str) -> str:
264
+ """
265
+ Model çıktısını yakalayıp Step 1–9 + Structured başlıklarını sırayla ve eksiksiz
266
+ döndürecek şekilde biçimler. Eksik bölümler 'Normal...' notuyla doldurulur.
267
+ """
268
+ pieces = _SECTION_RE.split(text)
269
+ found = {}
270
+ prefix = None
271
+
272
+ if pieces:
273
+ if not _SECTION_RE.match(pieces[0] or ""):
274
+ prefix = (pieces[0] or "").strip()
275
+
276
+ i = 1
277
+ while i + 1 < len(pieces):
278
+ heading = pieces[i].strip()
279
+ content = pieces[i + 1].strip()
280
+ for canonical in SECTION_ORDER:
281
+ if heading.lower().startswith(canonical.lower().rstrip(":")):
282
+ found[canonical] = content
283
+ break
284
+ i += 2
285
+
286
+ filled = []
287
+ for sec in SECTION_ORDER:
288
+ val = (found.get(sec, "") or "").strip()
289
+ if not val:
290
+ if sec.startswith("Step"):
291
+ val = "Normal. No definite abnormality detected in this section based on the provided ECG image."
292
+ else:
293
+ val = "Overall impression: No acute life-threatening abnormality identified. Correlate clinically."
294
+ filled.append(f"{sec}\n{val}")
295
+
296
+ if prefix:
297
+ filled[0] = filled[0] + f"\n\n(Additional notes captured before Step 1): {prefix}"
298
+
299
+ return "\n\n".join(filled)
300
+
301
  # ---------- Oturum / Konuşma ----------
302
 
303
  class InferenceDemo(object):
 
375
  use_wrap = False
376
 
377
  if use_wrap:
 
378
  inp = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{user_text}"
379
  else:
380
  inp = f"{DEFAULT_IMAGE_TOKEN}\n{user_text}"
 
398
  if not LLAVA_AVAILABLE:
399
  return {"error": "LLaVA modules not available"}
400
 
 
401
  if not message_text or not image_input:
402
  return {"error": "Both message text and image are required"}
403
 
 
404
  chatbot = chat_manager.get_chatbot(
405
  args, args.model_path if args else "PULSE-ECG/PULSE-7B",
406
  tokenizer, model, image_processor, context_len
407
  )
408
 
 
409
  if conv_mode_override and conv_mode_override in conv_templates:
410
  chatbot.conversation = conv_templates[conv_mode_override].copy()
411
  else:
 
441
  processed = process_images([image], chatbot.image_processor, chatbot.model.config)
442
 
443
  if isinstance(processed, torch.Tensor):
 
444
  if processed.ndim == 3:
445
  image_tensor = processed.unsqueeze(0) # (1,C,H,W)
446
  elif processed.ndim == 4:
447
+ image_tensor = processed # (B,C,H,W)
448
  elif processed.ndim == 5:
449
  b, t, c, h, w = processed.shape
450
  image_tensor = processed.reshape(b * t, c, h, w) # (B*T,C,H,W)
 
495
  images=image_tensor,
496
  do_sample=False, # deterministik
497
  max_new_tokens=int(max_output_tokens),
498
+ min_new_tokens=800, # en az bu kadar üret (step başlıkları garanti)
499
  repetition_penalty=float(repetition_penalty),
500
  use_cache=False,
501
  pad_token_id=eos_id,
 
508
  gen = outputs[0][input_ids.shape[1]:]
509
  response = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
510
 
511
+ # ŞABLON ZORLAMA: Step1–9 + Structured
512
+ response = _enforce_section_template(response)
513
+
514
  # Konuşmaya yerleştir
515
  if chatbot.conversation.messages and isinstance(chatbot.conversation.messages[-1], list):
516
  chatbot.conversation.messages[-1][-1] = response
 
671
  self.load_8bit = bool(int(os.getenv("LOAD_8BIT", "0")))
672
  self.load_4bit = bool(int(os.getenv("LOAD_4BIT", "0")))
673
  self.debug = bool(int(os.getenv("DEBUG", "0")))
674
+
675
  globals()["args"] = Args()
676
 
677
  model_name = get_model_name_from_path(args.model_path)