yuxbox's picture
Upload folder using huggingface_hub
a1aaf30 verified
"""
Evaluation Metrics for ActiveMedAgent.
Unified metrics across all three datasets:
- MRR (Mean Reciprocal Rank)
- Acquisition Efficiency (normalized improvement)
- Top-1 Accuracy
- Acquisition Precision
- Uncertainty Calibration (ECE-style)
- Information-Theoretic Metrics (entropy, IG, VoI)
- Bootstrap confidence intervals
"""
import logging
from dataclasses import dataclass, field
import numpy as np
from scipy import stats
from agent import AgentResult
from datasets.base import MedicalCase
from information_gain import BeliefTrajectory, compute_information_metrics
import config
logger = logging.getLogger(__name__)
@dataclass
class CaseMetrics:
"""Metrics for a single case."""
case_id: str
dataset: str
top1_correct: bool = False
reciprocal_rank: float = 0.0
ground_truth_rank: int = -1 # 1-indexed rank of correct answer
n_acquired: int = 0
acquired_channels: list[str] = field(default_factory=list)
committed_early: bool = False
top1_confidence: float = 0.0 # Confidence of the top-ranked diagnosis
acquisition_cost: float = 0.0
total_case_cost: float = 0.0
@dataclass
class DatasetMetrics:
"""Aggregated metrics for a dataset."""
dataset: str
n_cases: int
top1_accuracy: float
mrr: float # Mean Reciprocal Rank
top1_accuracy_ci: tuple = (0.0, 0.0) # 95% CI
mrr_ci: tuple = (0.0, 0.0)
mean_channels_acquired: float = 0.0
early_commit_rate: float = 0.0
per_channel_request_rate: dict = field(default_factory=dict)
mean_acquisition_cost: float = 0.0
mean_total_case_cost: float = 0.0
def compute_reciprocal_rank(
ranking: list[dict],
ground_truth: str,
candidates: list[str],
) -> float:
"""
Compute reciprocal rank of the ground truth in the agent's ranking.
Returns 1/rank if found, 0 if not found.
"""
if not ranking:
return 0.0
gt_lower = ground_truth.lower().strip()
for entry in ranking:
name = entry.get("name", "").lower().strip()
rank = entry.get("rank", 999)
# Flexible matching: check substring containment both ways
if gt_lower in name or name in gt_lower:
return 1.0 / rank
# Check if it matches any candidate that matches ground truth
for candidate in candidates:
if (
gt_lower in candidate.lower()
and (name in candidate.lower() or candidate.lower() in name)
):
return 1.0 / rank
# Ground truth not found in ranking — return 1/(N+1)
return 1.0 / (len(ranking) + 1) if ranking else 0.0
def evaluate_single_case(
result: AgentResult,
case: MedicalCase,
) -> CaseMetrics:
"""Evaluate a single agent result against ground truth."""
ranking = result.final_ranking
gt = case.ground_truth
candidates = case.candidates
rr = compute_reciprocal_rank(ranking, gt, candidates)
top1_correct = rr == 1.0 # RR=1 means correct answer is ranked first
top1_conf = ranking[0]["confidence"] if ranking else 0.0
# Determine ground truth rank in agent's output
gt_rank = -1
gt_lower = gt.lower().strip()
for entry in ranking:
name = entry.get("name", "").lower().strip()
if gt_lower in name or name in gt_lower:
gt_rank = entry.get("rank", -1)
break
return CaseMetrics(
case_id=result.case_id,
dataset=result.dataset,
top1_correct=top1_correct,
reciprocal_rank=rr,
ground_truth_rank=gt_rank,
n_acquired=len(result.acquired_channels),
acquired_channels=result.acquired_channels,
committed_early=result.committed_early,
top1_confidence=top1_conf,
acquisition_cost=result.acquisition_cost,
total_case_cost=result.total_case_cost,
)
def aggregate_metrics(
case_metrics: list[CaseMetrics],
dataset_name: str,
n_bootstrap: int = None,
) -> DatasetMetrics:
"""Aggregate per-case metrics into dataset-level stats with bootstrap CIs."""
if n_bootstrap is None:
n_bootstrap = config.N_BOOTSTRAP
n = len(case_metrics)
if n == 0:
return DatasetMetrics(dataset=dataset_name, n_cases=0, top1_accuracy=0, mrr=0)
accuracies = np.array([int(cm.top1_correct) for cm in case_metrics])
rrs = np.array([cm.reciprocal_rank for cm in case_metrics])
top1_acc = float(np.mean(accuracies))
mrr = float(np.mean(rrs))
# Bootstrap CIs
acc_ci = _bootstrap_ci(accuracies, n_bootstrap)
mrr_ci = _bootstrap_ci(rrs, n_bootstrap)
# Channel request rates
channel_counts: dict[str, int] = {}
for cm in case_metrics:
for ch in cm.acquired_channels:
channel_counts[ch] = channel_counts.get(ch, 0) + 1
channel_rates = {ch: count / n for ch, count in channel_counts.items()}
return DatasetMetrics(
dataset=dataset_name,
n_cases=n,
top1_accuracy=top1_acc,
mrr=mrr,
top1_accuracy_ci=acc_ci,
mrr_ci=mrr_ci,
mean_channels_acquired=float(np.mean([cm.n_acquired for cm in case_metrics])),
early_commit_rate=float(np.mean([int(cm.committed_early) for cm in case_metrics])),
per_channel_request_rate=channel_rates,
mean_acquisition_cost=float(np.mean([cm.acquisition_cost for cm in case_metrics])),
mean_total_case_cost=float(np.mean([cm.total_case_cost for cm in case_metrics])),
)
def compute_acquisition_efficiency(
mrr_at_k: float,
mrr_passive: float,
mrr_oracle: float,
) -> float:
"""
Normalized Acquisition Efficiency.
AE(K) = (MRR_K - MRR_passive) / (MRR_oracle - MRR_passive)
Returns 0 if oracle = passive (no room for improvement),
can exceed 1 if active outperforms oracle (shouldn't happen normally).
"""
denom = mrr_oracle - mrr_passive
if abs(denom) < 1e-8:
return 0.0
return (mrr_at_k - mrr_passive) / denom
def compute_acquisition_precision(
active_results: list[AgentResult],
passive_results: list[AgentResult],
cases: list[MedicalCase],
) -> dict:
"""
Acquisition Precision: when the agent requests info, does the diagnosis change?
Two sub-metrics:
- request_change_rate: fraction of acquisitions that changed the top-1 diagnosis
- change_correctness: among diagnosis changes, fraction that were improvements
"""
assert len(active_results) == len(passive_results) == len(cases)
total_acquisitions = 0
diagnosis_changed = 0
change_improved = 0
for active, passive, case in zip(active_results, passive_results, cases):
passive_top1 = _get_top1_name(passive.final_ranking)
active_top1 = _get_top1_name(active.final_ranking)
n_acq = len(active.acquired_channels)
if n_acq > 0:
total_acquisitions += 1
if passive_top1 != active_top1:
diagnosis_changed += 1
# Did it change to the correct answer?
gt = case.ground_truth.lower().strip()
if gt in active_top1.lower() or active_top1.lower() in gt:
change_improved += 1
return {
"total_cases_with_acquisition": total_acquisitions,
"request_change_rate": (
diagnosis_changed / total_acquisitions if total_acquisitions > 0 else 0
),
"change_correctness": (
change_improved / diagnosis_changed if diagnosis_changed > 0 else 0
),
}
def compute_prompt_agreement(
results_by_variant: dict[str, list[AgentResult]],
) -> dict:
"""
Prompt sensitivity analysis: measure agreement across prompt variants.
Returns:
- top1_agreement: fraction of cases where all variants agree on top-1
- acquisition_agreement: fraction of cases where all variants request
the same first channel
"""
variants = list(results_by_variant.keys())
if len(variants) < 2:
return {"top1_agreement": 1.0, "acquisition_agreement": 1.0}
# Align by case_id
case_ids = set()
for results in results_by_variant.values():
case_ids.update(r.case_id for r in results)
by_case: dict[str, dict[str, AgentResult]] = {}
for variant, results in results_by_variant.items():
for r in results:
if r.case_id not in by_case:
by_case[r.case_id] = {}
by_case[r.case_id][variant] = r
top1_agree_count = 0
acq_agree_count = 0
total = 0
for case_id, variant_results in by_case.items():
if len(variant_results) < len(variants):
continue # Skip cases not in all variants
total += 1
# Top-1 agreement
top1s = set()
for vr in variant_results.values():
top1s.add(_get_top1_name(vr.final_ranking).lower())
if len(top1s) == 1:
top1_agree_count += 1
# First acquisition agreement
first_acqs = set()
for vr in variant_results.values():
if vr.acquired_channels:
first_acqs.add(vr.acquired_channels[0])
else:
first_acqs.add("_committed_")
if len(first_acqs) == 1:
acq_agree_count += 1
return {
"top1_agreement": top1_agree_count / total if total > 0 else 0,
"acquisition_agreement": acq_agree_count / total if total > 0 else 0,
"n_cases_compared": total,
}
def compute_regret_analysis(
active_results: list[AgentResult],
oracle_results: list[AgentResult],
cases: list[MedicalCase],
) -> dict:
"""
Regret Analysis: when the agent gets a case wrong, could a different
acquisition strategy have saved it?
For each case where active got it wrong:
1. Did the oracle get it right? (recoverable error)
2. Which channels were available but not requested? (missed channels)
3. Among recoverable errors, which missing channels correlate most
with oracle success? (high-regret channels)
Returns a rich dict with per-case traces and aggregate statistics.
"""
assert len(active_results) == len(oracle_results) == len(cases)
per_case_regret = []
n_active_wrong = 0
n_oracle_right_when_active_wrong = 0 # recoverable
n_both_wrong = 0 # unrecoverable — VLM reasoning bottleneck
missed_channel_counts: dict[str, int] = {} # channels not requested in recoverable cases
missed_channel_total: dict[str, int] = {} # total times a channel was missed (all wrong)
for active, oracle, case in zip(active_results, oracle_results, cases):
active_rr = compute_reciprocal_rank(active.final_ranking, case.ground_truth, case.candidates)
oracle_rr = compute_reciprocal_rank(oracle.final_ranking, case.ground_truth, case.candidates)
active_correct = active_rr == 1.0
oracle_correct = oracle_rr == 1.0
if active_correct:
continue # No regret if agent got it right
n_active_wrong += 1
# Channels available but not acquired
all_requestable = set(case.requestable_channels.keys())
acquired = set(active.acquired_channels)
missed = all_requestable - acquired
case_entry = {
"case_id": case.case_id,
"ground_truth": case.ground_truth,
"active_top1": _get_top1_name(active.final_ranking),
"oracle_top1": _get_top1_name(oracle.final_ranking),
"active_correct": False,
"oracle_correct": oracle_correct,
"acquired_channels": list(acquired),
"missed_channels": list(missed),
"recoverable": oracle_correct,
"active_rr": active_rr,
"oracle_rr": oracle_rr,
}
for ch in missed:
missed_channel_total[ch] = missed_channel_total.get(ch, 0) + 1
if oracle_correct:
n_oracle_right_when_active_wrong += 1
for ch in missed:
missed_channel_counts[ch] = missed_channel_counts.get(ch, 0) + 1
else:
n_both_wrong += 1
per_case_regret.append(case_entry)
# Compute per-channel regret score: how often a missed channel appears
# in recoverable errors vs all errors
channel_regret_scores = {}
for ch in set(list(missed_channel_counts.keys()) + list(missed_channel_total.keys())):
recoverable_miss = missed_channel_counts.get(ch, 0)
total_miss = missed_channel_total.get(ch, 0)
# Regret score: fraction of times this channel was missed AND oracle succeeded
channel_regret_scores[ch] = {
"missed_in_recoverable": recoverable_miss,
"missed_in_all_wrong": total_miss,
"regret_rate": recoverable_miss / total_miss if total_miss > 0 else 0.0,
}
# Sort channels by regret rate descending
sorted_channels = sorted(
channel_regret_scores.items(),
key=lambda x: (-x[1]["regret_rate"], -x[1]["missed_in_recoverable"]),
)
return {
"n_cases": len(cases),
"n_active_wrong": n_active_wrong,
"n_recoverable": n_oracle_right_when_active_wrong,
"n_unrecoverable": n_both_wrong,
"recovery_rate": (
n_oracle_right_when_active_wrong / n_active_wrong
if n_active_wrong > 0 else 0.0
),
"error_rate": n_active_wrong / len(cases) if cases else 0.0,
"channel_regret_scores": dict(sorted_channels),
"per_case_regret": per_case_regret,
"summary": {
"total_errors": n_active_wrong,
"recoverable_pct": (
n_oracle_right_when_active_wrong / n_active_wrong * 100
if n_active_wrong > 0 else 0.0
),
"unrecoverable_pct": (
n_both_wrong / n_active_wrong * 100
if n_active_wrong > 0 else 0.0
),
"highest_regret_channel": sorted_channels[0][0] if sorted_channels else None,
},
}
def compute_info_theoretic_metrics(
results: list[AgentResult],
) -> dict:
"""
Compute information-theoretic metrics from belief trajectories.
Extracts BeliefTrajectory objects from AgentResults and computes
aggregate entropy, information gain, and per-channel value metrics.
"""
trajectories = [
r.belief_trajectory for r in results
if r.belief_trajectory and r.belief_trajectory.states
]
if not trajectories:
return {"n_cases_with_trajectory": 0}
metrics = compute_information_metrics(trajectories)
metrics["n_cases_with_trajectory"] = len(trajectories)
return metrics
def _get_top1_name(ranking: list[dict]) -> str:
"""Get the name of the top-ranked diagnosis."""
if not ranking:
return ""
return ranking[0].get("name", "")
def _bootstrap_ci(
values: np.ndarray, n_bootstrap: int = 1000, ci: float = 0.95
) -> tuple[float, float]:
"""Compute bootstrap confidence interval."""
if len(values) == 0:
return (0.0, 0.0)
rng = np.random.RandomState(config.SEED)
boot_means = []
for _ in range(n_bootstrap):
sample = rng.choice(values, size=len(values), replace=True)
boot_means.append(np.mean(sample))
alpha = (1 - ci) / 2
lower = float(np.percentile(boot_means, alpha * 100))
upper = float(np.percentile(boot_means, (1 - alpha) * 100))
return (lower, upper)