CanerDedeoglu commited on
Commit
5ea9d94
·
verified ·
1 Parent(s): 61bc5b4

no_repeat_ngram_size added

Browse files
Files changed (1) hide show
  1. handler.py +55 -8
handler.py CHANGED
@@ -6,6 +6,7 @@ PULSE ECG Handler - Deterministic ECG Analysis Model (app.py uyumlu)
6
  - Model dtype/device ile uyumlu görüntü tensörü (3D/4D/5D destekli)
7
  - Sağlam URL/base64 işleme, güvenli logging, opsiyonel HF upload
8
  - Zorunlu başlık şablonu + min_new_tokens ile tam Step 1–9 çıktısı
 
9
  """
10
 
11
  import os
@@ -115,7 +116,9 @@ DEFAULT_ECG_PROMPT = (
115
  "Step 8: T Wave Analysis\n"
116
  "Step 9: QT/QTc Interval Analysis\n"
117
  "Structured Clinical Impression:\n"
118
- "If a section is normal, write 'Normal' and give a brief justification."
 
 
119
  )
120
 
121
  # ---------- Yardımcılar ----------
@@ -135,7 +138,8 @@ def _safe_upload(path):
135
  def get_conv_log_filename():
136
  t = datetime.datetime.now()
137
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
138
- os.makedirs(os.path.dirname(name), exist_ok=True)
 
139
  return name
140
 
141
  def get_conv_vote_filename():
@@ -153,9 +157,8 @@ def vote_last_response(state, vote_type, model_selector):
153
  except Exception as e:
154
  print(f"Failed to record vote: {e}")
155
 
156
- # Yalın uzantı listeleri (sorunlu formatlar çıkarıldı)
157
  IMAGE_EXTS = {"jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "jfif"}
158
- # HEIC/HEIF: pillow-heif yoksa destekleme
159
  try:
160
  import pillow_heif # noqa: F401
161
  IMAGE_EXTS.update({"heic", "heif"})
@@ -225,7 +228,6 @@ def process_image_input(image_input):
225
  return load_image(image_input)
226
  if os.path.exists(image_input):
227
  return load_image(image_input)
228
- # muhtemelen base64
229
  return process_base64_image(image_input)
230
  if isinstance(image_input, dict) and "image" in image_input:
231
  return process_base64_image(image_input["image"])
@@ -298,6 +300,49 @@ def _enforce_section_template(text: str) -> str:
298
 
299
  return "\n\n".join(filled)
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  # ---------- Oturum / Konuşma ----------
302
 
303
  class InferenceDemo(object):
@@ -471,7 +516,7 @@ def generate_response(message_text,
471
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
472
  ).unsqueeze(0).to(model_device)
473
 
474
- # Stop kriteri (app.py uyumlu)
475
  stopping_criteria = _stop_criteria_from_conv(chatbot, input_ids)
476
 
477
  # Deterministik üretim
@@ -495,7 +540,8 @@ def generate_response(message_text,
495
  images=image_tensor,
496
  do_sample=False, # deterministik
497
  max_new_tokens=int(max_output_tokens),
498
- min_new_tokens=800, # en az bu kadar üret (step başlıkları garanti)
 
499
  repetition_penalty=float(repetition_penalty),
500
  use_cache=False,
501
  pad_token_id=eos_id,
@@ -508,8 +554,9 @@ def generate_response(message_text,
508
  gen = outputs[0][input_ids.shape[1]:]
509
  response = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
510
 
511
- # ŞABLON ZORLAMA: Step1–9 + Structured
512
  response = _enforce_section_template(response)
 
513
 
514
  # Konuşmaya yerleştir
515
  if chatbot.conversation.messages and isinstance(chatbot.conversation.messages[-1], list):
 
6
  - Model dtype/device ile uyumlu görüntü tensörü (3D/4D/5D destekli)
7
  - Sağlam URL/base64 işleme, güvenli logging, opsiyonel HF upload
8
  - Zorunlu başlık şablonu + min_new_tokens ile tam Step 1–9 çıktısı
9
+ - Tekrarları engelleme (no_repeat_ngram_size) + post-format dedup
10
  """
11
 
12
  import os
 
116
  "Step 8: T Wave Analysis\n"
117
  "Step 9: QT/QTc Interval Analysis\n"
