Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test Case Generator for Pre-Visit Summary Evaluation | |
| Generates test cases from Synthea patient data with known ground truth. | |
| """ | |
| import sqlite3 | |
| import random | |
| from datetime import datetime, timedelta | |
| from typing import List, Dict, Any | |
| import os | |
| DB_PATH = os.getenv("DB_PATH", "data/fhir.db") | |
| def get_db(): | |
| """Get database connection.""" | |
| conn = sqlite3.connect(DB_PATH) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def get_test_patients(limit: int = 10) -> List[Dict]: | |
| """Get patients that have sufficient data for testing.""" | |
| conn = get_db() | |
| try: | |
| # Find patients with good data coverage | |
| cursor = conn.execute(""" | |
| SELECT p.id, p.given_name, p.family_name, p.birth_date, p.gender, | |
| (SELECT COUNT(*) FROM conditions WHERE patient_id = p.id) as condition_count, | |
| (SELECT COUNT(*) FROM medications WHERE patient_id = p.id) as med_count, | |
| (SELECT COUNT(*) FROM observations WHERE patient_id = p.id) as obs_count, | |
| (SELECT COUNT(*) FROM allergies WHERE patient_id = p.id) as allergy_count, | |
| (SELECT COUNT(*) FROM immunizations WHERE patient_id = p.id) as imm_count, | |
| (SELECT COUNT(*) FROM procedures WHERE patient_id = p.id) as proc_count, | |
| (SELECT COUNT(*) FROM encounters WHERE patient_id = p.id) as enc_count | |
| FROM patients p | |
| WHERE (SELECT COUNT(*) FROM observations WHERE patient_id = p.id) > 10 | |
| ORDER BY obs_count DESC | |
| LIMIT ? | |
| """, (limit,)) | |
| patients = [] | |
| for row in cursor.fetchall(): | |
| patients.append({ | |
| "patient_id": row["id"], | |
| "name": f"{row['given_name']} {row['family_name']}", | |
| "birth_date": row["birth_date"], | |
| "gender": row["gender"], | |
| "data_counts": { | |
| "conditions": row["condition_count"], | |
| "medications": row["med_count"], | |
| "observations": row["obs_count"], | |
| "allergies": row["allergy_count"], | |
| "immunizations": row["imm_count"], | |
| "procedures": row["proc_count"], | |
| "encounters": row["enc_count"] | |
| } | |
| }) | |
| return patients | |
| finally: | |
| conn.close() | |
| def generate_vital_trend_cases(patient_id: str, days: int = 30) -> List[Dict]: | |
| """Generate test cases for vital sign trends (BP, heart rate, etc.).""" | |
| test_cases = [] | |
| vital_types = [ | |
| ("blood_pressure", ["8480-6", "8462-4"], ["systolic", "diastolic"]), | |
| ("heart_rate", ["8867-4"], ["heart_rate"]), | |
| ("weight", ["29463-7"], ["weight"]), | |
| ("temperature", ["8310-5"], ["temperature"]), | |
| ("oxygen_saturation", ["2708-6"], ["oxygen_saturation"]), | |
| ] | |
| conn = get_db() | |
| try: | |
| for vital_name, codes, labels in vital_types: | |
| # Check if patient has this vital data | |
| placeholders = ",".join(["?" for _ in codes]) | |
| cursor = conn.execute(f""" | |
| SELECT COUNT(*) as cnt FROM observations | |
| WHERE patient_id = ? AND code IN ({placeholders}) | |
| """, [patient_id] + codes) | |
| count = cursor.fetchone()["cnt"] | |
| if count >= 3: # Need at least 3 readings for meaningful test | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_vital_{vital_name}", | |
| "patient_id": patient_id, | |
| "query_type": "vital_trend", | |
| "query": f"Show me my {vital_name.replace('_', ' ')} chart", | |
| "parameters": { | |
| "vital_type": vital_name, | |
| "days": days, | |
| "codes": codes, | |
| "labels": labels | |
| } | |
| }) | |
| finally: | |
| conn.close() | |
| return test_cases | |
| def generate_medication_cases(patient_id: str) -> List[Dict]: | |
| """Generate test cases for medication queries.""" | |
| test_cases = [] | |
| conn = get_db() | |
| try: | |
| # Check if patient has medications | |
| cursor = conn.execute(""" | |
| SELECT COUNT(*) as total, | |
| SUM(CASE WHEN status = 'active' THEN 1 ELSE 0 END) as active | |
| FROM medications WHERE patient_id = ? | |
| """, (patient_id,)) | |
| row = cursor.fetchone() | |
| if row["total"] > 0: | |
| # All medications | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_meds_all", | |
| "patient_id": patient_id, | |
| "query_type": "medication_list", | |
| "query": "What medications am I taking?", | |
| "parameters": {"status": None} | |
| }) | |
| # Active only | |
| if row["active"] > 0: | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_meds_active", | |
| "patient_id": patient_id, | |
| "query_type": "medication_list", | |
| "query": "What are my current active medications?", | |
| "parameters": {"status": "active"} | |
| }) | |
| finally: | |
| conn.close() | |
| return test_cases | |
| def generate_condition_cases(patient_id: str) -> List[Dict]: | |
| """Generate test cases for condition queries.""" | |
| test_cases = [] | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT COUNT(*) as cnt FROM conditions WHERE patient_id = ? | |
| """, (patient_id,)) | |
| if cursor.fetchone()["cnt"] > 0: | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_conditions", | |
| "patient_id": patient_id, | |
| "query_type": "condition_list", | |
| "query": "What are my medical conditions?", | |
| "parameters": {} | |
| }) | |
| finally: | |
| conn.close() | |
| return test_cases | |
| def generate_allergy_cases(patient_id: str) -> List[Dict]: | |
| """Generate test cases for allergy queries.""" | |
| test_cases = [] | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT COUNT(*) as cnt FROM allergies WHERE patient_id = ? | |
| """, (patient_id,)) | |
| if cursor.fetchone()["cnt"] > 0: | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_allergies", | |
| "patient_id": patient_id, | |
| "query_type": "allergy_list", | |
| "query": "What are my allergies?", | |
| "parameters": {} | |
| }) | |
| finally: | |
| conn.close() | |
| return test_cases | |
| def generate_immunization_cases(patient_id: str) -> List[Dict]: | |
| """Generate test cases for immunization queries.""" | |
| test_cases = [] | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT COUNT(*) as cnt FROM immunizations WHERE patient_id = ? | |
| """, (patient_id,)) | |
| if cursor.fetchone()["cnt"] > 0: | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_immunizations", | |
| "patient_id": patient_id, | |
| "query_type": "immunization_list", | |
| "query": "What immunizations have I had?", | |
| "parameters": {} | |
| }) | |
| finally: | |
| conn.close() | |
| return test_cases | |
| def generate_procedure_cases(patient_id: str) -> List[Dict]: | |
| """Generate test cases for procedure/surgical history queries.""" | |
| test_cases = [] | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT COUNT(*) as cnt FROM procedures WHERE patient_id = ? | |
| """, (patient_id,)) | |
| if cursor.fetchone()["cnt"] > 0: | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_procedures", | |
| "patient_id": patient_id, | |
| "query_type": "procedure_list", | |
| "query": "What procedures or surgeries have I had?", | |
| "parameters": {} | |
| }) | |
| finally: | |
| conn.close() | |
| return test_cases | |
| def generate_encounter_cases(patient_id: str) -> List[Dict]: | |
| """Generate test cases for encounter history queries.""" | |
| test_cases = [] | |
| conn = get_db() | |
| try: | |
| cursor = conn.execute(""" | |
| SELECT COUNT(*) as cnt FROM encounters WHERE patient_id = ? | |
| """, (patient_id,)) | |
| if cursor.fetchone()["cnt"] > 0: | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_encounters", | |
| "patient_id": patient_id, | |
| "query_type": "encounter_list", | |
| "query": "Show me my recent visits", | |
| "parameters": {"limit": 5} | |
| }) | |
| finally: | |
| conn.close() | |
| return test_cases | |
| def generate_lab_cases(patient_id: str) -> List[Dict]: | |
| """Generate test cases for lab result queries.""" | |
| test_cases = [] | |
| lab_types = [ | |
| ("a1c", "4548-4", "HbA1c"), | |
| ("cholesterol", "2093-3", "Total Cholesterol"), | |
| ("glucose", "2345-7", "Glucose"), | |
| ] | |
| conn = get_db() | |
| try: | |
| for lab_name, code, display in lab_types: | |
| cursor = conn.execute(""" | |
| SELECT COUNT(*) as cnt FROM observations | |
| WHERE patient_id = ? AND code = ? | |
| """, (patient_id, code)) | |
| if cursor.fetchone()["cnt"] >= 2: | |
| test_cases.append({ | |
| "case_id": f"{patient_id}_lab_{lab_name}", | |
| "patient_id": patient_id, | |
| "query_type": "lab_trend", | |
| "query": f"Show me my {display} history", | |
| "parameters": { | |
| "lab_type": lab_name, | |
| "code": code | |
| } | |
| }) | |
| finally: | |
| conn.close() | |
| return test_cases | |
| def generate_all_test_cases(num_patients: int = 10) -> List[Dict]: | |
| """Generate complete test suite from available patients.""" | |
| patients = get_test_patients(num_patients) | |
| all_cases = [] | |
| for patient in patients: | |
| pid = patient["patient_id"] | |
| # Generate cases for each data type | |
| all_cases.extend(generate_vital_trend_cases(pid)) | |
| all_cases.extend(generate_medication_cases(pid)) | |
| all_cases.extend(generate_condition_cases(pid)) | |
| all_cases.extend(generate_allergy_cases(pid)) | |
| all_cases.extend(generate_immunization_cases(pid)) | |
| all_cases.extend(generate_procedure_cases(pid)) | |
| all_cases.extend(generate_encounter_cases(pid)) | |
| all_cases.extend(generate_lab_cases(pid)) | |
| return all_cases | |
| def get_test_summary(test_cases: List[Dict]) -> Dict: | |
| """Get summary of generated test cases.""" | |
| summary = { | |
| "total_cases": len(test_cases), | |
| "by_type": {}, | |
| "by_patient": {} | |
| } | |
| for case in test_cases: | |
| # Count by type | |
| qtype = case["query_type"] | |
| summary["by_type"][qtype] = summary["by_type"].get(qtype, 0) + 1 | |
| # Count by patient | |
| pid = case["patient_id"] | |
| summary["by_patient"][pid] = summary["by_patient"].get(pid, 0) + 1 | |
| return summary | |
| if __name__ == "__main__": | |
| # Test the generator | |
| print("Generating test cases...") | |
| cases = generate_all_test_cases(num_patients=5) | |
| summary = get_test_summary(cases) | |
| print(f"\nTotal test cases: {summary['total_cases']}") | |
| print("\nBy query type:") | |
| for qtype, count in sorted(summary["by_type"].items()): | |
| print(f" {qtype}: {count}") | |
| print("\nSample test case:") | |
| if cases: | |
| import json | |
| print(json.dumps(cases[0], indent=2)) | |