Spaces:
Sleeping
Sleeping
| """ | |
| Heuristic Baseline | |
| =================== | |
| A deterministic rule-based baseline agent for all 3 tasks. | |
| Used by the /baseline endpoint for fast, reproducible scoring. | |
| The LLM-based baseline (GroqCloud API) is in scripts/baseline_inference.py. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict | |
| from models import ( | |
| AdverseEventTriageAction, | |
| CausalityAssessment, | |
| DeviationType, | |
| ProtocolDeviationAction, | |
| SafetyNarrativeAction, | |
| TaskID, | |
| TriageAction, | |
| ) | |
| from server.environment import ClinicalTrialEnvironment | |
| from tasks.case_bank import AE_CASES, DEVIATION_CASES, NARRATIVE_CASES | |
| _SCORE_EPS = 1e-3 | |
| def _clamp_open_score(value: float) -> float: | |
| return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, float(value))) | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # HEURISTIC AGENTS (rule-based) | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| def _heuristic_ae_triage(case: Dict[str, Any]) -> TriageAction: | |
| """ | |
| Simple keyword/lab-value heuristic for AE triage. | |
| Gets ~70% on easy cases, ~45% on hard cases. | |
| """ | |
| narrative = case["narrative"].lower() | |
| ae_desc = case["ae_description"].lower() | |
| # Severity heuristic | |
| if any(kw in narrative for kw in ["fatal", "death", "died"]): | |
| severity = "fatal" | |
| timeline = "7-day" | |
| is_serious = True | |
| elif any(kw in narrative for kw in ["stemi", "cardiac arrest", "icu", "intensive care", "life-threatening"]): | |
| severity = "life_threatening" | |
| timeline = "7-day" | |
| is_serious = True | |
| elif any(kw in narrative for kw in ["hospitali", "encephalopathy", "grade 3", "severe"]): | |
| severity = "severe" | |
| timeline = "15-day" | |
| is_serious = True | |
| elif any(kw in narrative for kw in ["grade 2", "moderate", "nausea"]): | |
| severity = "moderate" | |
| timeline = "routine" | |
| is_serious = False | |
| else: | |
| severity = "mild" | |
| timeline = "routine" | |
| is_serious = False | |
| # MedDRA SOC heuristic | |
| if any(kw in ae_desc for kw in ["cardiac", "myocardial", "stemi", "heart"]): | |
| soc = "Cardiac disorders" | |
| pt = "Myocardial infarction" | |
| elif any(kw in ae_desc for kw in ["nausea", "vomiting", "gastrointestinal"]): | |
| soc = "Gastrointestinal disorders" | |
| pt = "Nausea" | |
| elif any(kw in ae_desc for kw in ["encephalopathy", "neurological", "nervous"]): | |
| soc = "Nervous system disorders" | |
| pt = "Encephalopathy" | |
| else: | |
| soc = "General disorders" | |
| pt = "Adverse event" | |
| return TriageAction( | |
| task_id=TaskID.ADVERSE_EVENT_TRIAGE, | |
| ae_triage=AdverseEventTriageAction( | |
| severity_classification=severity, | |
| reporting_timeline=timeline, | |
| meddra_soc=soc, | |
| meddra_preferred_term=pt, | |
| is_serious=is_serious, | |
| rationale="Heuristic baseline classification based on keyword matching.", | |
| ), | |
| ) | |
| def _heuristic_deviation_audit(case: Dict[str, Any]) -> TriageAction: | |
| """Heuristic deviation audit based on finding severity keywords.""" | |
| findings = case["findings"] | |
| high_risk_keywords = [ | |
| "eligibility", "blinding", "unblind", "sae report", "integrity", | |
| "data", "enroll", "hospitali", "ip accountability", "unaccounted", | |
| "endpoint", "consent", "delegate" | |
| ] | |
| risk_count = 0 | |
| flagged_ids = [] | |
| for f in findings: | |
| desc_lower = f["description"].lower() | |
| cat_lower = f["category"].lower() | |
| if any(kw in desc_lower or kw in cat_lower for kw in high_risk_keywords): | |
| risk_count += 1 | |
| flagged_ids.append(f["id"]) | |
| is_major = risk_count >= 2 | |
| capa = risk_count >= 2 | |
| risk_score = min(10.0, risk_count * 2.5 + case.get("prior_deviations", 0) * 0.3) | |
| return TriageAction( | |
| task_id=TaskID.PROTOCOL_DEVIATION_AUDIT, | |
| deviation_audit=ProtocolDeviationAction( | |
| deviation_type=DeviationType.MAJOR if is_major else DeviationType.MINOR, | |
| capa_required=capa, | |
| site_risk_score=risk_score, | |
| flagged_finding_ids=flagged_ids, | |
| recommended_action="Immediate escalation to CRA and Sponsor QA team for review." if is_major else "Document and include in next monitoring report.", | |
| ), | |
| ) | |
| def _heuristic_narrative(case: Dict[str, Any]) -> TriageAction: | |
| """Heuristic narrative generation β structured template filling.""" | |
| dem = case["patient_demographics"] | |
| ae = case["adverse_event"] | |
| labs = case["lab_values_timeline"] | |
| conmeds = case["concomitant_medications"] | |
| conmed_str = "; ".join(f"{m['name']} {m['dose']}" for m in conmeds) | |
| lab_str = "; ".join( | |
| f"{l['date']}: INR {l.get('INR', 'N/A')}, Hgb {l.get('Hgb_g_dL', 'N/A')} g/dL" | |
| for l in labs | |
| ) | |
| narrative = ( | |
| f"A {dem['age']}-year-old {dem['sex']} patient enrolled in a clinical study " | |
| f"received {case['study_drug']}. " | |
| f"Relevant medical history includes: {'; '.join(case['medical_history'])}. " | |
| f"Concomitant medications: {conmed_str}. " | |
| f"On {ae['onset_date']}, the patient developed {ae['term']} (MedDRA: {ae['meddra_soc']} / {ae['meddra_pt']}), " | |
| f"meeting seriousness criteria of {', '.join(ae['seriousness_criteria'])}. " | |
| f"Laboratory values over time: {lab_str}. " | |
| f"Notable INR elevation to {labs[-2].get('INR', 'N/A') if len(labs) >= 2 else 'N/A'} was observed prior to the event, " | |
| f"suggesting a potential drug-drug interaction between {case['study_drug']} and warfarin. " | |
| f"Action taken: {case['action_taken']}. " | |
| f"Dechallenge was positive β the event resolved following drug discontinuation. " | |
| f"Outcome at last follow-up: {case['outcome_at_last_followup']}. " | |
| f"Causality assessment: The event is considered probably related to the study drug, " | |
| f"given the temporal relationship, positive dechallenge, and the plausible pharmacokinetic " | |
| f"interaction with warfarin resulting in supratherapeutic INR levels." | |
| ) | |
| return TriageAction( | |
| task_id=TaskID.SAFETY_NARRATIVE_GENERATION, | |
| safety_narrative=SafetyNarrativeAction( | |
| narrative_text=narrative, | |
| causality_assessment=CausalityAssessment.PROBABLY_RELATED, | |
| key_temporal_flags=[ | |
| "INR elevation 2 days prior to event", | |
| "onset after dose initiation day 14", | |
| "positive dechallenge on drug discontinuation", | |
| ], | |
| dechallenge_positive=True, | |
| rechallenge_positive=None, | |
| ), | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN RUNNER | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| def run_heuristic_baseline() -> Dict[str, Any]: | |
| """Run heuristic baseline on all 3 tasks and return scores.""" | |
| env = ClinicalTrialEnvironment() | |
| results: Dict[str, Any] = { | |
| "baseline_type": "heuristic", | |
| "description": "Rule-based keyword matching baseline β establishes lower bound.", | |
| "tasks": {}, | |
| } | |
| # Task 1: AE Triage | |
| env.reset(task_id=TaskID.ADVERSE_EVENT_TRIAGE) | |
| ae_rewards = [] | |
| for case in AE_CASES: | |
| action = _heuristic_ae_triage(case) | |
| result = env.step(action) | |
| ae_rewards.append(_clamp_open_score(float(result.reward))) | |
| if result.done: | |
| break | |
| results["tasks"][TaskID.ADVERSE_EVENT_TRIAGE] = { | |
| "mean_reward": round(_clamp_open_score(sum(ae_rewards) / len(ae_rewards)), 4) if ae_rewards else _clamp_open_score(_SCORE_EPS), | |
| } | |
| # Task 2: Protocol Deviation Audit | |
| env.reset(task_id=TaskID.PROTOCOL_DEVIATION_AUDIT) | |
| dev_rewards = [] | |
| for case in DEVIATION_CASES: | |
| action = _heuristic_deviation_audit(case) | |
| result = env.step(action) | |
| dev_rewards.append(_clamp_open_score(float(result.reward))) | |
| if result.done: | |
| break | |
| results["tasks"][TaskID.PROTOCOL_DEVIATION_AUDIT] = { | |
| "mean_reward": round(_clamp_open_score(sum(dev_rewards) / len(dev_rewards)), 4) if dev_rewards else _clamp_open_score(_SCORE_EPS), | |
| } | |
| # Task 3: Safety Narrative | |
| env.reset(task_id=TaskID.SAFETY_NARRATIVE_GENERATION) | |
| nr_rewards = [] | |
| for case in NARRATIVE_CASES: | |
| action = _heuristic_narrative(case) | |
| result = env.step(action) | |
| nr_rewards.append(_clamp_open_score(float(result.reward))) | |
| if result.done: | |
| break | |
| results["tasks"][TaskID.SAFETY_NARRATIVE_GENERATION] = { | |
| "mean_reward": round(_clamp_open_score(sum(nr_rewards) / len(nr_rewards)), 4) if nr_rewards else _clamp_open_score(_SCORE_EPS), | |
| } | |
| all_means = [v["mean_reward"] for v in results["tasks"].values()] | |
| results["overall_mean_reward"] = round(_clamp_open_score(sum(all_means) / len(all_means)), 4) | |
| return results | |
| if __name__ == "__main__": | |
| import json | |
| results = run_heuristic_baseline() | |
| print(json.dumps(results, indent=2)) |