frabbani commited on
Commit
82a0e99
·
1 Parent(s): 3ade5b9

Fix fact extraction - pass raw data for simple tools..................,,,

Browse files
Files changed (2) hide show
  1. evaluation/llm_eval.py +204 -0
  2. server.py +68 -28
evaluation/llm_eval.py CHANGED
@@ -238,6 +238,210 @@ def extract_numbers_from_text(text: str) -> Dict[str, Any]:
238
  return numbers
239
 
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  def extract_numbers_from_chart(chart_data: Dict) -> Dict[str, Any]:
242
  """
243
  Extract numerical values from chart data returned by tools.
 
238
  return numbers
239
 
240
 
241
+ def extract_medication_names(text: str, expected_meds: List[str]) -> Dict[str, Any]:
242
+ """
243
+ Extract medication names from LLM response and compare against expected list.
244
+
245
+ Uses fuzzy matching since LLM might abbreviate or paraphrase.
246
+ """
247
+ text_lower = text.lower()
248
+
249
+ found = []
250
+ missing = []
251
+
252
+ for med in expected_meds:
253
+ med_lower = med.lower()
254
+ # Extract the drug name (first word or two before dosage)
255
+ # "Metformin 500 MG Oral Tablet" -> "metformin"
256
+ drug_name = med_lower.split()[0] if med_lower else ""
257
+
258
+ # Check if drug name appears in text
259
+ if drug_name and drug_name in text_lower:
260
+ found.append(med)
261
+ # Also check for common abbreviations or alternate names
262
+ elif any(part in text_lower for part in med_lower.split()[:2] if len(part) > 3):
263
+ found.append(med)
264
+ else:
265
+ missing.append(med)
266
+
267
+ return {
268
+ "found": found,
269
+ "missing": missing,
270
+ "found_count": len(found),
271
+ "expected_count": len(expected_meds),
272
+ "accuracy": len(found) / len(expected_meds) if expected_meds else 1.0
273
+ }
274
+
275
+
276
+ def extract_condition_names(text: str, expected_conditions: List[str]) -> Dict[str, Any]:
277
+ """
278
+ Extract condition names from LLM response and compare against expected list.
279
+ """
280
+ text_lower = text.lower()
281
+
282
+ found = []
283
+ missing = []
284
+
285
+ for condition in expected_conditions:
286
+ cond_lower = condition.lower()
287
+ # Check for key words from condition name
288
+ key_words = [w for w in cond_lower.split() if len(w) > 4]
289
+
290
+ # If any significant word from condition appears in text
291
+ if any(word in text_lower for word in key_words):
292
+ found.append(condition)
293
+ # Also check for common abbreviations
294
+ elif "diabetes" in cond_lower and ("diabetes" in text_lower or "diabetic" in text_lower):
295
+ found.append(condition)
296
+ elif "hypertension" in cond_lower and ("hypertension" in text_lower or "blood pressure" in text_lower or "htn" in text_lower):
297
+ found.append(condition)
298
+ elif "cholesterol" in cond_lower and ("cholesterol" in text_lower or "hyperlipidemia" in text_lower):
299
+ found.append(condition)
300
+ else:
301
+ missing.append(condition)
302
+
303
+ return {
304
+ "found": found,
305
+ "missing": missing,
306
+ "found_count": len(found),
307
+ "expected_count": len(expected_conditions),
308
+ "accuracy": len(found) / len(expected_conditions) if expected_conditions else 1.0
309
+ }
310
+
311
+
312
+ def extract_allergy_names(text: str, expected_allergies: List[str]) -> Dict[str, Any]:
313
+ """
314
+ Extract allergy/allergen names from LLM response.
315
+ """
316
+ text_lower = text.lower()
317
+
318
+ found = []
319
+ missing = []
320
+
321
+ for allergy in expected_allergies:
322
+ allergy_lower = allergy.lower()
323
+ # Check if allergen name appears
324
+ if allergy_lower in text_lower:
325
+ found.append(allergy)
326
+ # Check key words
327
+ elif any(word in text_lower for word in allergy_lower.split() if len(word) > 3):
328
+ found.append(allergy)
329
+ else:
330
+ missing.append(allergy)
331
+
332
+ return {
333
+ "found": found,
334
+ "missing": missing,
335
+ "found_count": len(found),
336
+ "expected_count": len(expected_allergies),
337
+ "accuracy": len(found) / len(expected_allergies) if expected_allergies else 1.0
338
+ }
339
+
340
+
341
+ @dataclass
342
+ class TextComparisonResult:
343
+ """Result of comparing LLM text response against expected items."""
344
+ case_id: str
345
+ query: str
346
+ query_type: str
347
+ success: bool
348
+ expected_items: List[str] = field(default_factory=list)
349
+ found_items: List[str] = field(default_factory=list)
350
+ missing_items: List[str] = field(default_factory=list)
351
+ accuracy: float = 0.0
352
+ raw_response: str = ""
353
+
354
+
355
+ async def evaluate_text_query(
356
+ patient_id: str,
357
+ query: str,
358
+ query_type: str,
359
+ expected_items: List[str],
360
+ case_id: str = ""
361
+ ) -> TextComparisonResult:
362
+ """
363
+ Evaluate LLM response for text-based queries (medications, conditions, allergies).
364
+ """
365
+ result = TextComparisonResult(
366
+ case_id=case_id,
367
+ query=query,
368
+ query_type=query_type,
369
+ success=False,
370
+ expected_items=expected_items
371
+ )
372
+
373
+ # Call the agent
374
+ llm_response = await call_agent_endpoint(patient_id, query, timeout=90.0)
375
+
376
+ if llm_response.error:
377
+ result.raw_response = f"Error: {llm_response.error}"
378
+ return result
379
+
380
+ result.raw_response = llm_response.raw_response
381
+
382
+ # Compare based on query type
383
+ if query_type == "medication_list":
384
+ comparison = extract_medication_names(llm_response.raw_response, expected_items)
385
+ elif query_type == "condition_list":
386
+ comparison = extract_condition_names(llm_response.raw_response, expected_items)
387
+ elif query_type == "allergy_list":
388
+ comparison = extract_allergy_names(llm_response.raw_response, expected_items)
389
+ else:
390
+ # Generic text comparison
391
+ comparison = extract_condition_names(llm_response.raw_response, expected_items)
392
+
393
+ result.found_items = comparison["found"]
394
+ result.missing_items = comparison["missing"]
395
+ result.accuracy = comparison["accuracy"]
396
+ result.success = comparison["accuracy"] >= 0.7 # 70% threshold
397
+
398
+ return result
399
+
400
+
401
+ def aggregate_text_results(results: List[TextComparisonResult]) -> Dict[str, Any]:
402
+ """Aggregate text evaluation results."""
403
+ if not results:
404
+ return {"total_cases": 0, "message": "No text cases evaluated"}
405
+
406
+ total = len(results)
407
+ successful = sum(1 for r in results if r.success)
408
+
409
+ total_expected = sum(len(r.expected_items) for r in results)
410
+ total_found = sum(len(r.found_items) for r in results)
411
+
412
+ by_type = {}
413
+ for r in results:
414
+ if r.query_type not in by_type:
415
+ by_type[r.query_type] = {"total": 0, "passed": 0, "accuracy_sum": 0}
416
+ by_type[r.query_type]["total"] += 1
417
+ by_type[r.query_type]["passed"] += 1 if r.success else 0
418
+ by_type[r.query_type]["accuracy_sum"] += r.accuracy
419
+
420
+ # Compute average accuracy per type
421
+ for qtype in by_type:
422
+ by_type[qtype]["avg_accuracy"] = f"{by_type[qtype]['accuracy_sum'] / by_type[qtype]['total'] * 100:.1f}%"
423
+
424
+ return {
425
+ "total_cases": total,
426
+ "successful_cases": successful,
427
+ "failed_cases": total - successful,
428
+ "success_rate": f"{(successful/total*100):.1f}%",
429
+ "total_expected_items": total_expected,
430
+ "total_found_items": total_found,
431
+ "item_recall": f"{(total_found/total_expected*100):.1f}%" if total_expected > 0 else "N/A",
432
+ "by_type": by_type,
433
+ "failed_details": [
434
+ {
435
+ "case_id": r.case_id,
436
+ "query_type": r.query_type,
437
+ "accuracy": f"{r.accuracy:.0%}",
438
+ "missing": r.missing_items[:5]
439
+ }
440
+ for r in results if not r.success
441
+ ]
442
+ }
443
+
444
+
445
  def extract_numbers_from_chart(chart_data: Dict) -> Dict[str, Any]:
446
  """
