no_repeat_ngram_size added
Browse files- handler.py +55 -8
handler.py
CHANGED
|
@@ -6,6 +6,7 @@ PULSE ECG Handler - Deterministic ECG Analysis Model (app.py uyumlu)
|
|
| 6 |
- Model dtype/device ile uyumlu görüntü tensörü (3D/4D/5D destekli)
|
| 7 |
- Sağlam URL/base64 işleme, güvenli logging, opsiyonel HF upload
|
| 8 |
- Zorunlu başlık şablonu + min_new_tokens ile tam Step 1–9 çıktısı
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
import os
|
|
@@ -115,7 +116,9 @@ DEFAULT_ECG_PROMPT = (
|
|
| 115 |
"Step 8: T Wave Analysis\n"
|
| 116 |
"Step 9: QT/QTc Interval Analysis\n"
|
| 117 |
"Structured Clinical Impression:\n"
|
| 118 |
-
"If a section is normal, write 'Normal' and give a brief justification."
|
|
|
|
|
|
|
| 119 |
)
|
| 120 |
|
| 121 |
# ---------- Yardımcılar ----------
|
|
@@ -135,7 +138,8 @@ def _safe_upload(path):
|
|
| 135 |
def get_conv_log_filename():
|
| 136 |
t = datetime.datetime.now()
|
| 137 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
|
| 138 |
-
os.makedirs(os.path.dirname(name), exist_ok=True
|
|
|
|
| 139 |
return name
|
| 140 |
|
| 141 |
def get_conv_vote_filename():
|
|
@@ -153,9 +157,8 @@ def vote_last_response(state, vote_type, model_selector):
|
|
| 153 |
except Exception as e:
|
| 154 |
print(f"Failed to record vote: {e}")
|
| 155 |
|
| 156 |
-
# Yalın uzantı listeleri
|
| 157 |
IMAGE_EXTS = {"jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "jfif"}
|
| 158 |
-
# HEIC/HEIF: pillow-heif yoksa destekleme
|
| 159 |
try:
|
| 160 |
import pillow_heif # noqa: F401
|
| 161 |
IMAGE_EXTS.update({"heic", "heif"})
|
|
@@ -225,7 +228,6 @@ def process_image_input(image_input):
|
|
| 225 |
return load_image(image_input)
|
| 226 |
if os.path.exists(image_input):
|
| 227 |
return load_image(image_input)
|
| 228 |
-
# muhtemelen base64
|
| 229 |
return process_base64_image(image_input)
|
| 230 |
if isinstance(image_input, dict) and "image" in image_input:
|
| 231 |
return process_base64_image(image_input["image"])
|
|
@@ -298,6 +300,49 @@ def _enforce_section_template(text: str) -> str:
|
|
| 298 |
|
| 299 |
return "\n\n".join(filled)
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
# ---------- Oturum / Konuşma ----------
|
| 302 |
|
| 303 |
class InferenceDemo(object):
|
|
@@ -471,7 +516,7 @@ def generate_response(message_text,
|
|
| 471 |
prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 472 |
).unsqueeze(0).to(model_device)
|
| 473 |
|
| 474 |
-
# Stop kriteri
|
| 475 |
stopping_criteria = _stop_criteria_from_conv(chatbot, input_ids)
|
| 476 |
|
| 477 |
# Deterministik üretim
|
|
@@ -495,7 +540,8 @@ def generate_response(message_text,
|
|
| 495 |
images=image_tensor,
|
| 496 |
do_sample=False, # deterministik
|
| 497 |
max_new_tokens=int(max_output_tokens),
|
| 498 |
-
min_new_tokens=
|
|
|
|
| 499 |
repetition_penalty=float(repetition_penalty),
|
| 500 |
use_cache=False,
|
| 501 |
pad_token_id=eos_id,
|
|
@@ -508,8 +554,9 @@ def generate_response(message_text,
|
|
| 508 |
gen = outputs[0][input_ids.shape[1]:]
|
| 509 |
response = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
|
| 510 |
|
| 511 |
-
# ŞABLON ZORLAMA
|
| 512 |
response = _enforce_section_template(response)
|
|
|
|
| 513 |
|
| 514 |
# Konuşmaya yerleştir
|
| 515 |
if chatbot.conversation.messages and isinstance(chatbot.conversation.messages[-1], list):
|
|
|
|
| 6 |
- Model dtype/device ile uyumlu görüntü tensörü (3D/4D/5D destekli)
|
| 7 |
- Sağlam URL/base64 işleme, güvenli logging, opsiyonel HF upload
|
| 8 |
- Zorunlu başlık şablonu + min_new_tokens ile tam Step 1–9 çıktısı
|
| 9 |
+
- Tekrarları engelleme (no_repeat_ngram_size) + post-format dedup
|
| 10 |
"""
|
| 11 |
|
| 12 |
import os
|
|
|
|
| 116 |
"Step 8: T Wave Analysis\n"
|
| 117 |
"Step 9: QT/QTc Interval Analysis\n"
|
| 118 |
"Structured Clinical Impression:\n"
|
| 119 |
+
"If a section is normal, write 'Normal' and give a brief justification. "
|
| 120 |
+
"Each section must be 1–3 concise sentences. Do not repeat identical statements. "
|
| 121 |
+
"Write the final diagnostic impression only once in 'Structured Clinical Impression' and do not restate it elsewhere."
|
| 122 |
)
|
| 123 |
|
| 124 |
# ---------- Yardımcılar ----------
|
|
|
|
| 138 |
def get_conv_log_filename():
|
| 139 |
t = datetime.datetime.now()
|
| 140 |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json")
|
| 141 |
+
os.makedirs(os.path.dirname(name), exist_ok=True
|
| 142 |
+
)
|
| 143 |
return name
|
| 144 |
|
| 145 |
def get_conv_vote_filename():
|
|
|
|
| 157 |
except Exception as e:
|
| 158 |
print(f"Failed to record vote: {e}")
|
| 159 |
|
| 160 |
+
# Yalın uzantı listeleri
|
| 161 |
IMAGE_EXTS = {"jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "jfif"}
|
|
|
|
| 162 |
try:
|
| 163 |
import pillow_heif # noqa: F401
|
| 164 |
IMAGE_EXTS.update({"heic", "heif"})
|
|
|
|
| 228 |
return load_image(image_input)
|
| 229 |
if os.path.exists(image_input):
|
| 230 |
return load_image(image_input)
|
|
|
|
| 231 |
return process_base64_image(image_input)
|
| 232 |
if isinstance(image_input, dict) and "image" in image_input:
|
| 233 |
return process_base64_image(image_input["image"])
|
|
|
|
| 300 |
|
| 301 |
return "\n\n".join(filled)
|
| 302 |
|
| 303 |
+
def _sent_split(s: str):
|
| 304 |
+
return [x.strip() for x in re.split(r'(?<=[.!?])\s+', s.strip()) if x.strip()]
|
| 305 |
+
|
| 306 |
+
def _norm_key(s: str):
|
| 307 |
+
return re.sub(r'\W+', ' ', s.lower()).strip()
|
| 308 |
+
|
| 309 |
+
def _dedupe_and_clip_sections(text: str) -> str:
|
| 310 |
+
"""
|
| 311 |
+
Şablon oluşmuş metni alır, her bölümde tekrar eden cümleleri siler,
|
| 312 |
+
uzunluğu kısaltır (Steps: ≤3 cümle, Impression: ≤6 cümle) ve birleştirir.
|
| 313 |
+
"""
|
| 314 |
+
pieces = _SECTION_RE.split(text)
|
| 315 |
+
found = {}
|
| 316 |
+
i = 1
|
| 317 |
+
while i + 1 < len(pieces):
|
| 318 |
+
heading = pieces[i].strip()
|
| 319 |
+
content = pieces[i + 1].strip()
|
| 320 |
+
for canonical in SECTION_ORDER:
|
| 321 |
+
if heading.lower().startswith(canonical.lower().rstrip(":")):
|
| 322 |
+
found[canonical] = content
|
| 323 |
+
break
|
| 324 |
+
i += 2
|
| 325 |
+
|
| 326 |
+
out_sections = []
|
| 327 |
+
for sec in SECTION_ORDER:
|
| 328 |
+
body = (found.get(sec, "") or "").strip()
|
| 329 |
+
sents = _sent_split(body)
|
| 330 |
+
|
| 331 |
+
seen = set()
|
| 332 |
+
deduped = []
|
| 333 |
+
for s in sents:
|
| 334 |
+
k = _norm_key(s)
|
| 335 |
+
if k not in seen:
|
| 336 |
+
seen.add(k)
|
| 337 |
+
deduped.append(s)
|
| 338 |
+
|
| 339 |
+
limit = 3 if sec.startswith("Step") else 6
|
| 340 |
+
limited = deduped[:limit] if deduped else []
|
| 341 |
+
out_body = " ".join(limited) if limited else body
|
| 342 |
+
out_sections.append(f"{sec}\n{out_body}" if out_body else f"{sec}\n")
|
| 343 |
+
|
| 344 |
+
return "\n\n".join(out_sections)
|
| 345 |
+
|
| 346 |
# ---------- Oturum / Konuşma ----------
|
| 347 |
|
| 348 |
class InferenceDemo(object):
|
|
|
|
| 516 |
prompt, chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 517 |
).unsqueeze(0).to(model_device)
|
| 518 |
|
| 519 |
+
# Stop kriteri
|
| 520 |
stopping_criteria = _stop_criteria_from_conv(chatbot, input_ids)
|
| 521 |
|
| 522 |
# Deterministik üretim
|
|
|
|
| 540 |
images=image_tensor,
|
| 541 |
do_sample=False, # deterministik
|
| 542 |
max_new_tokens=int(max_output_tokens),
|
| 543 |
+
min_new_tokens=350, # 800 -> 350 (tekrar riskini azalt)
|
| 544 |
+
no_repeat_ngram_size=5, # tekrar bloklarını engelle
|
| 545 |
repetition_penalty=float(repetition_penalty),
|
| 546 |
use_cache=False,
|
| 547 |
pad_token_id=eos_id,
|
|
|
|
| 554 |
gen = outputs[0][input_ids.shape[1]:]
|
| 555 |
response = chatbot.tokenizer.decode(gen, skip_special_tokens=True)
|
| 556 |
|
| 557 |
+
# ŞABLON ZORLAMA + tekrar kırpma
|
| 558 |
response = _enforce_section_template(response)
|
| 559 |
+
response = _dedupe_and_clip_sections(response)
|
| 560 |
|
| 561 |
# Konuşmaya yerleştir
|
| 562 |
if chatbot.conversation.messages and isinstance(chatbot.conversation.messages[-1], list):
|