meta-hack / scripts /test_generalization.py
vvinayakkk's picture
Sync full clinical-trial-triage project into Space
404c45f
"""
Generalization test runner for unseen (OOD) production cases.
Strategy:
- Use production_cases.py only as unknown test set.
- Reuse inference.py action policy path (single LLM attempt + deterministic fallback).
- Grade with deterministic task graders.
Output:
- outputs/generalization_results.json
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Any, Dict, List
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from inference import choose_action, heuristic_action
from models import AdverseEventTriageAction, ProtocolDeviationAction, SafetyNarrativeAction
from tasks.graders import grade_ae_triage, grade_protocol_deviation, grade_safety_narrative
from tasks.production_cases import EXTRA_AE_CASES, EXTRA_DEVIATION_CASES, EXTRA_NARRATIVE_CASES
OUTPUT_FILE = ROOT / "outputs" / "generalization_results.json"
TASK_ID_AE = "adverse_event_triage"
TASK_ID_DEV = "protocol_deviation_audit"
TASK_ID_NAR = "safety_narrative_generation"
def _build_observation(task_id: str, case: Dict[str, Any], step_idx: int, total_steps: int) -> Dict[str, Any]:
if task_id == TASK_ID_AE:
payload = dict(case)
payload["step_count"] = step_idx
payload["max_steps"] = total_steps
return {
"task_id": task_id,
"ae_observation": payload,
"message": "Unknown production AE case",
}
if task_id == TASK_ID_DEV:
payload = dict(case)
payload["step_count"] = step_idx
payload["max_steps"] = total_steps
return {
"task_id": task_id,
"deviation_observation": payload,
"message": "Unknown production deviation case",
}
payload = dict(case)
payload["step_count"] = step_idx
payload["max_steps"] = total_steps
return {
"task_id": task_id,
"narrative_observation": payload,
"message": "Unknown production narrative case",
}
def _ensure_valid_action(task_id: str, observation: Dict[str, Any]) -> Dict[str, Any]:
try:
action = choose_action(task_id, observation)
if isinstance(action, dict) and action.get("task_id") == task_id:
return action
except Exception: # noqa: BLE001
pass
print("LLM failed, using heuristic fallback")
return heuristic_action(task_id, observation)
def _score_ae_case(case: Dict[str, Any], step_idx: int, total_steps: int) -> float:
observation = _build_observation(TASK_ID_AE, case, step_idx, total_steps)
action = _ensure_valid_action(TASK_ID_AE, observation)
try:
action_model = AdverseEventTriageAction(**action["ae_triage"])
return float(grade_ae_triage(action_model, case).total)
except Exception: # noqa: BLE001
fallback = heuristic_action(TASK_ID_AE, observation)
action_model = AdverseEventTriageAction(**fallback["ae_triage"])
return float(grade_ae_triage(action_model, case).total)
def _score_deviation_case(case: Dict[str, Any], step_idx: int, total_steps: int) -> float:
observation = _build_observation(TASK_ID_DEV, case, step_idx, total_steps)
action = _ensure_valid_action(TASK_ID_DEV, observation)
try:
action_model = ProtocolDeviationAction(**action["deviation_audit"])
return float(grade_protocol_deviation(action_model, case).total)
except Exception: # noqa: BLE001
fallback = heuristic_action(TASK_ID_DEV, observation)
action_model = ProtocolDeviationAction(**fallback["deviation_audit"])
return float(grade_protocol_deviation(action_model, case).total)
def _score_narrative_case(case: Dict[str, Any], step_idx: int, total_steps: int) -> float:
observation = _build_observation(TASK_ID_NAR, case, step_idx, total_steps)
action = _ensure_valid_action(TASK_ID_NAR, observation)
try:
action_model = SafetyNarrativeAction(**action["safety_narrative"])
return float(grade_safety_narrative(action_model, case).total)
except Exception: # noqa: BLE001
fallback = heuristic_action(TASK_ID_NAR, observation)
action_model = SafetyNarrativeAction(**fallback["safety_narrative"])
return float(grade_safety_narrative(action_model, case).total)
def _mean(values: List[float]) -> float:
if not values:
return 0.0
return float(sum(values) / len(values))
def run_generalization() -> Dict[str, Any]:
print("Running generalization test...")
ae_scores = [
_score_ae_case(case, idx, len(EXTRA_AE_CASES))
for idx, case in enumerate(EXTRA_AE_CASES, start=1)
]
dev_scores = [
_score_deviation_case(case, idx, len(EXTRA_DEVIATION_CASES))
for idx, case in enumerate(EXTRA_DEVIATION_CASES, start=1)
]
nar_scores = [
_score_narrative_case(case, idx, len(EXTRA_NARRATIVE_CASES))
for idx, case in enumerate(EXTRA_NARRATIVE_CASES, start=1)
]
per_task_scores = {
TASK_ID_AE: round(_mean(ae_scores), 4),
TASK_ID_DEV: round(_mean(dev_scores), 4),
TASK_ID_NAR: round(_mean(nar_scores), 4),
}
mean_score = round(_mean(list(per_task_scores.values())), 4)
print(f"{TASK_ID_AE}: {per_task_scores[TASK_ID_AE]:.4f}")
print(f"{TASK_ID_DEV}: {per_task_scores[TASK_ID_DEV]:.4f}")
print(f"{TASK_ID_NAR}: {per_task_scores[TASK_ID_NAR]:.4f}")
print(f"Final mean score: {mean_score:.4f}")
result = {
"per_task_scores": per_task_scores,
"mean_score": mean_score,
}
OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
OUTPUT_FILE.write_text(json.dumps(result, indent=2), encoding="utf-8")
print(f"Saved results to: {OUTPUT_FILE}")
return result
if __name__ == "__main__":
run_generalization()