meta-hack / scripts /heuristic_baseline.py
vvinayakkk's picture
Return score-only baseline and grader payloads
ec2d3a1
"""
Heuristic Baseline
===================
A deterministic rule-based baseline agent for all 3 tasks.
Used by the /baseline endpoint for fast, reproducible scoring.
The LLM-based baseline (GroqCloud API) is in scripts/baseline_inference.py.
"""
from __future__ import annotations
from typing import Any, Dict
from models import (
AdverseEventTriageAction,
CausalityAssessment,
DeviationType,
ProtocolDeviationAction,
SafetyNarrativeAction,
TaskID,
TriageAction,
)
from server.environment import ClinicalTrialEnvironment
from tasks.case_bank import AE_CASES, DEVIATION_CASES, NARRATIVE_CASES
_SCORE_EPS = 1e-3
def _clamp_open_score(value: float) -> float:
return max(_SCORE_EPS, min(1.0 - _SCORE_EPS, float(value)))
# ─────────────────────────────────────────
# HEURISTIC AGENTS (rule-based)
# ─────────────────────────────────────────
def _heuristic_ae_triage(case: Dict[str, Any]) -> TriageAction:
"""
Simple keyword/lab-value heuristic for AE triage.
Gets ~70% on easy cases, ~45% on hard cases.
"""
narrative = case["narrative"].lower()
ae_desc = case["ae_description"].lower()
# Severity heuristic
if any(kw in narrative for kw in ["fatal", "death", "died"]):
severity = "fatal"
timeline = "7-day"
is_serious = True
elif any(kw in narrative for kw in ["stemi", "cardiac arrest", "icu", "intensive care", "life-threatening"]):
severity = "life_threatening"
timeline = "7-day"
is_serious = True
elif any(kw in narrative for kw in ["hospitali", "encephalopathy", "grade 3", "severe"]):
severity = "severe"
timeline = "15-day"
is_serious = True
elif any(kw in narrative for kw in ["grade 2", "moderate", "nausea"]):
severity = "moderate"
timeline = "routine"
is_serious = False
else:
severity = "mild"
timeline = "routine"
is_serious = False
# MedDRA SOC heuristic
if any(kw in ae_desc for kw in ["cardiac", "myocardial", "stemi", "heart"]):
soc = "Cardiac disorders"
pt = "Myocardial infarction"
elif any(kw in ae_desc for kw in ["nausea", "vomiting", "gastrointestinal"]):
soc = "Gastrointestinal disorders"
pt = "Nausea"
elif any(kw in ae_desc for kw in ["encephalopathy", "neurological", "nervous"]):
soc = "Nervous system disorders"
pt = "Encephalopathy"
else:
soc = "General disorders"
pt = "Adverse event"
return TriageAction(
task_id=TaskID.ADVERSE_EVENT_TRIAGE,
ae_triage=AdverseEventTriageAction(
severity_classification=severity,
reporting_timeline=timeline,
meddra_soc=soc,
meddra_preferred_term=pt,
is_serious=is_serious,
rationale="Heuristic baseline classification based on keyword matching.",
),
)
def _heuristic_deviation_audit(case: Dict[str, Any]) -> TriageAction:
"""Heuristic deviation audit based on finding severity keywords."""
findings = case["findings"]
high_risk_keywords = [
"eligibility", "blinding", "unblind", "sae report", "integrity",
"data", "enroll", "hospitali", "ip accountability", "unaccounted",
"endpoint", "consent", "delegate"
]
risk_count = 0
flagged_ids = []
for f in findings:
desc_lower = f["description"].lower()
cat_lower = f["category"].lower()
if any(kw in desc_lower or kw in cat_lower for kw in high_risk_keywords):
risk_count += 1
flagged_ids.append(f["id"])
is_major = risk_count >= 2
capa = risk_count >= 2
risk_score = min(10.0, risk_count * 2.5 + case.get("prior_deviations", 0) * 0.3)
return TriageAction(
task_id=TaskID.PROTOCOL_DEVIATION_AUDIT,
deviation_audit=ProtocolDeviationAction(
deviation_type=DeviationType.MAJOR if is_major else DeviationType.MINOR,
capa_required=capa,
site_risk_score=risk_score,
flagged_finding_ids=flagged_ids,
recommended_action="Immediate escalation to CRA and Sponsor QA team for review." if is_major else "Document and include in next monitoring report.",
),
)
def _heuristic_narrative(case: Dict[str, Any]) -> TriageAction:
"""Heuristic narrative generation β€” structured template filling."""
dem = case["patient_demographics"]
ae = case["adverse_event"]
labs = case["lab_values_timeline"]
conmeds = case["concomitant_medications"]
conmed_str = "; ".join(f"{m['name']} {m['dose']}" for m in conmeds)
lab_str = "; ".join(
f"{l['date']}: INR {l.get('INR', 'N/A')}, Hgb {l.get('Hgb_g_dL', 'N/A')} g/dL"
for l in labs
)
narrative = (
f"A {dem['age']}-year-old {dem['sex']} patient enrolled in a clinical study "
f"received {case['study_drug']}. "
f"Relevant medical history includes: {'; '.join(case['medical_history'])}. "
f"Concomitant medications: {conmed_str}. "
f"On {ae['onset_date']}, the patient developed {ae['term']} (MedDRA: {ae['meddra_soc']} / {ae['meddra_pt']}), "
f"meeting seriousness criteria of {', '.join(ae['seriousness_criteria'])}. "
f"Laboratory values over time: {lab_str}. "
f"Notable INR elevation to {labs[-2].get('INR', 'N/A') if len(labs) >= 2 else 'N/A'} was observed prior to the event, "
f"suggesting a potential drug-drug interaction between {case['study_drug']} and warfarin. "
f"Action taken: {case['action_taken']}. "
f"Dechallenge was positive β€” the event resolved following drug discontinuation. "
f"Outcome at last follow-up: {case['outcome_at_last_followup']}. "
f"Causality assessment: The event is considered probably related to the study drug, "
f"given the temporal relationship, positive dechallenge, and the plausible pharmacokinetic "
f"interaction with warfarin resulting in supratherapeutic INR levels."
)
return TriageAction(
task_id=TaskID.SAFETY_NARRATIVE_GENERATION,
safety_narrative=SafetyNarrativeAction(
narrative_text=narrative,
causality_assessment=CausalityAssessment.PROBABLY_RELATED,
key_temporal_flags=[
"INR elevation 2 days prior to event",
"onset after dose initiation day 14",
"positive dechallenge on drug discontinuation",
],
dechallenge_positive=True,
rechallenge_positive=None,
),
)
# ─────────────────────────────────────────
# MAIN RUNNER
# ─────────────────────────────────────────
def run_heuristic_baseline() -> Dict[str, Any]:
"""Run heuristic baseline on all 3 tasks and return scores."""
env = ClinicalTrialEnvironment()
results: Dict[str, Any] = {
"baseline_type": "heuristic",
"description": "Rule-based keyword matching baseline β€” establishes lower bound.",
"tasks": {},
}
# Task 1: AE Triage
env.reset(task_id=TaskID.ADVERSE_EVENT_TRIAGE)
ae_rewards = []
for case in AE_CASES:
action = _heuristic_ae_triage(case)
result = env.step(action)
ae_rewards.append(_clamp_open_score(float(result.reward)))
if result.done:
break
results["tasks"][TaskID.ADVERSE_EVENT_TRIAGE] = {
"mean_reward": round(_clamp_open_score(sum(ae_rewards) / len(ae_rewards)), 4) if ae_rewards else _clamp_open_score(_SCORE_EPS),
}
# Task 2: Protocol Deviation Audit
env.reset(task_id=TaskID.PROTOCOL_DEVIATION_AUDIT)
dev_rewards = []
for case in DEVIATION_CASES:
action = _heuristic_deviation_audit(case)
result = env.step(action)
dev_rewards.append(_clamp_open_score(float(result.reward)))
if result.done:
break
results["tasks"][TaskID.PROTOCOL_DEVIATION_AUDIT] = {
"mean_reward": round(_clamp_open_score(sum(dev_rewards) / len(dev_rewards)), 4) if dev_rewards else _clamp_open_score(_SCORE_EPS),
}
# Task 3: Safety Narrative
env.reset(task_id=TaskID.SAFETY_NARRATIVE_GENERATION)
nr_rewards = []
for case in NARRATIVE_CASES:
action = _heuristic_narrative(case)
result = env.step(action)
nr_rewards.append(_clamp_open_score(float(result.reward)))
if result.done:
break
results["tasks"][TaskID.SAFETY_NARRATIVE_GENERATION] = {
"mean_reward": round(_clamp_open_score(sum(nr_rewards) / len(nr_rewards)), 4) if nr_rewards else _clamp_open_score(_SCORE_EPS),
}
all_means = [v["mean_reward"] for v in results["tasks"].values()]
results["overall_mean_reward"] = round(_clamp_open_score(sum(all_means) / len(all_means)), 4)
return results
if __name__ == "__main__":
import json
results = run_heuristic_baseline()
print(json.dumps(results, indent=2))