118
  "Structured Clinical Impression:\n"
119
+ "If a section is normal, write 'Normal' and give a brief justification. "
120
+ "Each section must be 1–3 concise sentences. Do not repeat identical statements. "
121
+ "Write the final diagnostic impression only once in 'Structured Clinical Impression' and do not restate it elsewhere."
122
  )
123
 
124
  # ---------- Yardımcılar ----------
 
138
  def get_conv_log_filename():
139
  t = datetime.datetime.now()
140
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
141
+ os.makedirs(os.path.dirname(name), exist_ok=True
142
+ )
143
  return name
144
 
145
  def get_conv_vote_filename():
 
157
  except Exception as e:
158
  print(f"Failed to record vote: {e}")
159
 
160
+ # Yalın uzantı listeleri
161
  IMAGE_EXTS = {"jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "jfif"}
 
162
  try:
163
  import pillow_heif # noqa: F401
164
  IMAGE_EXTS.update({"heic", "heif"})
 
228
  return load_image(image_input)
229
  if os.path.exists(image_input):
230
  return load_image(image_input)
 
231
  return process_base64_image(image_input)
232
  if isinstance(image_input, dict) and "image" in image_input:
233
  return process_base64_image(image_input["image"])
 
300
 
301
  return "\n\n".join(filled)
302
 
303
+ def _sent_split(s: str):
304
+ return [x.strip() for x in re.split(r'(?<=[.!?])\s+', s.strip()) if x.strip()]
305
+
306
+ def _norm_key(s: str):
307
+ return re.sub(r'\W+', ' ', s.lower()).strip()
308
+
309
+ def _dedupe_and_clip_sections(text: str) -> str:
310
+ """
311
+ Şablon oluşmuş metni alır, her bölümde tekrar eden cümleleri siler,
312
+ uzunluğu kısaltır (Steps: ≤3 cümle, Impression: ≤6 cümle) ve birleştirir.
313
+ """
314
+ pieces = _SECTION_RE.split(text)
315
+ found = {}
316
+ i = 1
317
+ while i + 1 < len(pieces):
318
+ heading = pieces[i].strip()
319
+ content = pieces[i + 1].strip()
320
+ for canonical in SECTION_ORDER:
321
+ if heading.lower().startswith(canonical.lower().rstrip(":")):
322
+ found[canonical] = content
323
+ break
324
+ i += 2
325
+
326
+ out_sections = []
327
+ for sec in SECTION_ORDER:
328
+ body = (found.get(sec, "") or "").strip()
329
+ sents = _sent_split(body)
330
+
331
+ seen = set()
332
+ deduped = []
333
+ for s in sents:
334
+ k = _norm_key(s)
335
+ if k not in seen:
336
+ seen.add(k)
337
+ deduped.append(s)
338
+
339
+ limit = 3 if sec.startswith("Step") else 6
340
+ limited = deduped[:limit] if deduped else []
341
+ out_body = " ".join(limited) if limited else body
342
+ out_sections.append(f"{sec}\n{out_body}" if out_body else f"{sec}\n")
343
+
344
+ return "\n\n".join(out_sections)
345
+
346
  # ---------- Oturum / Konuşma ----------
347
 
348
  class InferenceDemo(object):
 
516
  prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
517
  ).unsqueeze(0).to(model_device)
518
 
519
+ # Stop kriteri
520
  stopping_criteria = _stop_criteria_from_conv(chatbot, input_ids)
521
 
522
  # Deterministik üretim
 
540
  images=image_tensor,
541
  do_sample=False, # deterministik
542
  max_new_tokens=int(max_output_tokens),
543
+ min_new_tokens=350, # 800 -> 350 (tekrar riskini azalt)
544
+ no_repeat_ngram_size=5, # tekrar bloklarını engelle
545
  repetition_penalty=float(repetition_penalty),
546
  use_cache=False,
547
  pad_token_id=eos_id,
 
554
  gen = outputs[0][input_ids.shape[1]:]
555
  response = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
556
 
557
+ # ŞABLON ZORLAMA + tekrar kırpma
558
  response = _enforce_section_template(response)
559
+ response = _dedupe_and_clip_sections(response)
560
 
561
  # Konuşmaya yerleştir
562
  if chatbot.conversation.messages and isinstance(chatbot.conversation.messages[-1], list):