multimodal_previsit / evaluation /test_generator.py
frabbani
Fix fact extraction - pass raw data for simple tools.......
8daa8bf
#!/usr/bin/env python3
"""
Test Case Generator for Pre-Visit Summary Evaluation
Generates test cases from Synthea patient data with known ground truth.
"""
import sqlite3
import random
from datetime import datetime, timedelta
from typing import List, Dict, Any
import os
DB_PATH = os.getenv("DB_PATH", "data/fhir.db")
def get_db():
"""Get database connection."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def get_test_patients(limit: int = 10) -> List[Dict]:
"""Get patients that have sufficient data for testing."""
conn = get_db()
try:
# Find patients with good data coverage
cursor = conn.execute("""
SELECT p.id, p.given_name, p.family_name, p.birth_date, p.gender,
(SELECT COUNT(*) FROM conditions WHERE patient_id = p.id) as condition_count,
(SELECT COUNT(*) FROM medications WHERE patient_id = p.id) as med_count,
(SELECT COUNT(*) FROM observations WHERE patient_id = p.id) as obs_count,
(SELECT COUNT(*) FROM allergies WHERE patient_id = p.id) as allergy_count,
(SELECT COUNT(*) FROM immunizations WHERE patient_id = p.id) as imm_count,
(SELECT COUNT(*) FROM procedures WHERE patient_id = p.id) as proc_count,
(SELECT COUNT(*) FROM encounters WHERE patient_id = p.id) as enc_count
FROM patients p
WHERE (SELECT COUNT(*) FROM observations WHERE patient_id = p.id) > 10
ORDER BY obs_count DESC
LIMIT ?
""", (limit,))
patients = []
for row in cursor.fetchall():
patients.append({
"patient_id": row["id"],
"name": f"{row['given_name']} {row['family_name']}",
"birth_date": row["birth_date"],
"gender": row["gender"],
"data_counts": {
"conditions": row["condition_count"],
"medications": row["med_count"],
"observations": row["obs_count"],
"allergies": row["allergy_count"],
"immunizations": row["imm_count"],
"procedures": row["proc_count"],
"encounters": row["enc_count"]
}
})
return patients
finally:
conn.close()
def generate_vital_trend_cases(patient_id: str, days: int = 30) -> List[Dict]:
"""Generate test cases for vital sign trends (BP, heart rate, etc.)."""
test_cases = []
vital_types = [
("blood_pressure", ["8480-6", "8462-4"], ["systolic", "diastolic"]),
("heart_rate", ["8867-4"], ["heart_rate"]),
("weight", ["29463-7"], ["weight"]),
("temperature", ["8310-5"], ["temperature"]),
("oxygen_saturation", ["2708-6"], ["oxygen_saturation"]),
]
conn = get_db()
try:
for vital_name, codes, labels in vital_types:
# Check if patient has this vital data
placeholders = ",".join(["?" for _ in codes])
cursor = conn.execute(f"""
SELECT COUNT(*) as cnt FROM observations
WHERE patient_id = ? AND code IN ({placeholders})
""", [patient_id] + codes)
count = cursor.fetchone()["cnt"]
if count >= 3: # Need at least 3 readings for meaningful test
test_cases.append({
"case_id": f"{patient_id}_vital_{vital_name}",
"patient_id": patient_id,
"query_type": "vital_trend",
"query": f"Show me my {vital_name.replace('_', ' ')} chart",
"parameters": {
"vital_type": vital_name,
"days": days,
"codes": codes,
"labels": labels
}
})
finally:
conn.close()
return test_cases
def generate_medication_cases(patient_id: str) -> List[Dict]:
"""Generate test cases for medication queries."""
test_cases = []
conn = get_db()
try:
# Check if patient has medications
cursor = conn.execute("""
SELECT COUNT(*) as total,
SUM(CASE WHEN status = 'active' THEN 1 ELSE 0 END) as active
FROM medications WHERE patient_id = ?
""", (patient_id,))
row = cursor.fetchone()
if row["total"] > 0:
# All medications
test_cases.append({
"case_id": f"{patient_id}_meds_all",
"patient_id": patient_id,
"query_type": "medication_list",
"query": "What medications am I taking?",
"parameters": {"status": None}
})
# Active only
if row["active"] > 0:
test_cases.append({
"case_id": f"{patient_id}_meds_active",
"patient_id": patient_id,
"query_type": "medication_list",
"query": "What are my current active medications?",
"parameters": {"status": "active"}
})
finally:
conn.close()
return test_cases
def generate_condition_cases(patient_id: str) -> List[Dict]:
"""Generate test cases for condition queries."""
test_cases = []
conn = get_db()
try:
cursor = conn.execute("""
SELECT COUNT(*) as cnt FROM conditions WHERE patient_id = ?
""", (patient_id,))
if cursor.fetchone()["cnt"] > 0:
test_cases.append({
"case_id": f"{patient_id}_conditions",
"patient_id": patient_id,
"query_type": "condition_list",
"query": "What are my medical conditions?",
"parameters": {}
})
finally:
conn.close()
return test_cases
def generate_allergy_cases(patient_id: str) -> List[Dict]:
"""Generate test cases for allergy queries."""
test_cases = []
conn = get_db()
try:
cursor = conn.execute("""
SELECT COUNT(*) as cnt FROM allergies WHERE patient_id = ?
""", (patient_id,))
if cursor.fetchone()["cnt"] > 0:
test_cases.append({
"case_id": f"{patient_id}_allergies",
"patient_id": patient_id,
"query_type": "allergy_list",
"query": "What are my allergies?",
"parameters": {}
})
finally:
conn.close()
return test_cases
def generate_immunization_cases(patient_id: str) -> List[Dict]:
"""Generate test cases for immunization queries."""
test_cases = []
conn = get_db()
try:
cursor = conn.execute("""
SELECT COUNT(*) as cnt FROM immunizations WHERE patient_id = ?
""", (patient_id,))
if cursor.fetchone()["cnt"] > 0:
test_cases.append({
"case_id": f"{patient_id}_immunizations",
"patient_id": patient_id,
"query_type": "immunization_list",
"query": "What immunizations have I had?",
"parameters": {}
})
finally:
conn.close()
return test_cases
def generate_procedure_cases(patient_id: str) -> List[Dict]:
"""Generate test cases for procedure/surgical history queries."""
test_cases = []
conn = get_db()
try:
cursor = conn.execute("""
SELECT COUNT(*) as cnt FROM procedures WHERE patient_id = ?
""", (patient_id,))
if cursor.fetchone()["cnt"] > 0:
test_cases.append({
"case_id": f"{patient_id}_procedures",
"patient_id": patient_id,
"query_type": "procedure_list",
"query": "What procedures or surgeries have I had?",
"parameters": {}
})
finally:
conn.close()
return test_cases
def generate_encounter_cases(patient_id: str) -> List[Dict]:
"""Generate test cases for encounter history queries."""
test_cases = []
conn = get_db()
try:
cursor = conn.execute("""
SELECT COUNT(*) as cnt FROM encounters WHERE patient_id = ?
""", (patient_id,))
if cursor.fetchone()["cnt"] > 0:
test_cases.append({
"case_id": f"{patient_id}_encounters",
"patient_id": patient_id,
"query_type": "encounter_list",
"query": "Show me my recent visits",
"parameters": {"limit": 5}
})
finally:
conn.close()
return test_cases
def generate_lab_cases(patient_id: str) -> List[Dict]:
"""Generate test cases for lab result queries."""
test_cases = []
lab_types = [
("a1c", "4548-4", "HbA1c"),
("cholesterol", "2093-3", "Total Cholesterol"),
("glucose", "2345-7", "Glucose"),
]
conn = get_db()
try:
for lab_name, code, display in lab_types:
cursor = conn.execute("""
SELECT COUNT(*) as cnt FROM observations
WHERE patient_id = ? AND code = ?
""", (patient_id, code))
if cursor.fetchone()["cnt"] >= 2:
test_cases.append({
"case_id": f"{patient_id}_lab_{lab_name}",
"patient_id": patient_id,
"query_type": "lab_trend",
"query": f"Show me my {display} history",
"parameters": {
"lab_type": lab_name,
"code": code
}
})
finally:
conn.close()
return test_cases
def generate_all_test_cases(num_patients: int = 10) -> List[Dict]:
"""Generate complete test suite from available patients."""
patients = get_test_patients(num_patients)
all_cases = []
for patient in patients:
pid = patient["patient_id"]
# Generate cases for each data type
all_cases.extend(generate_vital_trend_cases(pid))
all_cases.extend(generate_medication_cases(pid))
all_cases.extend(generate_condition_cases(pid))
all_cases.extend(generate_allergy_cases(pid))
all_cases.extend(generate_immunization_cases(pid))
all_cases.extend(generate_procedure_cases(pid))
all_cases.extend(generate_encounter_cases(pid))
all_cases.extend(generate_lab_cases(pid))
return all_cases
def get_test_summary(test_cases: List[Dict]) -> Dict:
"""Get summary of generated test cases."""
summary = {
"total_cases": len(test_cases),
"by_type": {},
"by_patient": {}
}
for case in test_cases:
# Count by type
qtype = case["query_type"]
summary["by_type"][qtype] = summary["by_type"].get(qtype, 0) + 1
# Count by patient
pid = case["patient_id"]
summary["by_patient"][pid] = summary["by_patient"].get(pid, 0) + 1
return summary
if __name__ == "__main__":
# Test the generator
print("Generating test cases...")
cases = generate_all_test_cases(num_patients=5)
summary = get_test_summary(cases)
print(f"\nTotal test cases: {summary['total_cases']}")
print("\nBy query type:")
for qtype, count in sorted(summary["by_type"].items()):
print(f" {qtype}: {count}")
print("\nSample test case:")
if cases:
import json
print(json.dumps(cases[0], indent=2))