Spaces:
Sleeping
Sleeping
| """ | |
| Generalization test runner for unseen (OOD) production cases. | |
| Strategy: | |
| - Use production_cases.py only as unknown test set. | |
| - Reuse inference.py action policy path (single LLM attempt + deterministic fallback). | |
| - Grade with deterministic task graders. | |
| Output: | |
| - outputs/generalization_results.json | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| ROOT = Path(__file__).resolve().parent.parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from inference import choose_action, heuristic_action | |
| from models import AdverseEventTriageAction, ProtocolDeviationAction, SafetyNarrativeAction | |
| from tasks.graders import grade_ae_triage, grade_protocol_deviation, grade_safety_narrative | |
| from tasks.production_cases import EXTRA_AE_CASES, EXTRA_DEVIATION_CASES, EXTRA_NARRATIVE_CASES | |
| OUTPUT_FILE = ROOT / "outputs" / "generalization_results.json" | |
| TASK_ID_AE = "adverse_event_triage" | |
| TASK_ID_DEV = "protocol_deviation_audit" | |
| TASK_ID_NAR = "safety_narrative_generation" | |
| def _build_observation(task_id: str, case: Dict[str, Any], step_idx: int, total_steps: int) -> Dict[str, Any]: | |
| if task_id == TASK_ID_AE: | |
| payload = dict(case) | |
| payload["step_count"] = step_idx | |
| payload["max_steps"] = total_steps | |
| return { | |
| "task_id": task_id, | |
| "ae_observation": payload, | |
| "message": "Unknown production AE case", | |
| } | |
| if task_id == TASK_ID_DEV: | |
| payload = dict(case) | |
| payload["step_count"] = step_idx | |
| payload["max_steps"] = total_steps | |
| return { | |
| "task_id": task_id, | |
| "deviation_observation": payload, | |
| "message": "Unknown production deviation case", | |
| } | |
| payload = dict(case) | |
| payload["step_count"] = step_idx | |
| payload["max_steps"] = total_steps | |
| return { | |
| "task_id": task_id, | |
| "narrative_observation": payload, | |
| "message": "Unknown production narrative case", | |
| } | |
| def _ensure_valid_action(task_id: str, observation: Dict[str, Any]) -> Dict[str, Any]: | |
| try: | |
| action = choose_action(task_id, observation) | |
| if isinstance(action, dict) and action.get("task_id") == task_id: | |
| return action | |
| except Exception: # noqa: BLE001 | |
| pass | |
| print("LLM failed, using heuristic fallback") | |
| return heuristic_action(task_id, observation) | |
| def _score_ae_case(case: Dict[str, Any], step_idx: int, total_steps: int) -> float: | |
| observation = _build_observation(TASK_ID_AE, case, step_idx, total_steps) | |
| action = _ensure_valid_action(TASK_ID_AE, observation) | |
| try: | |
| action_model = AdverseEventTriageAction(**action["ae_triage"]) | |
| return float(grade_ae_triage(action_model, case).total) | |
| except Exception: # noqa: BLE001 | |
| fallback = heuristic_action(TASK_ID_AE, observation) | |
| action_model = AdverseEventTriageAction(**fallback["ae_triage"]) | |
| return float(grade_ae_triage(action_model, case).total) | |
| def _score_deviation_case(case: Dict[str, Any], step_idx: int, total_steps: int) -> float: | |
| observation = _build_observation(TASK_ID_DEV, case, step_idx, total_steps) | |
| action = _ensure_valid_action(TASK_ID_DEV, observation) | |
| try: | |
| action_model = ProtocolDeviationAction(**action["deviation_audit"]) | |
| return float(grade_protocol_deviation(action_model, case).total) | |
| except Exception: # noqa: BLE001 | |
| fallback = heuristic_action(TASK_ID_DEV, observation) | |
| action_model = ProtocolDeviationAction(**fallback["deviation_audit"]) | |
| return float(grade_protocol_deviation(action_model, case).total) | |
| def _score_narrative_case(case: Dict[str, Any], step_idx: int, total_steps: int) -> float: | |
| observation = _build_observation(TASK_ID_NAR, case, step_idx, total_steps) | |
| action = _ensure_valid_action(TASK_ID_NAR, observation) | |
| try: | |
| action_model = SafetyNarrativeAction(**action["safety_narrative"]) | |
| return float(grade_safety_narrative(action_model, case).total) | |
| except Exception: # noqa: BLE001 | |
| fallback = heuristic_action(TASK_ID_NAR, observation) | |
| action_model = SafetyNarrativeAction(**fallback["safety_narrative"]) | |
| return float(grade_safety_narrative(action_model, case).total) | |
| def _mean(values: List[float]) -> float: | |
| if not values: | |
| return 0.0 | |
| return float(sum(values) / len(values)) | |
| def run_generalization() -> Dict[str, Any]: | |
| print("Running generalization test...") | |
| ae_scores = [ | |
| _score_ae_case(case, idx, len(EXTRA_AE_CASES)) | |
| for idx, case in enumerate(EXTRA_AE_CASES, start=1) | |
| ] | |
| dev_scores = [ | |
| _score_deviation_case(case, idx, len(EXTRA_DEVIATION_CASES)) | |
| for idx, case in enumerate(EXTRA_DEVIATION_CASES, start=1) | |
| ] | |
| nar_scores = [ | |
| _score_narrative_case(case, idx, len(EXTRA_NARRATIVE_CASES)) | |
| for idx, case in enumerate(EXTRA_NARRATIVE_CASES, start=1) | |
| ] | |
| per_task_scores = { | |
| TASK_ID_AE: round(_mean(ae_scores), 4), | |
| TASK_ID_DEV: round(_mean(dev_scores), 4), | |
| TASK_ID_NAR: round(_mean(nar_scores), 4), | |
| } | |
| mean_score = round(_mean(list(per_task_scores.values())), 4) | |
| print(f"{TASK_ID_AE}: {per_task_scores[TASK_ID_AE]:.4f}") | |
| print(f"{TASK_ID_DEV}: {per_task_scores[TASK_ID_DEV]:.4f}") | |
| print(f"{TASK_ID_NAR}: {per_task_scores[TASK_ID_NAR]:.4f}") | |
| print(f"Final mean score: {mean_score:.4f}") | |
| result = { | |
| "per_task_scores": per_task_scores, | |
| "mean_score": mean_score, | |
| } | |
| OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| OUTPUT_FILE.write_text(json.dumps(result, indent=2), encoding="utf-8") | |
| print(f"Saved results to: {OUTPUT_FILE}") | |
| return result | |
| if __name__ == "__main__": | |
| run_generalization() | |