Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Expected Values Calculator | |
| Computes ground truth values directly from the database for each test case type. | |
| These are the values we expect the LLM to report. | |
| """ | |
| import sqlite3 | |
| from datetime import datetime, timedelta | |
| from typing import Dict, List, Any, Optional | |
| import os | |
| import statistics | |
| DB_PATH = os.getenv("DB_PATH", "data/fhir.db") | |
| def get_db(): | |
| """Get database connection.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def compute_vital_trend_expected(patient_id: str, vital_type: str, codes: List[str], | |
| labels: List[str], days: int = 30) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for vital trend queries. | |
| Returns expected facts like min, max, avg, count, date range. | |
| """ | |
| conn = get_db() | |
| try: | |
| cutoff_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") | |
| result = { | |
| "query_type": "vital_trend", | |
| "vital_type": vital_type, | |
| "days": days, | |
| "metrics": {} | |
| } | |
| for code, label in zip(codes, labels): | |
| cursor = conn.execute(""" | |
| SELECT value_quantity, effective_date | |
| FROM observations | |
| WHERE patient_id = ? AND code = ? AND effective_date >= ? | |
| ORDER BY effective_date ASC | |
| """, (patient_id, code, cutoff_date)) | |
| rows = cursor.fetchall() | |
| values = [r["value_quantity"] for r in rows if r["value_quantity"] is not None] | |
| dates = [r["effective_date"][:10] for r in rows] | |
| if values: | |
| result["metrics"][label] = { | |
| "min": round(min(values), 1), | |
| "max": round(max(values), 1), | |
| "avg": round(statistics.mean(values), 1), | |
| "count": len(values), | |
| "latest": round(values[-1], 1), | |
| "earliest_date": dates[0] if dates else None, | |
| "latest_date": dates[-1] if dates else None, | |
| "all_values": [round(v, 1) for v in values], | |
| "all_dates": dates | |
| } | |
| return result | |
| finally: | |
| conn.close() | |
| def compute_medication_expected(patient_id: str, status: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for medication queries. | |
| Returns list of medications with their details. | |
| """ | |
| conn = get_db() | |
| try: | |
| if status: | |
| cursor = conn.execute(""" | |
| SELECT code, display, status, start_date | |
| FROM medications | |
| WHERE patient_id = ? AND status = ? | |
| ORDER BY start_date DESC | |
| """, (patient_id, status)) | |
| else: | |
| cursor = conn.execute(""" | |
| SELECT code, display, status, start_date | |
| FROM medications | |
| WHERE patient_id = ? | |
| ORDER BY start_date DESC | |
| """, (patient_id,)) | |
| medications = [] | |
| for row in cursor.fetchall(): | |
| medications.append({ | |
| "code": row["code"], | |
| "display": row["display"], | |
| "status": row["status"], | |
| "start_date": row["start_date"][:10] if row["start_date"] else None | |
| }) | |
| return { | |
| "query_type": "medication_list", | |
| "status_filter": status, | |
| "count": len(medications), | |
| "medications": medications, | |
| "medication_names": [m["display"] for m in medications] | |
| } | |
| finally: | |
| conn.close() | |
| def compute_condition_expected(patient_id: str) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for condition queries. | |
| """ | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT code, display, clinical_status, onset_date | |
| FROM conditions | |
| WHERE patient_id = ? | |
| ORDER BY onset_date DESC | |
| """, (patient_id,)) | |
| conditions = [] | |
| for row in cursor.fetchall(): | |
| conditions.append({ | |
| "code": row["code"], | |
| "display": row["display"], | |
| "clinical_status": row["clinical_status"], | |
| "onset_date": row["onset_date"][:10] if row["onset_date"] else None | |
| }) | |
| return { | |
| "query_type": "condition_list", | |
| "count": len(conditions), | |
| "conditions": conditions, | |
| "condition_names": [c["display"] for c in conditions] | |
| } | |
| finally: | |
| conn.close() | |
| def compute_allergy_expected(patient_id: str) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for allergy queries. | |
| """ | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT substance, reaction_display, criticality, category | |
| FROM allergies | |
| WHERE patient_id = ? | |
| """, (patient_id,)) | |
| allergies = [] | |
| for row in cursor.fetchall(): | |
| allergies.append({ | |
| "substance": row["substance"], | |
| "reaction": row["reaction_display"], | |
| "criticality": row["criticality"], | |
| "category": row["category"] | |
| }) | |
| return { | |
| "query_type": "allergy_list", | |
| "count": len(allergies), | |
| "allergies": allergies, | |
| "substances": [a["substance"] for a in allergies] | |
| } | |
| finally: | |
| conn.close() | |
| def compute_immunization_expected(patient_id: str) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for immunization queries. | |
| """ | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT vaccine_code, vaccine_display, status, occurrence_date | |
| FROM immunizations | |
| WHERE patient_id = ? | |
| ORDER BY occurrence_date DESC | |
| """, (patient_id,)) | |
| immunizations = [] | |
| for row in cursor.fetchall(): | |
| immunizations.append({ | |
| "vaccine_code": row["vaccine_code"], | |
| "vaccine_display": row["vaccine_display"], | |
| "status": row["status"], | |
| "occurrence_date": row["occurrence_date"][:10] if row["occurrence_date"] else None | |
| }) | |
| return { | |
| "query_type": "immunization_list", | |
| "count": len(immunizations), | |
| "immunizations": immunizations, | |
| "vaccine_names": [i["vaccine_display"] for i in immunizations] | |
| } | |
| finally: | |
| conn.close() | |
| def compute_procedure_expected(patient_id: str) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for procedure queries. | |
| """ | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT code, display, status, performed_date | |
| FROM procedures | |
| WHERE patient_id = ? | |
| ORDER BY performed_date DESC | |
| """, (patient_id,)) | |
| procedures = [] | |
| for row in cursor.fetchall(): | |
| procedures.append({ | |
| "code": row["code"], | |
| "display": row["display"], | |
| "status": row["status"], | |
| "performed_date": row["performed_date"][:10] if row["performed_date"] else None | |
| }) | |
| return { | |
| "query_type": "procedure_list", | |
| "count": len(procedures), | |
| "procedures": procedures, | |
| "procedure_names": [p["display"] for p in procedures] | |
| } | |
| finally: | |
| conn.close() | |
| def compute_encounter_expected(patient_id: str, limit: int = 5) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for encounter queries. | |
| """ | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT type_display, reason_display, period_start, period_end, class_display | |
| FROM encounters | |
| WHERE patient_id = ? | |
| ORDER BY period_start DESC | |
| LIMIT ? | |
| """, (patient_id, limit)) | |
| encounters = [] | |
| for row in cursor.fetchall(): | |
| encounters.append({ | |
| "type": row["type_display"], | |
| "reason": row["reason_display"], | |
| "class": row["class_display"], | |
| "start_date": row["period_start"][:10] if row["period_start"] else None, | |
| "end_date": row["period_end"][:10] if row["period_end"] else None | |
| }) | |
| return { | |
| "query_type": "encounter_list", | |
| "count": len(encounters), | |
| "limit": limit, | |
| "encounters": encounters | |
| } | |
| finally: | |
| conn.close() | |
| def compute_lab_trend_expected(patient_id: str, lab_type: str, code: str, | |
| periods: int = 4) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for lab trend queries. | |
| """ | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT value_quantity, effective_date, value_unit | |
| FROM observations | |
| WHERE patient_id = ? AND code = ? | |
| ORDER BY effective_date DESC | |
| LIMIT ? | |
| """, (patient_id, code, periods)) | |
| rows = cursor.fetchall() | |
| values = [r["value_quantity"] for r in rows if r["value_quantity"] is not None] | |
| dates = [r["effective_date"][:10] for r in rows] | |
| unit = rows[0]["value_unit"] if rows else None | |
| result = { | |
| "query_type": "lab_trend", | |
| "lab_type": lab_type, | |
| "code": code, | |
| "unit": unit, | |
| "count": len(values) | |
| } | |
| if values: | |
| result["metrics"] = { | |
| "min": round(min(values), 1), | |
| "max": round(max(values), 1), | |
| "avg": round(statistics.mean(values), 1), | |
| "latest": round(values[0], 1), # Most recent | |
| "latest_date": dates[0] if dates else None, | |
| "all_values": [round(v, 1) for v in values], | |
| "all_dates": dates | |
| } | |
| return result | |
| finally: | |
| conn.close() | |
| def compute_expected_values(test_case: Dict) -> Dict[str, Any]: | |
| """ | |
| Compute expected values for any test case type. | |
| Routes to the appropriate computation function. | |
| """ | |
| query_type = test_case["query_type"] | |
| patient_id = test_case["patient_id"] | |
| params = test_case.get("parameters", {}) | |
| if query_type == "vital_trend": | |
| return compute_vital_trend_expected( | |
| patient_id, | |
| params["vital_type"], | |
| params["codes"], | |
| params["labels"], | |
| params.get("days", 30) | |
| ) | |
| elif query_type == "medication_list": | |
| return compute_medication_expected(patient_id, params.get("status")) | |
| elif query_type == "condition_list": | |
| return compute_condition_expected(patient_id) | |
| elif query_type == "allergy_list": | |
| return compute_allergy_expected(patient_id) | |
| elif query_type == "immunization_list": | |
| return compute_immunization_expected(patient_id) | |
| elif query_type == "procedure_list": | |
| return compute_procedure_expected(patient_id) | |
| elif query_type == "encounter_list": | |
| return compute_encounter_expected(patient_id, params.get("limit", 5)) | |
| elif query_type == "lab_trend": | |
| return compute_lab_trend_expected( | |
| patient_id, | |
| params["lab_type"], | |
| params["code"], | |
| params.get("periods", 4) | |
| ) | |
| else: | |
| return {"error": f"Unknown query type: {query_type}"} | |
| if __name__ == "__main__": | |
| # Test with a sample case | |
| from test_generator import generate_all_test_cases | |
| import json | |
| print("Generating test cases...") | |
| cases = generate_all_test_cases(num_patients=1) | |
| print(f"\nComputing expected values for {len(cases)} test cases...") | |
| for case in cases[:3]: # Show first 3 | |
| print(f"\n{'='*60}") | |
| print(f"Case: {case['case_id']}") | |
| print(f"Query: {case['query']}") | |
| expected = compute_expected_values(case) | |
| print(f"Expected values:") | |
| print(json.dumps(expected, indent=2, default=str)) | |