447
  Extract numerical values from chart data returned by tools.
server.py CHANGED
@@ -698,32 +698,27 @@ async def run_evaluation(
698
  extract_numbers_from_text,
699
  compare_llm_response,
700
  aggregate_llm_results,
701
- LLMComparisonResult
 
 
702
  )
703
 
704
  print("\nRunning FULL LLM evaluation (this calls actual MedGemma)...")
705
- print("Note: Only testing vital_trend queries (charts) for number accuracy\n")
706
 
707
- # Filter to just vital trend cases (charts with numbers)
708
- vital_cases = [tc for tc in test_cases if tc["query_type"] == "vital_trend"]
709
-
710
- if not vital_cases:
711
- return {
712
- "success": False,
713
- "error": "No vital trend test cases found"
714
- }
715
 
 
716
  llm_results = []
717
 
718
- for i, test_case in enumerate(vital_cases[:5]): # Limit to 5 for speed
719
  patient_id = test_case["patient_id"]
720
  query = test_case["query"]
721
  case_id = test_case["case_id"]
722
  expected = compute_expected_values(test_case)
723
 
724
- print(f" [{i+1}/{min(5, len(vital_cases))}] {query[:50]}...")
725
 
726
- # Call actual LLM
727
  llm_response = await call_agent_endpoint(patient_id, query, timeout=90.0)
