multimodal_previsit / evaluation /expected_values.py
frabbani
Fix fact extraction - pass raw data for simple tools..........
9b331e2
#!/usr/bin/env python3
"""
Expected Values Calculator
Computes ground truth values directly from the database for each test case type.
These are the values we expect the LLM to report.
"""
import sqlite3
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
import os
import statistics
DB_PATH = os.getenv("DB_PATH", "data/fhir.db")
def get_db():
"""Get database connection."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def compute_vital_trend_expected(patient_id: str, vital_type: str, codes: List[str],
labels: List[str], days: int = 30) -> Dict[str, Any]:
"""
Compute expected values for vital trend queries.
Returns expected facts like min, max, avg, count, date range.
"""
conn = get_db()
try:
cutoff_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
result = {
"query_type": "vital_trend",
"vital_type": vital_type,
"days": days,
"metrics": {}
}
for code, label in zip(codes, labels):
cursor = conn.execute("""
SELECT value_quantity, effective_date
FROM observations
WHERE patient_id = ? AND code = ? AND effective_date >= ?
ORDER BY effective_date ASC
""", (patient_id, code, cutoff_date))
rows = cursor.fetchall()
values = [r["value_quantity"] for r in rows if r["value_quantity"] is not None]
dates = [r["effective_date"][:10] for r in rows]
if values:
result["metrics"][label] = {
"min": round(min(values), 1),
"max": round(max(values), 1),
"avg": round(statistics.mean(values), 1),
"count": len(values),
"latest": round(values[-1], 1),
"earliest_date": dates[0] if dates else None,
"latest_date": dates[-1] if dates else None,
"all_values": [round(v, 1) for v in values],
"all_dates": dates
}
return result
finally:
conn.close()
def compute_medication_expected(patient_id: str, status: Optional[str] = None) -> Dict[str, Any]:
"""
Compute expected values for medication queries.
Returns list of medications with their details.
"""
conn = get_db()
try:
if status:
cursor = conn.execute("""
SELECT code, display, status, start_date
FROM medications
WHERE patient_id = ? AND status = ?
ORDER BY start_date DESC
""", (patient_id, status))
else:
cursor = conn.execute("""
SELECT code, display, status, start_date
FROM medications
WHERE patient_id = ?
ORDER BY start_date DESC
""", (patient_id,))
medications = []
for row in cursor.fetchall():
medications.append({
"code": row["code"],
"display": row["display"],
"status": row["status"],
"start_date": row["start_date"][:10] if row["start_date"] else None
})
return {
"query_type": "medication_list",
"status_filter": status,
"count": len(medications),
"medications": medications,
"medication_names": [m["display"] for m in medications]
}
finally:
conn.close()
def compute_condition_expected(patient_id: str) -> Dict[str, Any]:
"""
Compute expected values for condition queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT code, display, clinical_status, onset_date
FROM conditions
WHERE patient_id = ?
ORDER BY onset_date DESC
""", (patient_id,))
conditions = []
for row in cursor.fetchall():
conditions.append({
"code": row["code"],
"display": row["display"],
"clinical_status": row["clinical_status"],
"onset_date": row["onset_date"][:10] if row["onset_date"] else None
})
return {
"query_type": "condition_list",
"count": len(conditions),
"conditions": conditions,
"condition_names": [c["display"] for c in conditions]
}
finally:
conn.close()
def compute_allergy_expected(patient_id: str) -> Dict[str, Any]:
"""
Compute expected values for allergy queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT substance, reaction_display, criticality, category
FROM allergies
WHERE patient_id = ?
""", (patient_id,))
allergies = []
for row in cursor.fetchall():
allergies.append({
"substance": row["substance"],
"reaction": row["reaction_display"],
"criticality": row["criticality"],
"category": row["category"]
})
return {
"query_type": "allergy_list",
"count": len(allergies),
"allergies": allergies,
"substances": [a["substance"] for a in allergies]
}
finally:
conn.close()
def compute_immunization_expected(patient_id: str) -> Dict[str, Any]:
"""
Compute expected values for immunization queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT vaccine_code, vaccine_display, status, occurrence_date
FROM immunizations
WHERE patient_id = ?
ORDER BY occurrence_date DESC
""", (patient_id,))
immunizations = []
for row in cursor.fetchall():
immunizations.append({
"vaccine_code": row["vaccine_code"],
"vaccine_display": row["vaccine_display"],
"status": row["status"],
"occurrence_date": row["occurrence_date"][:10] if row["occurrence_date"] else None
})
return {
"query_type": "immunization_list",
"count": len(immunizations),
"immunizations": immunizations,
"vaccine_names": [i["vaccine_display"] for i in immunizations]
}
finally:
conn.close()
def compute_procedure_expected(patient_id: str) -> Dict[str, Any]:
"""
Compute expected values for procedure queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT code, display, status, performed_date
FROM procedures
WHERE patient_id = ?
ORDER BY performed_date DESC
""", (patient_id,))
procedures = []
for row in cursor.fetchall():
procedures.append({
"code": row["code"],
"display": row["display"],
"status": row["status"],
"performed_date": row["performed_date"][:10] if row["performed_date"] else None
})
return {
"query_type": "procedure_list",
"count": len(procedures),
"procedures": procedures,
"procedure_names": [p["display"] for p in procedures]
}
finally:
conn.close()
def compute_encounter_expected(patient_id: str, limit: int = 5) -> Dict[str, Any]:
"""
Compute expected values for encounter queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT type_display, reason_display, period_start, period_end, class_display
FROM encounters
WHERE patient_id = ?
ORDER BY period_start DESC
LIMIT ?
""", (patient_id, limit))
encounters = []
for row in cursor.fetchall():
encounters.append({
"type": row["type_display"],
"reason": row["reason_display"],
"class": row["class_display"],
"start_date": row["period_start"][:10] if row["period_start"] else None,
"end_date": row["period_end"][:10] if row["period_end"] else None
})
return {
"query_type": "encounter_list",
"count": len(encounters),
"limit": limit,
"encounters": encounters
}
finally:
conn.close()
def compute_lab_trend_expected(patient_id: str, lab_type: str, code: str,
periods: int = 4) -> Dict[str, Any]:
"""
Compute expected values for lab trend queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT value_quantity, effective_date, value_unit
FROM observations
WHERE patient_id = ? AND code = ?
ORDER BY effective_date DESC
LIMIT ?
""", (patient_id, code, periods))
rows = cursor.fetchall()
values = [r["value_quantity"] for r in rows if r["value_quantity"] is not None]
dates = [r["effective_date"][:10] for r in rows]
unit = rows[0]["value_unit"] if rows else None
result = {
"query_type": "lab_trend",
"lab_type": lab_type,
"code": code,
"unit": unit,
"count": len(values)
}
if values:
result["metrics"] = {
"min": round(min(values), 1),
"max": round(max(values), 1),
"avg": round(statistics.mean(values), 1),
"latest": round(values[0], 1), # Most recent
"latest_date": dates[0] if dates else None,
"all_values": [round(v, 1) for v in values],
"all_dates": dates
}
return result
finally:
conn.close()
def compute_expected_values(test_case: Dict) -> Dict[str, Any]:
"""
Compute expected values for any test case type.
Routes to the appropriate computation function.
"""
query_type = test_case["query_type"]
patient_id = test_case["patient_id"]
params = test_case.get("parameters", {})
if query_type == "vital_trend":
return compute_vital_trend_expected(
patient_id,
params["vital_type"],
params["codes"],
params["labels"],
params.get("days", 30)
)
elif query_type == "medication_list":
return compute_medication_expected(patient_id, params.get("status"))
elif query_type == "condition_list":
return compute_condition_expected(patient_id)
elif query_type == "allergy_list":
return compute_allergy_expected(patient_id)
elif query_type == "immunization_list":
return compute_immunization_expected(patient_id)
elif query_type == "procedure_list":
return compute_procedure_expected(patient_id)
elif query_type == "encounter_list":
return compute_encounter_expected(patient_id, params.get("limit", 5))
elif query_type == "lab_trend":
return compute_lab_trend_expected(
patient_id,
params["lab_type"],
params["code"],
params.get("periods", 4)
)
else:
return {"error": f"Unknown query type: {query_type}"}
if __name__ == "__main__":
# Test with a sample case
from test_generator import generate_all_test_cases
import json
print("Generating test cases...")
cases = generate_all_test_cases(num_patients=1)
print(f"\nComputing expected values for {len(cases)} test cases...")
for case in cases[:3]: # Show first 3
print(f"\n{'='*60}")
print(f"Case: {case['case_id']}")
print(f"Query: {case['query']}")
expected = compute_expected_values(case)
print(f"Expected values:")
print(json.dumps(expected, indent=2, default=str))