ubden commited on
Commit
2f02401
·
verified ·
1 Parent(s): f9bf71d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +43 -16
handler.py CHANGED
@@ -350,9 +350,9 @@ class EndpointHandler:
350
  # Create ECG-specific prompt that mimics visual analysis
351
  ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
352
 
353
- # Use very simple, direct prompt that works with PULSE-7B
354
- # Just keep the original query - model is trained for ECG analysis
355
- text = text # Keep original query exactly as is
356
  else:
357
  # Simple string input
358
  text = str(inputs)
@@ -360,9 +360,9 @@ class EndpointHandler:
360
  if not text:
361
  return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
362
 
363
- # Get generation parameters - using PULSE-7B demo's optimal settings
364
  parameters = data.get("parameters", {})
365
- max_new_tokens = min(parameters.get("max_new_tokens", 4096), 8192) # Demo uses 4096 default, 8192 max
366
  temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
367
  top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
368
  do_sample = parameters.get("do_sample", True) # Demo uses sampling
@@ -381,7 +381,7 @@ class EndpointHandler:
381
  result = self.pipe(
382
  text,
383
  max_new_tokens=max_new_tokens,
384
- min_new_tokens=100, # Force detailed analysis like demo
385
  temperature=temperature,
386
  top_p=top_p,
387
  do_sample=do_sample,
@@ -463,7 +463,7 @@ class EndpointHandler:
463
  input_ids,
464
  attention_mask=attention_mask,
465
  max_new_tokens=max_new_tokens,
466
- min_new_tokens=50, # Ensure substantial response
467
  temperature=temperature,
468
  top_p=top_p,
469
  do_sample=do_sample,
@@ -484,14 +484,45 @@ class EndpointHandler:
484
  # Aggressive cleanup of artifacts
485
  generated_text = generated_text.replace("</s>", "").strip()
486
 
487
- # Remove training-style artifacts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  cleanup_patterns = [
489
  "In this task",
490
  "I'm asking the respondent",
491
- "The respondent should",
492
- "Answer: 7)",
493
- "Any abnormalities or pathological findings.",
494
- "The main features and diagnosis in this ECG image are"
495
  ]
496
 
497
  for pattern in cleanup_patterns:
@@ -500,10 +531,6 @@ class EndpointHandler:
500
  generated_text = parts[0].strip()
501
  break
502
 
503
- # Clean up Answer: prefix if it exists
504
- if generated_text.startswith("Answer:"):
505
- generated_text = generated_text[7:].strip()
506
-
507
  # Only provide fallback if response is truly empty or malformed
508
  if len(generated_text) < 10 or generated_text.startswith("7)"):
509
  print("⚠️ Malformed response detected, providing fallback...")
 
350
  # Create ECG-specific prompt that mimics visual analysis
351
  ecg_context = f"Analyzing an ECG image ({image.size[0]}x{image.size[1]} pixels). "
352
 
353
+ # Use demo's exact approach - no additional context, just the query
354
+ # Model is trained to understand ECG images from text queries
355
+ pass # Keep text exactly as received
356
  else:
357
  # Simple string input
358
  text = str(inputs)
 
360
  if not text:
361
  return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
362
 
363
+ # Get generation parameters - using PULSE-7B demo's exact settings
364
  parameters = data.get("parameters", {})
365
+ max_new_tokens = min(parameters.get("max_new_tokens", 1024), 8192) # Demo uses 1024 default
366
  temperature = parameters.get("temperature", 0.05) # Demo uses 0.05 for precise medical analysis
367
  top_p = parameters.get("top_p", 1.0) # Demo uses 1.0 for full vocabulary access
368
  do_sample = parameters.get("do_sample", True) # Demo uses sampling
 
381
  result = self.pipe(
382
  text,
383
  max_new_tokens=max_new_tokens,
384
+ min_new_tokens=200, # Force very detailed analysis to match demo
385
  temperature=temperature,
386
  top_p=top_p,
387
  do_sample=do_sample,
 
463
  input_ids,
464
  attention_mask=attention_mask,
465
  max_new_tokens=max_new_tokens,
466
+ min_new_tokens=200, # Force detailed response like demo
467
  temperature=temperature,
468
  top_p=top_p,
469
  do_sample=do_sample,
 
484
  # Aggressive cleanup of artifacts
485
  generated_text = generated_text.replace("</s>", "").strip()
486
 
487
+ # Clean up parenthetical Answer format
488
+ if generated_text.startswith("(Answer:") and ")" in generated_text:
489
+ # Extract and expand the concise answer
490
+ end_paren = generated_text.find(")")
491
+ answer_content = generated_text[8:end_paren].strip() # Remove "(Answer:"
492
+
493
+ # Expand the concise answer into full medical interpretation
494
+ if "sinus rhythm" in answer_content.lower():
495
+ parts = [part.strip() for part in answer_content.split(",")]
496
+ expanded_parts = []
497
+
498
+ for part in parts:
499
+ if "sinus rhythm" in part.lower():
500
+ expanded_parts.append("The electrocardiogram (ECG) reveals a sinus rhythm, indicating a normal heart rate and rhythm.")
501
+ elif "inferior infarct" in part.lower():
502
+ expanded_parts.append("The ECG shows signs of an inferior infarct, indicating myocardial damage in the inferior region.")
503
+ elif "anterior" in part.lower() and "infarct" in part.lower():
504
+ expanded_parts.append("There are signs of a possible acute anterior infarct.")
505
+ elif "fascicular block" in part.lower() or "block" in part.lower():
506
+ expanded_parts.append("The ECG suggests possible left anterior fascicular block, which may indicate a conduction abnormality in the heart's electrical system.")
507
+ elif "hypertrophy" in part.lower():
508
+ expanded_parts.append(f"There are signs of possible {part.lower()}.")
509
+
510
+ if expanded_parts:
511
+ generated_text = " ".join(expanded_parts)
512
+ else:
513
+ generated_text = f"The ECG shows {answer_content.lower()}."
514
+ else:
515
+ generated_text = f"The ECG shows {answer_content.lower()}."
516
+
517
+ # Clean up other artifacts
518
+ elif generated_text.startswith("Answer:"):
519
+ generated_text = generated_text[7:].strip()
520
+
521
+ # Remove training artifacts
522
  cleanup_patterns = [
523
  "In this task",
524
  "I'm asking the respondent",
525
+ "The respondent should"
 
 
 
526
  ]
527
 
528
  for pattern in cleanup_patterns:
 
531
  generated_text = parts[0].strip()
532
  break
533
 
 
 
 
 
534
  # Only provide fallback if response is truly empty or malformed
535
  if len(generated_text) < 10 or generated_text.startswith("7)"):
536
  print("⚠️ Malformed response detected, providing fallback...")