728
 
729
  if llm_response.error:
@@ -735,17 +730,12 @@ async def run_evaluation(
735
  errors=[llm_response.error]
736
  ))
737
  else:
738
- # Extract numbers from chart (ground truth) and text (LLM said)
739
  chart_nums = extract_numbers_from_chart(llm_response.chart_data)
740
  text_nums = extract_numbers_from_text(llm_response.raw_response)
741
 
742
- # Debug: show first 300 chars of LLM response
743
- print(f" LLM response (first 300 chars):")
744
- print(f" {llm_response.raw_response[:300].replace(chr(10), ' ')}")
745
  print(f" Chart numbers: {chart_nums}")
746
  print(f" Text numbers: {text_nums}")
747
 
748
- # Compare
749
  result = compare_llm_response(llm_response, expected)
750
  result.case_id = case_id
751
  llm_results.append(result)
@@ -757,26 +747,76 @@ async def run_evaluation(
757
  for err in result.errors[:3]:
758
  print(f" - {err}")
759
 
760
- # Aggregate LLM results
761
- llm_summary = aggregate_llm_results(llm_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762
 
763
  print("\n" + "="*60)
764
  print("LLM RESPONSE ACCURACY REPORT")
765
  print("="*60)
766
- print(f"Test Cases: {llm_summary['total_cases']}")
767
- print(f"Successful: {llm_summary['successful_cases']}")
768
- print(f"Failed: {llm_summary['failed_cases']}")
769
- print(f"Success Rate: {llm_summary['success_rate']}")
770
- print(f"Number Checks: {llm_summary['total_number_checks']}")
771
- print(f"Correct Numbers: {llm_summary['correct_numbers']}")
772
- print(f"Number Accuracy: {llm_summary['number_accuracy']}")
 
 
 
 
 
 
 
 
773
  print("="*60)
774
 
775
  return {
776
  "success": True,
777
  "mode": "llm",
778
  "patients_tested": patients,
779
- "metrics": llm_summary
 
 
 
780
  }
781
 
782
  elif mode == "agent":
 
698
  extract_numbers_from_text,
699
  compare_llm_response,
700
  aggregate_llm_results,
701
+ LLMComparisonResult,
702
+ evaluate_text_query,
703
+ aggregate_text_results
704
  )
705
 
706
  print("\nRunning FULL LLM evaluation (this calls actual MedGemma)...")
 
707
 
708
+ # === PART 1: NUMERIC EVALUATION (Vitals) ===
709
+ print("\n--- PART 1: NUMERIC ACCURACY (Vital Charts) ---\n")
 
 
 
 
 
 
710
 
711
+ vital_cases = [tc for tc in test_cases if tc["query_type"] == "vital_trend"]
712
  llm_results = []
713
 
714
+ for i, test_case in enumerate(vital_cases[:4]): # Limit to 4
715
  patient_id = test_case["patient_id"]
716
  query = test_case["query"]
717
  case_id = test_case["case_id"]
718
  expected = compute_expected_values(test_case)
719
 
720
+ print(f" [{i+1}/{min(4, len(vital_cases))}] {query[:50]}...")
721
 
 
722
  llm_response = await call_agent_endpoint(patient_id, query, timeout=90.0)
723
 
724
  if llm_response.error:
 
730
  errors=[llm_response.error]
731
  ))
