Spaces:
Sleeping
Sleeping
File size: 12,345 Bytes
8daa8bf 9b331e2 8daa8bf 9b331e2 8daa8bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 | #!/usr/bin/env python3
"""
Expected Values Calculator
Computes ground truth values directly from the database for each test case type.
These are the values we expect the LLM to report.
"""
import sqlite3
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
import os
import statistics
DB_PATH = os.getenv("DB_PATH", "data/fhir.db")
def get_db():
"""Get database connection."""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def compute_vital_trend_expected(patient_id: str, vital_type: str, codes: List[str],
labels: List[str], days: int = 30) -> Dict[str, Any]:
"""
Compute expected values for vital trend queries.
Returns expected facts like min, max, avg, count, date range.
"""
conn = get_db()
try:
cutoff_date = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
result = {
"query_type": "vital_trend",
"vital_type": vital_type,
"days": days,
"metrics": {}
}
for code, label in zip(codes, labels):
cursor = conn.execute("""
SELECT value_quantity, effective_date
FROM observations
WHERE patient_id = ? AND code = ? AND effective_date >= ?
ORDER BY effective_date ASC
""", (patient_id, code, cutoff_date))
rows = cursor.fetchall()
values = [r["value_quantity"] for r in rows if r["value_quantity"] is not None]
dates = [r["effective_date"][:10] for r in rows]
if values:
result["metrics"][label] = {
"min": round(min(values), 1),
"max": round(max(values), 1),
"avg": round(statistics.mean(values), 1),
"count": len(values),
"latest": round(values[-1], 1),
"earliest_date": dates[0] if dates else None,
"latest_date": dates[-1] if dates else None,
"all_values": [round(v, 1) for v in values],
"all_dates": dates
}
return result
finally:
conn.close()
def compute_medication_expected(patient_id: str, status: Optional[str] = None) -> Dict[str, Any]:
"""
Compute expected values for medication queries.
Returns list of medications with their details.
"""
conn = get_db()
try:
if status:
cursor = conn.execute("""
SELECT code, display, status, start_date
FROM medications
WHERE patient_id = ? AND status = ?
ORDER BY start_date DESC
""", (patient_id, status))
else:
cursor = conn.execute("""
SELECT code, display, status, start_date
FROM medications
WHERE patient_id = ?
ORDER BY start_date DESC
""", (patient_id,))
medications = []
for row in cursor.fetchall():
medications.append({
"code": row["code"],
"display": row["display"],
"status": row["status"],
"start_date": row["start_date"][:10] if row["start_date"] else None
})
return {
"query_type": "medication_list",
"status_filter": status,
"count": len(medications),
"medications": medications,
"medication_names": [m["display"] for m in medications]
}
finally:
conn.close()
def compute_condition_expected(patient_id: str) -> Dict[str, Any]:
"""
Compute expected values for condition queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT code, display, clinical_status, onset_date
FROM conditions
WHERE patient_id = ?
ORDER BY onset_date DESC
""", (patient_id,))
conditions = []
for row in cursor.fetchall():
conditions.append({
"code": row["code"],
"display": row["display"],
"clinical_status": row["clinical_status"],
"onset_date": row["onset_date"][:10] if row["onset_date"] else None
})
return {
"query_type": "condition_list",
"count": len(conditions),
"conditions": conditions,
"condition_names": [c["display"] for c in conditions]
}
finally:
conn.close()
def compute_allergy_expected(patient_id: str) -> Dict[str, Any]:
"""
Compute expected values for allergy queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT substance, reaction_display, criticality, category
FROM allergies
WHERE patient_id = ?
""", (patient_id,))
allergies = []
for row in cursor.fetchall():
allergies.append({
"substance": row["substance"],
"reaction": row["reaction_display"],
"criticality": row["criticality"],
"category": row["category"]
})
return {
"query_type": "allergy_list",
"count": len(allergies),
"allergies": allergies,
"substances": [a["substance"] for a in allergies]
}
finally:
conn.close()
def compute_immunization_expected(patient_id: str) -> Dict[str, Any]:
"""
Compute expected values for immunization queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT vaccine_code, vaccine_display, status, occurrence_date
FROM immunizations
WHERE patient_id = ?
ORDER BY occurrence_date DESC
""", (patient_id,))
immunizations = []
for row in cursor.fetchall():
immunizations.append({
"vaccine_code": row["vaccine_code"],
"vaccine_display": row["vaccine_display"],
"status": row["status"],
"occurrence_date": row["occurrence_date"][:10] if row["occurrence_date"] else None
})
return {
"query_type": "immunization_list",
"count": len(immunizations),
"immunizations": immunizations,
"vaccine_names": [i["vaccine_display"] for i in immunizations]
}
finally:
conn.close()
def compute_procedure_expected(patient_id: str) -> Dict[str, Any]:
"""
Compute expected values for procedure queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT code, display, status, performed_date
FROM procedures
WHERE patient_id = ?
ORDER BY performed_date DESC
""", (patient_id,))
procedures = []
for row in cursor.fetchall():
procedures.append({
"code": row["code"],
"display": row["display"],
"status": row["status"],
"performed_date": row["performed_date"][:10] if row["performed_date"] else None
})
return {
"query_type": "procedure_list",
"count": len(procedures),
"procedures": procedures,
"procedure_names": [p["display"] for p in procedures]
}
finally:
conn.close()
def compute_encounter_expected(patient_id: str, limit: int = 5) -> Dict[str, Any]:
"""
Compute expected values for encounter queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT type_display, reason_display, period_start, period_end, class_display
FROM encounters
WHERE patient_id = ?
ORDER BY period_start DESC
LIMIT ?
""", (patient_id, limit))
encounters = []
for row in cursor.fetchall():
encounters.append({
"type": row["type_display"],
"reason": row["reason_display"],
"class": row["class_display"],
"start_date": row["period_start"][:10] if row["period_start"] else None,
"end_date": row["period_end"][:10] if row["period_end"] else None
})
return {
"query_type": "encounter_list",
"count": len(encounters),
"limit": limit,
"encounters": encounters
}
finally:
conn.close()
def compute_lab_trend_expected(patient_id: str, lab_type: str, code: str,
periods: int = 4) -> Dict[str, Any]:
"""
Compute expected values for lab trend queries.
"""
conn = get_db()
try:
cursor = conn.execute("""
SELECT value_quantity, effective_date, value_unit
FROM observations
WHERE patient_id = ? AND code = ?
ORDER BY effective_date DESC
LIMIT ?
""", (patient_id, code, periods))
rows = cursor.fetchall()
values = [r["value_quantity"] for r in rows if r["value_quantity"] is not None]
dates = [r["effective_date"][:10] for r in rows]
unit = rows[0]["value_unit"] if rows else None
result = {
"query_type": "lab_trend",
"lab_type": lab_type,
"code": code,
"unit": unit,
"count": len(values)
}
if values:
result["metrics"] = {
"min": round(min(values), 1),
"max": round(max(values), 1),
"avg": round(statistics.mean(values), 1),
"latest": round(values[0], 1), # Most recent
"latest_date": dates[0] if dates else None,
"all_values": [round(v, 1) for v in values],
"all_dates": dates
}
return result
finally:
conn.close()
def compute_expected_values(test_case: Dict) -> Dict[str, Any]:
"""
Compute expected values for any test case type.
Routes to the appropriate computation function.
"""
query_type = test_case["query_type"]
patient_id = test_case["patient_id"]
params = test_case.get("parameters", {})
if query_type == "vital_trend":
return compute_vital_trend_expected(
patient_id,
params["vital_type"],
params["codes"],
params["labels"],
params.get("days", 30)
)
elif query_type == "medication_list":
return compute_medication_expected(patient_id, params.get("status"))
elif query_type == "condition_list":
return compute_condition_expected(patient_id)
elif query_type == "allergy_list":
return compute_allergy_expected(patient_id)
elif query_type == "immunization_list":
return compute_immunization_expected(patient_id)
elif query_type == "procedure_list":
return compute_procedure_expected(patient_id)
elif query_type == "encounter_list":
return compute_encounter_expected(patient_id, params.get("limit", 5))
elif query_type == "lab_trend":
return compute_lab_trend_expected(
patient_id,
params["lab_type"],
params["code"],
params.get("periods", 4)
)
else:
return {"error": f"Unknown query type: {query_type}"}
if __name__ == "__main__":
# Test with a sample case
from test_generator import generate_all_test_cases
import json
print("Generating test cases...")
cases = generate_all_test_cases(num_patients=1)
print(f"\nComputing expected values for {len(cases)} test cases...")
for case in cases[:3]: # Show first 3
print(f"\n{'='*60}")
print(f"Case: {case['case_id']}")
print(f"Query: {case['query']}")
expected = compute_expected_values(case)
print(f"Expected values:")
print(json.dumps(expected, indent=2, default=str))
|