#!/usr/bin/env python3 """ Expected Values Calculator Computes ground truth values directly from the database for each test case type. These are the values we expect the LLM to report. """ import sqlite3 from datetime import datetime, timedelta from typing import Dict, List, Any, Optional import os import statistics DB_PATH = os.getenv("DB_PATH", "data/fhir.db") def get_db(): """Get database connection.""" conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row return conn def compute_vital_trend_expected(patient_id: str, vital_type: str, codes: List[str], labels: List[str], days: int = 30) -> Dict[str, Any]: """ Compute expected values for vital trend queries. Returns expected facts like min, max, avg, count, date range. """ conn = get_db() try: cutoff_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") result = { "query_type": "vital_trend", "vital_type": vital_type, "days": days, "metrics": {} } for code, label in zip(codes, labels): cursor = conn.execute(""" SELECT value_quantity, effective_date FROM observations WHERE patient_id = ? AND code = ? AND effective_date >= ? ORDER BY effective_date ASC """, (patient_id, code, cutoff_date)) rows = cursor.fetchall() values = [r["value_quantity"] for r in rows if r["value_quantity"] is not None] dates = [r["effective_date"][:10] for r in rows] if values: result["metrics"][label] = { "min": round(min(values), 1), "max": round(max(values), 1), "avg": round(statistics.mean(values), 1), "count": len(values), "latest": round(values[-1], 1), "earliest_date": dates[0] if dates else None, "latest_date": dates[-1] if dates else None, "all_values": [round(v, 1) for v in values], "all_dates": dates } return result finally: conn.close() def compute_medication_expected(patient_id: str, status: Optional[str] = None) -> Dict[str, Any]: """ Compute expected values for medication queries. Returns list of medications with their details. """ conn = get_db() try: if status: cursor = conn.execute(""" SELECT code, display, status, start_date FROM medications WHERE patient_id = ? AND status = ? ORDER BY start_date DESC """, (patient_id, status)) else: cursor = conn.execute(""" SELECT code, display, status, start_date FROM medications WHERE patient_id = ? ORDER BY start_date DESC """, (patient_id,)) medications = [] for row in cursor.fetchall(): medications.append({ "code": row["code"], "display": row["display"], "status": row["status"], "start_date": row["start_date"][:10] if row["start_date"] else None }) return { "query_type": "medication_list", "status_filter": status, "count": len(medications), "medications": medications, "medication_names": [m["display"] for m in medications] } finally: conn.close() def compute_condition_expected(patient_id: str) -> Dict[str, Any]: """ Compute expected values for condition queries. """ conn = get_db() try: cursor = conn.execute(""" SELECT code, display, clinical_status, onset_date FROM conditions WHERE patient_id = ? ORDER BY onset_date DESC """, (patient_id,)) conditions = [] for row in cursor.fetchall(): conditions.append({ "code": row["code"], "display": row["display"], "clinical_status": row["clinical_status"], "onset_date": row["onset_date"][:10] if row["onset_date"] else None }) return { "query_type": "condition_list", "count": len(conditions), "conditions": conditions, "condition_names": [c["display"] for c in conditions] } finally: conn.close() def compute_allergy_expected(patient_id: str) -> Dict[str, Any]: """ Compute expected values for allergy queries. """ conn = get_db() try: cursor = conn.execute(""" SELECT substance, reaction_display, criticality, category FROM allergies WHERE patient_id = ? """, (patient_id,)) allergies = [] for row in cursor.fetchall(): allergies.append({ "substance": row["substance"], "reaction": row["reaction_display"], "criticality": row["criticality"], "category": row["category"] }) return { "query_type": "allergy_list", "count": len(allergies), "allergies": allergies, "substances": [a["substance"] for a in allergies] } finally: conn.close() def compute_immunization_expected(patient_id: str) -> Dict[str, Any]: """ Compute expected values for immunization queries. """ conn = get_db() try: cursor = conn.execute(""" SELECT vaccine_code, vaccine_display, status, occurrence_date FROM immunizations WHERE patient_id = ? ORDER BY occurrence_date DESC """, (patient_id,)) immunizations = [] for row in cursor.fetchall(): immunizations.append({ "vaccine_code": row["vaccine_code"], "vaccine_display": row["vaccine_display"], "status": row["status"], "occurrence_date": row["occurrence_date"][:10] if row["occurrence_date"] else None }) return { "query_type": "immunization_list", "count": len(immunizations), "immunizations": immunizations, "vaccine_names": [i["vaccine_display"] for i in immunizations] } finally: conn.close() def compute_procedure_expected(patient_id: str) -> Dict[str, Any]: """ Compute expected values for procedure queries. """ conn = get_db() try: cursor = conn.execute(""" SELECT code, display, status, performed_date FROM procedures WHERE patient_id = ? ORDER BY performed_date DESC """, (patient_id,)) procedures = [] for row in cursor.fetchall(): procedures.append({ "code": row["code"], "display": row["display"], "status": row["status"], "performed_date": row["performed_date"][:10] if row["performed_date"] else None }) return { "query_type": "procedure_list", "count": len(procedures), "procedures": procedures, "procedure_names": [p["display"] for p in procedures] } finally: conn.close() def compute_encounter_expected(patient_id: str, limit: int = 5) -> Dict[str, Any]: """ Compute expected values for encounter queries. """ conn = get_db() try: cursor = conn.execute(""" SELECT type_display, reason_display, period_start, period_end, class_display FROM encounters WHERE patient_id = ? ORDER BY period_start DESC LIMIT ? """, (patient_id, limit)) encounters = [] for row in cursor.fetchall(): encounters.append({ "type": row["type_display"], "reason": row["reason_display"], "class": row["class_display"], "start_date": row["period_start"][:10] if row["period_start"] else None, "end_date": row["period_end"][:10] if row["period_end"] else None }) return { "query_type": "encounter_list", "count": len(encounters), "limit": limit, "encounters": encounters } finally: conn.close() def compute_lab_trend_expected(patient_id: str, lab_type: str, code: str, periods: int = 4) -> Dict[str, Any]: """ Compute expected values for lab trend queries. """ conn = get_db() try: cursor = conn.execute(""" SELECT value_quantity, effective_date, value_unit FROM observations WHERE patient_id = ? AND code = ? ORDER BY effective_date DESC LIMIT ? """, (patient_id, code, periods)) rows = cursor.fetchall() values = [r["value_quantity"] for r in rows if r["value_quantity"] is not None] dates = [r["effective_date"][:10] for r in rows] unit = rows[0]["value_unit"] if rows else None result = { "query_type": "lab_trend", "lab_type": lab_type, "code": code, "unit": unit, "count": len(values) } if values: result["metrics"] = { "min": round(min(values), 1), "max": round(max(values), 1), "avg": round(statistics.mean(values), 1), "latest": round(values[0], 1), # Most recent "latest_date": dates[0] if dates else None, "all_values": [round(v, 1) for v in values], "all_dates": dates } return result finally: conn.close() def compute_expected_values(test_case: Dict) -> Dict[str, Any]: """ Compute expected values for any test case type. Routes to the appropriate computation function. """ query_type = test_case["query_type"] patient_id = test_case["patient_id"] params = test_case.get("parameters", {}) if query_type == "vital_trend": return compute_vital_trend_expected( patient_id, params["vital_type"], params["codes"], params["labels"], params.get("days", 30) ) elif query_type == "medication_list": return compute_medication_expected(patient_id, params.get("status")) elif query_type == "condition_list": return compute_condition_expected(patient_id) elif query_type == "allergy_list": return compute_allergy_expected(patient_id) elif query_type == "immunization_list": return compute_immunization_expected(patient_id) elif query_type == "procedure_list": return compute_procedure_expected(patient_id) elif query_type == "encounter_list": return compute_encounter_expected(patient_id, params.get("limit", 5)) elif query_type == "lab_trend": return compute_lab_trend_expected( patient_id, params["lab_type"], params["code"], params.get("periods", 4) ) else: return {"error": f"Unknown query type: {query_type}"} if __name__ == "__main__": # Test with a sample case from test_generator import generate_all_test_cases import json print("Generating test cases...") cases = generate_all_test_cases(num_patients=1) print(f"\nComputing expected values for {len(cases)} test cases...") for case in cases[:3]: # Show first 3 print(f"\n{'='*60}") print(f"Case: {case['case_id']}") print(f"Query: {case['query']}") expected = compute_expected_values(case) print(f"Expected values:") print(json.dumps(expected, indent=2, default=str))