Spaces:
Sleeping
Sleeping
frabbani commited on
Commit ·
82a0e99
1
Parent(s): 3ade5b9
Fix fact extraction - pass raw data for simple tools..................,,,
Browse files- evaluation/llm_eval.py +204 -0
- server.py +68 -28
evaluation/llm_eval.py
CHANGED
|
@@ -238,6 +238,210 @@ def extract_numbers_from_text(text: str) -> Dict[str, Any]:
|
|
| 238 |
return numbers
|
| 239 |
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
def extract_numbers_from_chart(chart_data: Dict) -> Dict[str, Any]:
|
| 242 |
"""
|
| 243 |
Extract numerical values from chart data returned by tools.
|
|
|
|
| 238 |
return numbers
|
| 239 |
|
| 240 |
|
| 241 |
+
def extract_medication_names(text: str, expected_meds: List[str]) -> Dict[str, Any]:
|
| 242 |
+
"""
|
| 243 |
+
Extract medication names from LLM response and compare against expected list.
|
| 244 |
+
|
| 245 |
+
Uses fuzzy matching since LLM might abbreviate or paraphrase.
|
| 246 |
+
"""
|
| 247 |
+
text_lower = text.lower()
|
| 248 |
+
|
| 249 |
+
found = []
|
| 250 |
+
missing = []
|
| 251 |
+
|
| 252 |
+
for med in expected_meds:
|
| 253 |
+
med_lower = med.lower()
|
| 254 |
+
# Extract the drug name (first word or two before dosage)
|
| 255 |
+
# "Metformin 500 MG Oral Tablet" -> "metformin"
|
| 256 |
+
drug_name = med_lower.split()[0] if med_lower else ""
|
| 257 |
+
|
| 258 |
+
# Check if drug name appears in text
|
| 259 |
+
if drug_name and drug_name in text_lower:
|
| 260 |
+
found.append(med)
|
| 261 |
+
# Also check for common abbreviations or alternate names
|
| 262 |
+
elif any(part in text_lower for part in med_lower.split()[:2] if len(part) > 3):
|
| 263 |
+
found.append(med)
|
| 264 |
+
else:
|
| 265 |
+
missing.append(med)
|
| 266 |
+
|
| 267 |
+
return {
|
| 268 |
+
"found": found,
|
| 269 |
+
"missing": missing,
|
| 270 |
+
"found_count": len(found),
|
| 271 |
+
"expected_count": len(expected_meds),
|
| 272 |
+
"accuracy": len(found) / len(expected_meds) if expected_meds else 1.0
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def extract_condition_names(text: str, expected_conditions: List[str]) -> Dict[str, Any]:
|
| 277 |
+
"""
|
| 278 |
+
Extract condition names from LLM response and compare against expected list.
|
| 279 |
+
"""
|
| 280 |
+
text_lower = text.lower()
|
| 281 |
+
|
| 282 |
+
found = []
|
| 283 |
+
missing = []
|
| 284 |
+
|
| 285 |
+
for condition in expected_conditions:
|
| 286 |
+
cond_lower = condition.lower()
|
| 287 |
+
# Check for key words from condition name
|
| 288 |
+
key_words = [w for w in cond_lower.split() if len(w) > 4]
|
| 289 |
+
|
| 290 |
+
# If any significant word from condition appears in text
|
| 291 |
+
if any(word in text_lower for word in key_words):
|
| 292 |
+
found.append(condition)
|
| 293 |
+
# Also check for common abbreviations
|
| 294 |
+
elif "diabetes" in cond_lower and ("diabetes" in text_lower or "diabetic" in text_lower):
|
| 295 |
+
found.append(condition)
|
| 296 |
+
elif "hypertension" in cond_lower and ("hypertension" in text_lower or "blood pressure" in text_lower or "htn" in text_lower):
|
| 297 |
+
found.append(condition)
|
| 298 |
+
elif "cholesterol" in cond_lower and ("cholesterol" in text_lower or "hyperlipidemia" in text_lower):
|
| 299 |
+
found.append(condition)
|
| 300 |
+
else:
|
| 301 |
+
missing.append(condition)
|
| 302 |
+
|
| 303 |
+
return {
|
| 304 |
+
"found": found,
|
| 305 |
+
"missing": missing,
|
| 306 |
+
"found_count": len(found),
|
| 307 |
+
"expected_count": len(expected_conditions),
|
| 308 |
+
"accuracy": len(found) / len(expected_conditions) if expected_conditions else 1.0
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def extract_allergy_names(text: str, expected_allergies: List[str]) -> Dict[str, Any]:
|
| 313 |
+
"""
|
| 314 |
+
Extract allergy/allergen names from LLM response.
|
| 315 |
+
"""
|
| 316 |
+
text_lower = text.lower()
|
| 317 |
+
|
| 318 |
+
found = []
|
| 319 |
+
missing = []
|
| 320 |
+
|
| 321 |
+
for allergy in expected_allergies:
|
| 322 |
+
allergy_lower = allergy.lower()
|
| 323 |
+
# Check if allergen name appears
|
| 324 |
+
if allergy_lower in text_lower:
|
| 325 |
+
found.append(allergy)
|
| 326 |
+
# Check key words
|
| 327 |
+
elif any(word in text_lower for word in allergy_lower.split() if len(word) > 3):
|
| 328 |
+
found.append(allergy)
|
| 329 |
+
else:
|
| 330 |
+
missing.append(allergy)
|
| 331 |
+
|
| 332 |
+
return {
|
| 333 |
+
"found": found,
|
| 334 |
+
"missing": missing,
|
| 335 |
+
"found_count": len(found),
|
| 336 |
+
"expected_count": len(expected_allergies),
|
| 337 |
+
"accuracy": len(found) / len(expected_allergies) if expected_allergies else 1.0
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@dataclass
|
| 342 |
+
class TextComparisonResult:
|
| 343 |
+
"""Result of comparing LLM text response against expected items."""
|
| 344 |
+
case_id: str
|
| 345 |
+
query: str
|
| 346 |
+
query_type: str
|
| 347 |
+
success: bool
|
| 348 |
+
expected_items: List[str] = field(default_factory=list)
|
| 349 |
+
found_items: List[str] = field(default_factory=list)
|
| 350 |
+
missing_items: List[str] = field(default_factory=list)
|
| 351 |
+
accuracy: float = 0.0
|
| 352 |
+
raw_response: str = ""
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
async def evaluate_text_query(
|
| 356 |
+
patient_id: str,
|
| 357 |
+
query: str,
|
| 358 |
+
query_type: str,
|
| 359 |
+
expected_items: List[str],
|
| 360 |
+
case_id: str = ""
|
| 361 |
+
) -> TextComparisonResult:
|
| 362 |
+
"""
|
| 363 |
+
Evaluate LLM response for text-based queries (medications, conditions, allergies).
|
| 364 |
+
"""
|
| 365 |
+
result = TextComparisonResult(
|
| 366 |
+
case_id=case_id,
|
| 367 |
+
query=query,
|
| 368 |
+
query_type=query_type,
|
| 369 |
+
success=False,
|
| 370 |
+
expected_items=expected_items
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
# Call the agent
|
| 374 |
+
llm_response = await call_agent_endpoint(patient_id, query, timeout=90.0)
|
| 375 |
+
|
| 376 |
+
if llm_response.error:
|
| 377 |
+
result.raw_response = f"Error: {llm_response.error}"
|
| 378 |
+
return result
|
| 379 |
+
|
| 380 |
+
result.raw_response = llm_response.raw_response
|
| 381 |
+
|
| 382 |
+
# Compare based on query type
|
| 383 |
+
if query_type == "medication_list":
|
| 384 |
+
comparison = extract_medication_names(llm_response.raw_response, expected_items)
|
| 385 |
+
elif query_type == "condition_list":
|
| 386 |
+
comparison = extract_condition_names(llm_response.raw_response, expected_items)
|
| 387 |
+
elif query_type == "allergy_list":
|
| 388 |
+
comparison = extract_allergy_names(llm_response.raw_response, expected_items)
|
| 389 |
+
else:
|
| 390 |
+
# Generic text comparison
|
| 391 |
+
comparison = extract_condition_names(llm_response.raw_response, expected_items)
|
| 392 |
+
|
| 393 |
+
result.found_items = comparison["found"]
|
| 394 |
+
result.missing_items = comparison["missing"]
|
| 395 |
+
result.accuracy = comparison["accuracy"]
|
| 396 |
+
result.success = comparison["accuracy"] >= 0.7 # 70% threshold
|
| 397 |
+
|
| 398 |
+
return result
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def aggregate_text_results(results: List[TextComparisonResult]) -> Dict[str, Any]:
|
| 402 |
+
"""Aggregate text evaluation results."""
|
| 403 |
+
if not results:
|
| 404 |
+
return {"total_cases": 0, "message": "No text cases evaluated"}
|
| 405 |
+
|
| 406 |
+
total = len(results)
|
| 407 |
+
successful = sum(1 for r in results if r.success)
|
| 408 |
+
|
| 409 |
+
total_expected = sum(len(r.expected_items) for r in results)
|
| 410 |
+
total_found = sum(len(r.found_items) for r in results)
|
| 411 |
+
|
| 412 |
+
by_type = {}
|
| 413 |
+
for r in results:
|
| 414 |
+
if r.query_type not in by_type:
|
| 415 |
+
by_type[r.query_type] = {"total": 0, "passed": 0, "accuracy_sum": 0}
|
| 416 |
+
by_type[r.query_type]["total"] += 1
|
| 417 |
+
by_type[r.query_type]["passed"] += 1 if r.success else 0
|
| 418 |
+
by_type[r.query_type]["accuracy_sum"] += r.accuracy
|
| 419 |
+
|
| 420 |
+
# Compute average accuracy per type
|
| 421 |
+
for qtype in by_type:
|
| 422 |
+
by_type[qtype]["avg_accuracy"] = f"{by_type[qtype]['accuracy_sum'] / by_type[qtype]['total'] * 100:.1f}%"
|
| 423 |
+
|
| 424 |
+
return {
|
| 425 |
+
"total_cases": total,
|
| 426 |
+
"successful_cases": successful,
|
| 427 |
+
"failed_cases": total - successful,
|
| 428 |
+
"success_rate": f"{(successful/total*100):.1f}%",
|
| 429 |
+
"total_expected_items": total_expected,
|
| 430 |
+
"total_found_items": total_found,
|
| 431 |
+
"item_recall": f"{(total_found/total_expected*100):.1f}%" if total_expected > 0 else "N/A",
|
| 432 |
+
"by_type": by_type,
|
| 433 |
+
"failed_details": [
|
| 434 |
+
{
|
| 435 |
+
"case_id": r.case_id,
|
| 436 |
+
"query_type": r.query_type,
|
| 437 |
+
"accuracy": f"{r.accuracy:.0%}",
|
| 438 |
+
"missing": r.missing_items[:5]
|
| 439 |
+
}
|
| 440 |
+
for r in results if not r.success
|
| 441 |
+
]
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
|
| 445 |
def extract_numbers_from_chart(chart_data: Dict) -> Dict[str, Any]:
|
| 446 |
"""
|
| 447 |
Extract numerical values from chart data returned by tools.
|
server.py
CHANGED
|
@@ -698,32 +698,27 @@ async def run_evaluation(
|
|
| 698 |
extract_numbers_from_text,
|
| 699 |
compare_llm_response,
|
| 700 |
aggregate_llm_results,
|
| 701 |
-
LLMComparisonResult
|
|
|
|
|
|
|
| 702 |
)
|
| 703 |
|
| 704 |
print("\nRunning FULL LLM evaluation (this calls actual MedGemma)...")
|
| 705 |
-
print("Note: Only testing vital_trend queries (charts) for number accuracy\n")
|
| 706 |
|
| 707 |
-
#
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
if not vital_cases:
|
| 711 |
-
return {
|
| 712 |
-
"success": False,
|
| 713 |
-
"error": "No vital trend test cases found"
|
| 714 |
-
}
|
| 715 |
|
|
|
|
| 716 |
llm_results = []
|
| 717 |
|
| 718 |
-
for i, test_case in enumerate(vital_cases[:
|
| 719 |
patient_id = test_case["patient_id"]
|
| 720 |
query = test_case["query"]
|
| 721 |
case_id = test_case["case_id"]
|
| 722 |
expected = compute_expected_values(test_case)
|
| 723 |
|
| 724 |
-
print(f" [{i+1}/{min(
|
| 725 |
|
| 726 |
-
# Call actual LLM
|
| 727 |
llm_response = await call_agent_endpoint(patient_id, query, timeout=90.0)
|
| 728 |
|
| 729 |
if llm_response.error:
|
|
@@ -735,17 +730,12 @@ async def run_evaluation(
|
|
| 735 |
errors=[llm_response.error]
|
| 736 |
))
|
| 737 |
else:
|
| 738 |
-
# Extract numbers from chart (ground truth) and text (LLM said)
|
| 739 |
chart_nums = extract_numbers_from_chart(llm_response.chart_data)
|
| 740 |
text_nums = extract_numbers_from_text(llm_response.raw_response)
|
| 741 |
|
| 742 |
-
# Debug: show first 300 chars of LLM response
|
| 743 |
-
print(f" LLM response (first 300 chars):")
|
| 744 |
-
print(f" {llm_response.raw_response[:300].replace(chr(10), ' ')}")
|
| 745 |
print(f" Chart numbers: {chart_nums}")
|
| 746 |
print(f" Text numbers: {text_nums}")
|
| 747 |
|
| 748 |
-
# Compare
|
| 749 |
result = compare_llm_response(llm_response, expected)
|
| 750 |
result.case_id = case_id
|
| 751 |
llm_results.append(result)
|
|
@@ -757,26 +747,76 @@ async def run_evaluation(
|
|
| 757 |
for err in result.errors[:3]:
|
| 758 |
print(f" - {err}")
|
| 759 |
|
| 760 |
-
#
|
| 761 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
|
| 763 |
print("\n" + "="*60)
|
| 764 |
print("LLM RESPONSE ACCURACY REPORT")
|
| 765 |
print("="*60)
|
| 766 |
-
|
| 767 |
-
print(
|
| 768 |
-
print(f"
|
| 769 |
-
print(f"Success Rate: {
|
| 770 |
-
print(f"Number
|
| 771 |
-
|
| 772 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
print("="*60)
|
| 774 |
|
| 775 |
return {
|
| 776 |
"success": True,
|
| 777 |
"mode": "llm",
|
| 778 |
"patients_tested": patients,
|
| 779 |
-
"metrics":
|
|
|
|
|
|
|
|
|
|
| 780 |
}
|
| 781 |
|
| 782 |
elif mode == "agent":
|
|
|
|
| 698 |
extract_numbers_from_text,
|
| 699 |
compare_llm_response,
|
| 700 |
aggregate_llm_results,
|
| 701 |
+
LLMComparisonResult,
|
| 702 |
+
evaluate_text_query,
|
| 703 |
+
aggregate_text_results
|
| 704 |
)
|
| 705 |
|
| 706 |
print("\nRunning FULL LLM evaluation (this calls actual MedGemma)...")
|
|
|
|
| 707 |
|
| 708 |
+
# === PART 1: NUMERIC EVALUATION (Vitals) ===
|
| 709 |
+
print("\n--- PART 1: NUMERIC ACCURACY (Vital Charts) ---\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
|
| 711 |
+
vital_cases = [tc for tc in test_cases if tc["query_type"] == "vital_trend"]
|
| 712 |
llm_results = []
|
| 713 |
|
| 714 |
+
for i, test_case in enumerate(vital_cases[:4]): # Limit to 4
|
| 715 |
patient_id = test_case["patient_id"]
|
| 716 |
query = test_case["query"]
|
| 717 |
case_id = test_case["case_id"]
|
| 718 |
expected = compute_expected_values(test_case)
|
| 719 |
|
| 720 |
+
print(f" [{i+1}/{min(4, len(vital_cases))}] {query[:50]}...")
|
| 721 |
|
|
|
|
| 722 |
llm_response = await call_agent_endpoint(patient_id, query, timeout=90.0)
|
| 723 |
|
| 724 |
if llm_response.error:
|
|
|
|
| 730 |
errors=[llm_response.error]
|
| 731 |
))
|
| 732 |
else:
|
|
|
|
| 733 |
chart_nums = extract_numbers_from_chart(llm_response.chart_data)
|
| 734 |
text_nums = extract_numbers_from_text(llm_response.raw_response)
|
| 735 |
|
|
|
|
|
|
|
|
|
|
| 736 |
print(f" Chart numbers: {chart_nums}")
|
| 737 |
print(f" Text numbers: {text_nums}")
|
| 738 |
|
|
|
|
| 739 |
result = compare_llm_response(llm_response, expected)
|
| 740 |
result.case_id = case_id
|
| 741 |
llm_results.append(result)
|
|
|
|
| 747 |
for err in result.errors[:3]:
|
| 748 |
print(f" - {err}")
|
| 749 |
|
| 750 |
+
# === PART 2: TEXT EVALUATION (Medications, Conditions, Allergies) ===
|
| 751 |
+
print("\n--- PART 2: TEXT ACCURACY (Medications, Conditions, Allergies) ---\n")
|
| 752 |
+
|
| 753 |
+
text_cases = [tc for tc in test_cases if tc["query_type"] in ["medication_list", "condition_list", "allergy_list"]]
|
| 754 |
+
text_results = []
|
| 755 |
+
|
| 756 |
+
for i, test_case in enumerate(text_cases[:4]): # Limit to 4
|
| 757 |
+
patient_id = test_case["patient_id"]
|
| 758 |
+
query = test_case["query"]
|
| 759 |
+
query_type = test_case["query_type"]
|
| 760 |
+
case_id = test_case["case_id"]
|
| 761 |
+
expected = compute_expected_values(test_case)
|
| 762 |
+
|
| 763 |
+
# Get expected items list based on query type
|
| 764 |
+
if query_type == "medication_list":
|
| 765 |
+
expected_items = expected.get("medication_names", [])
|
| 766 |
+
elif query_type == "condition_list":
|
| 767 |
+
expected_items = expected.get("condition_names", [])
|
| 768 |
+
elif query_type == "allergy_list":
|
| 769 |
+
expected_items = expected.get("substances", [])
|
| 770 |
+
else:
|
| 771 |
+
expected_items = []
|
| 772 |
+
|
| 773 |
+
print(f" [{i+1}/{min(4, len(text_cases))}] {query[:50]}...")
|
| 774 |
+
print(f" Expected {len(expected_items)} items: {[x[:30] for x in expected_items[:3]]}...")
|
| 775 |
+
|
| 776 |
+
result = await evaluate_text_query(
|
| 777 |
+
patient_id, query, query_type, expected_items, case_id
|
| 778 |
+
)
|
| 779 |
+
text_results.append(result)
|
| 780 |
+
|
| 781 |
+
if result.success:
|
| 782 |
+
print(f" ✓ PASS ({result.accuracy:.0%} - found {len(result.found_items)}/{len(expected_items)})")
|
| 783 |
+
else:
|
| 784 |
+
print(f" ✗ FAIL ({result.accuracy:.0%} - found {len(result.found_items)}/{len(expected_items)})")
|
| 785 |
+
if result.missing_items:
|
| 786 |
+
print(f" Missing: {result.missing_items[:3]}")
|
| 787 |
+
|
| 788 |
+
# === AGGREGATE RESULTS ===
|
| 789 |
+
numeric_summary = aggregate_llm_results(llm_results)
|
| 790 |
+
text_summary = aggregate_text_results(text_results) if text_results else {}
|
| 791 |
|
| 792 |
print("\n" + "="*60)
|
| 793 |
print("LLM RESPONSE ACCURACY REPORT")
|
| 794 |
print("="*60)
|
| 795 |
+
|
| 796 |
+
print("\n📊 NUMERIC ACCURACY (Vital Charts):")
|
| 797 |
+
print(f" Test Cases: {numeric_summary['total_cases']}")
|
| 798 |
+
print(f" Success Rate: {numeric_summary['success_rate']}")
|
| 799 |
+
print(f" Number Accuracy: {numeric_summary['number_accuracy']}")
|
| 800 |
+
|
| 801 |
+
if text_summary:
|
| 802 |
+
print("\n📝 TEXT ACCURACY (Medications, Conditions, Allergies):")
|
| 803 |
+
print(f" Test Cases: {text_summary['total_cases']}")
|
| 804 |
+
print(f" Success Rate: {text_summary['success_rate']}")
|
| 805 |
+
print(f" Item Recall: {text_summary['item_recall']}")
|
| 806 |
+
if text_summary.get('by_type'):
|
| 807 |
+
for qtype, stats in text_summary['by_type'].items():
|
| 808 |
+
print(f" {qtype}: {stats['passed']}/{stats['total']} passed ({stats['avg_accuracy']})")
|
| 809 |
+
|
| 810 |
print("="*60)
|
| 811 |
|
| 812 |
return {
|
| 813 |
"success": True,
|
| 814 |
"mode": "llm",
|
| 815 |
"patients_tested": patients,
|
| 816 |
+
"metrics": {
|
| 817 |
+
"numeric": numeric_summary,
|
| 818 |
+
"text": text_summary
|
| 819 |
+
}
|
| 820 |
}
|
| 821 |
|
| 822 |
elif mode == "agent":
|