multimodal_previsit / evaluation /facts_schema.py
frabbani
Fix fact extraction - pass raw data for simple tools.......
8daa8bf
#!/usr/bin/env python3
"""
Facts Schema
Defines the structured facts format that the agent should return
alongside its text responses. These facts are used for evaluation.
"""
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, asdict
import json
@dataclass
class VitalTrendFacts:
"""Structured facts for vital sign trend queries."""
vital_type: str
days: int
metrics: Dict[str, Dict[str, Any]] # {label: {min, max, avg, count, dates...}}
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class MedicationFacts:
"""Structured facts for medication queries."""
status_filter: Optional[str]
count: int
medication_names: List[str]
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class ConditionFacts:
"""Structured facts for condition queries."""
count: int
condition_names: List[str]
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class AllergyFacts:
"""Structured facts for allergy queries."""
count: int
substances: List[str]
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class ImmunizationFacts:
"""Structured facts for immunization queries."""
count: int
vaccine_names: List[str]
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class ProcedureFacts:
"""Structured facts for procedure queries."""
count: int
procedure_names: List[str]
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class EncounterFacts:
"""Structured facts for encounter queries."""
count: int
limit: int
def to_dict(self) -> Dict:
return asdict(self)
@dataclass
class LabTrendFacts:
"""Structured facts for lab trend queries."""
lab_type: str
code: str
unit: Optional[str]
count: int
metrics: Dict[str, Any] # {min, max, avg, latest, dates...}
def to_dict(self) -> Dict:
return asdict(self)
def extract_vital_facts_from_tool_result(tool_result: Dict) -> Optional[VitalTrendFacts]:
"""
Extract structured facts from vital chart tool result.
The tool already returns structured JSON - we just reshape it.
"""
if "error" in tool_result:
return None
chart_type = tool_result.get("chart_type", "")
if chart_type not in ["line", "line_dual"]:
return None
metrics = {}
for dataset in tool_result.get("datasets", []):
label = dataset.get("label", "unknown").lower().replace(" ", "_")
data_points = dataset.get("data", [])
if not data_points:
continue
values = [p["value"] for p in data_points if p.get("value") is not None]
dates = [p["date"] for p in data_points if p.get("date")]
if values:
import statistics
metrics[label] = {
"min": round(min(values), 1),
"max": round(max(values), 1),
"avg": round(statistics.mean(values), 1),
"count": len(values),
"latest": round(values[-1], 1) if values else None,
"earliest_date": dates[0] if dates else None,
"latest_date": dates[-1] if dates else None
}
return VitalTrendFacts(
vital_type=tool_result.get("title", "").lower().replace(" ", "_"),
days=30, # Default, could be extracted from title
metrics=metrics
)
def extract_lab_facts_from_tool_result(tool_result: Dict) -> Optional[LabTrendFacts]:
"""Extract structured facts from lab chart tool result."""
if "error" in tool_result:
return None
datasets = tool_result.get("datasets", [])
if not datasets:
return None
# Get first dataset
dataset = datasets[0]
data_points = dataset.get("data", [])
if not data_points:
return None
values = [p["value"] for p in data_points if p.get("value") is not None]
dates = [p["date"] for p in data_points if p.get("date")]
metrics = {}
if values:
import statistics
metrics = {
"min": round(min(values), 1),
"max": round(max(values), 1),
"avg": round(statistics.mean(values), 1),
"latest": round(values[-1], 1),
"latest_date": dates[-1] if dates else None
}
return LabTrendFacts(
lab_type=dataset.get("label", "unknown").lower(),
code="", # Not in tool result
unit=tool_result.get("unit"),
count=len(values),
metrics=metrics
)
def extract_medication_facts(medications: List[Dict], status_filter: Optional[str] = None) -> MedicationFacts:
"""Extract structured facts from medication list."""
names = [m.get("display", "") for m in medications]
return MedicationFacts(
status_filter=status_filter,
count=len(medications),
medication_names=names
)
def extract_condition_facts(conditions: List[Dict]) -> ConditionFacts:
"""Extract structured facts from condition list."""
names = [c.get("display", "") for c in conditions]
return ConditionFacts(
count=len(conditions),
condition_names=names
)
def extract_allergy_facts(allergies: List[Dict]) -> AllergyFacts:
"""Extract structured facts from allergy list."""
substances = [a.get("substance", "") for a in allergies]
return AllergyFacts(
count=len(allergies),
substances=substances
)
def extract_immunization_facts(immunizations: List[Dict]) -> ImmunizationFacts:
"""Extract structured facts from immunization list."""
names = [i.get("vaccine_display", "") for i in immunizations]
return ImmunizationFacts(
count=len(immunizations),
vaccine_names=names
)
def extract_procedure_facts(procedures: List[Dict]) -> ProcedureFacts:
"""Extract structured facts from procedure list."""
names = [p.get("display", "") for p in procedures]
return ProcedureFacts(
count=len(procedures),
procedure_names=names
)
def extract_encounter_facts(encounters: List[Dict], limit: int = 5) -> EncounterFacts:
"""Extract structured facts from encounter list."""
return EncounterFacts(
count=len(encounters),
limit=limit
)