732
  else:
 
733
  chart_nums = extract_numbers_from_chart(llm_response.chart_data)
734
  text_nums = extract_numbers_from_text(llm_response.raw_response)
735
 
 
 
 
736
  print(f" Chart numbers: {chart_nums}")
737
  print(f" Text numbers: {text_nums}")
738
 
 
739
  result = compare_llm_response(llm_response, expected)
740
  result.case_id = case_id
741
  llm_results.append(result)
 
747
  for err in result.errors[:3]:
748
  print(f" - {err}")
749
 
750
+ # === PART 2: TEXT EVALUATION (Medications, Conditions, Allergies) ===
751
+ print("\n--- PART 2: TEXT ACCURACY (Medications, Conditions, Allergies) ---\n")
752
+
753
+ text_cases = [tc for tc in test_cases if tc["query_type"] in ["medication_list", "condition_list", "allergy_list"]]
754
+ text_results = []
755
+
756
+ for i, test_case in enumerate(text_cases[:4]): # Limit to 4
757
+ patient_id = test_case["patient_id"]
758
+ query = test_case["query"]
759
+ query_type = test_case["query_type"]
760
+ case_id = test_case["case_id"]
761
+ expected = compute_expected_values(test_case)
762
+
763
+ # Get expected items list based on query type
764
+ if query_type == "medication_list":
765
+ expected_items = expected.get("medication_names", [])
766
+ elif query_type == "condition_list":
767
+ expected_items = expected.get("condition_names", [])
768
+ elif query_type == "allergy_list":
769
+ expected_items = expected.get("substances", [])
770
+ else:
771
+ expected_items = []
772
+
773
+ print(f" [{i+1}/{min(4, len(text_cases))}] {query[:50]}...")
774
+ print(f" Expected {len(expected_items)} items: {[x[:30] for x in expected_items[:3]]}...")
775
+
776
+ result = await evaluate_text_query(
777
+ patient_id, query, query_type, expected_items, case_id
778
+ )
779
+ text_results.append(result)
780
+
781
+ if result.success:
782
+ print(f" ✓ PASS ({result.accuracy:.0%} - found {len(result.found_items)}/{len(expected_items)})")
783
+ else:
784
+ print(f" ✗ FAIL ({result.accuracy:.0%} - found {len(result.found_items)}/{len(expected_items)})")
785
+ if result.missing_items:
786
+ print(f" Missing: {result.missing_items[:3]}")
787
+
788
+ # === AGGREGATE RESULTS ===
789
+ numeric_summary = aggregate_llm_results(llm_results)
790
+ text_summary = aggregate_text_results(text_results) if text_results else {}
791
 
792
  print("\n" + "="*60)
793
  print("LLM RESPONSE ACCURACY REPORT")
794
  print("="*60)
795
+
796
+ print("\n📊 NUMERIC ACCURACY (Vital Charts):")
797
+ print(f" Test Cases: {numeric_summary['total_cases']}")
798
+ print(f" Success Rate: {numeric_summary['success_rate']}")
799
+ print(f" Number Accuracy: {numeric_summary['number_accuracy']}")
800
+
801
+ if text_summary:
802
+ print("\n📝 TEXT ACCURACY (Medications, Conditions, Allergies):")
803
+ print(f" Test Cases: {text_summary['total_cases']}")
804
+ print(f" Success Rate: {text_summary['success_rate']}")
805
+ print(f" Item Recall: {text_summary['item_recall']}")
806
+ if text_summary.get('by_type'):
807
+ for qtype, stats in text_summary['by_type'].items():
808
+ print(f" {qtype}: {stats['passed']}/{stats['total']} passed ({stats['avg_accuracy']})")
809
+
810
  print("="*60)
811
 
812
  return {
813
  "success": True,
814
  "mode": "llm",
815
  "patients_tested": patients,
816
+ "metrics": {
817
+ "numeric": numeric_summary,
818
+ "text": text_summary
819
+ }
820
  }
821
 
822
  elif mode == "agent":