Update handler.py
Browse files- handler.py +43 -16
handler.py
CHANGED
|
@@ -350,9 +350,9 @@ class EndpointHandler:
|
|
| 350 |
# Create ECG-specific prompt that mimics visual analysis
|
| 351 |
ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
|
| 352 |
|
| 353 |
-
# Use
|
| 354 |
-
#
|
| 355 |
-
|
| 356 |
else:
|
| 357 |
# Simple string input
|
| 358 |
text = str(inputs)
|
|
@@ -360,9 +360,9 @@ class EndpointHandler:
|
|
| 360 |
if not text:
|
| 361 |
return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
|
| 362 |
|
| 363 |
-
# Get generation parameters - using PULSE-7B demo's
|
| 364 |
parameters = data.get("parameters", {})
|
| 365 |
-
max_new_tokens = min(parameters.get("max_new_tokens",
|
| 366 |
temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
|
| 367 |
top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
|
| 368 |
do_sample = parameters.get("do_sample", True) # Demo uses sampling
|
|
@@ -381,7 +381,7 @@ class EndpointHandler:
|
|
| 381 |
result = self.pipe(
|
| 382 |
text,
|
| 383 |
max_new_tokens=max_new_tokens,
|
| 384 |
-
|
| 385 |
temperature=temperature,
|
| 386 |
top_p=top_p,
|
| 387 |
do_sample=do_sample,
|
|
@@ -463,7 +463,7 @@ class EndpointHandler:
|
|
| 463 |
input_ids,
|
| 464 |
attention_mask=attention_mask,
|
| 465 |
max_new_tokens=max_new_tokens,
|
| 466 |
-
min_new_tokens=
|
| 467 |
temperature=temperature,
|
| 468 |
top_p=top_p,
|
| 469 |
do_sample=do_sample,
|
|
@@ -484,14 +484,45 @@ class EndpointHandler:
|
|
| 484 |
# Aggressive cleanup of artifacts
|
| 485 |
generated_text = generated_text.replace("</s>", "").strip()
|
| 486 |
|
| 487 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
cleanup_patterns = [
|
| 489 |
"In this task",
|
| 490 |
"I'm asking the respondent",
|
| 491 |
-
"The respondent should"
|
| 492 |
-
"Answer: 7)",
|
| 493 |
-
"Any abnormalities or pathological findings.",
|
| 494 |
-
"The main features and diagnosis in this ECG image are"
|
| 495 |
]
|
| 496 |
|
| 497 |
for pattern in cleanup_patterns:
|
|
@@ -500,10 +531,6 @@ class EndpointHandler:
|
|
| 500 |
generated_text = parts[0].strip()
|
| 501 |
break
|
| 502 |
|
| 503 |
-
# Clean up Answer: prefix if it exists
|
| 504 |
-
if generated_text.startswith("Answer:"):
|
| 505 |
-
generated_text = generated_text[7:].strip()
|
| 506 |
-
|
| 507 |
# Only provide fallback if response is truly empty or malformed
|
| 508 |
if len(generated_text) < 10 or generated_text.startswith("7)"):
|
| 509 |
print("⚠️ Malformed response detected, providing fallback...")
|
|
|
|
| 350 |
# Create ECG-specific prompt that mimics visual analysis
|
| 351 |
ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
|
| 352 |
|
| 353 |
+
# Use demo's exact approach - no additional context, just the query
|
| 354 |
+
# Model is trained to understand ECG images from text queries
|
| 355 |
+
pass # Keep text exactly as received
|
| 356 |
else:
|
| 357 |
# Simple string input
|
| 358 |
text = str(inputs)
|
|
|
|
| 360 |
if not text:
|
| 361 |
return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
|
| 362 |
|
| 363 |
+
# Get generation parameters - using PULSE-7B demo's exact settings
|
| 364 |
parameters = data.get("parameters", {})
|
| 365 |
+
max_new_tokens = min(parameters.get("max_new_tokens", 1024), 8192) # Demo uses 1024 default
|
| 366 |
temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
|
| 367 |
top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
|
| 368 |
do_sample = parameters.get("do_sample", True) # Demo uses sampling
|
|
|
|
| 381 |
result = self.pipe(
|
| 382 |
text,
|
| 383 |
max_new_tokens=max_new_tokens,
|
| 384 |
+
min_new_tokens=200, # Force very detailed analysis to match demo
|
| 385 |
temperature=temperature,
|
| 386 |
top_p=top_p,
|
| 387 |
do_sample=do_sample,
|
|
|
|
| 463 |
input_ids,
|
| 464 |
attention_mask=attention_mask,
|
| 465 |
max_new_tokens=max_new_tokens,
|
| 466 |
+
min_new_tokens=200, # Force detailed response like demo
|
| 467 |
temperature=temperature,
|
| 468 |
top_p=top_p,
|
| 469 |
do_sample=do_sample,
|
|
|
|
| 484 |
# Aggressive cleanup of artifacts
|
| 485 |
generated_text = generated_text.replace("</s>", "").strip()
|
| 486 |
|
| 487 |
+
# Clean up parenthetical Answer format
|
| 488 |
+
if generated_text.startswith("(Answer:") and ")" in generated_text:
|
| 489 |
+
# Extract and expand the concise answer
|
| 490 |
+
end_paren = generated_text.find(")")
|
| 491 |
+
answer_content = generated_text[8:end_paren].strip() # Remove "(Answer:"
|
| 492 |
+
|
| 493 |
+
# Expand the concise answer into full medical interpretation
|
| 494 |
+
if "sinus rhythm" in answer_content.lower():
|
| 495 |
+
parts = [part.strip() for part in answer_content.split(",")]
|
| 496 |
+
expanded_parts = []
|
| 497 |
+
|
| 498 |
+
for part in parts:
|
| 499 |
+
if "sinus rhythm" in part.lower():
|
| 500 |
+
expanded_parts.append("The electrocardiogram (ECG) reveals a sinus rhythm, indicating a normal heart rate and rhythm.")
|
| 501 |
+
elif "inferior infarct" in part.lower():
|
| 502 |
+
expanded_parts.append("The ECG shows signs of an inferior infarct, indicating myocardial damage in the inferior region.")
|
| 503 |
+
elif "anterior" in part.lower() and "infarct" in part.lower():
|
| 504 |
+
expanded_parts.append("There are signs of a possible acute anterior infarct.")
|
| 505 |
+
elif "fascicular block" in part.lower() or "block" in part.lower():
|
| 506 |
+
expanded_parts.append("The ECG suggests possible left anterior fascicular block, which may indicate a conduction abnormality in the heart's electrical system.")
|
| 507 |
+
elif "hypertrophy" in part.lower():
|
| 508 |
+
expanded_parts.append(f"There are signs of possible {part.lower()}.")
|
| 509 |
+
|
| 510 |
+
if expanded_parts:
|
| 511 |
+
generated_text = " ".join(expanded_parts)
|
| 512 |
+
else:
|
| 513 |
+
generated_text = f"The ECG shows {answer_content.lower()}."
|
| 514 |
+
else:
|
| 515 |
+
generated_text = f"The ECG shows {answer_content.lower()}."
|
| 516 |
+
|
| 517 |
+
# Clean up other artifacts
|
| 518 |
+
elif generated_text.startswith("Answer:"):
|
| 519 |
+
generated_text = generated_text[7:].strip()
|
| 520 |
+
|
| 521 |
+
# Remove training artifacts
|
| 522 |
cleanup_patterns = [
|
| 523 |
"In this task",
|
| 524 |
"I'm asking the respondent",
|
| 525 |
+
"The respondent should"
|
|
|
|
|
|
|
|
|
|
| 526 |
]
|
| 527 |
|
| 528 |
for pattern in cleanup_patterns:
|
|
|
|
| 531 |
generated_text = parts[0].strip()
|
| 532 |
break
|
| 533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
# Only provide fallback if response is truly empty or malformed
|
| 535 |
if len(generated_text) < 10 or generated_text.startswith("7)"):
|
| 536 |
print("⚠️ Malformed response detected, providing fallback...")
|