Spaces:
Sleeping
Sleeping
| """ | |
| Evaluation Metrics for ActiveMedAgent. | |
| Unified metrics across all three datasets: | |
| - MRR (Mean Reciprocal Rank) | |
| - Acquisition Efficiency (normalized improvement) | |
| - Top-1 Accuracy | |
| - Acquisition Precision | |
| - Uncertainty Calibration (ECE-style) | |
| - Information-Theoretic Metrics (entropy, IG, VoI) | |
| - Bootstrap confidence intervals | |
| """ | |
| import logging | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| from scipy import stats | |
| from agent import AgentResult | |
| from datasets.base import MedicalCase | |
| from information_gain import BeliefTrajectory, compute_information_metrics | |
| import config | |
| logger = logging.getLogger(__name__) | |
| class CaseMetrics: | |
| """Metrics for a single case.""" | |
| case_id: str | |
| dataset: str | |
| top1_correct: bool = False | |
| reciprocal_rank: float = 0.0 | |
| ground_truth_rank: int = -1 # 1-indexed rank of correct answer | |
| n_acquired: int = 0 | |
| acquired_channels: list[str] = field(default_factory=list) | |
| committed_early: bool = False | |
| top1_confidence: float = 0.0 # Confidence of the top-ranked diagnosis | |
| acquisition_cost: float = 0.0 | |
| total_case_cost: float = 0.0 | |
| class DatasetMetrics: | |
| """Aggregated metrics for a dataset.""" | |
| dataset: str | |
| n_cases: int | |
| top1_accuracy: float | |
| mrr: float # Mean Reciprocal Rank | |
| top1_accuracy_ci: tuple = (0.0, 0.0) # 95% CI | |
| mrr_ci: tuple = (0.0, 0.0) | |
| mean_channels_acquired: float = 0.0 | |
| early_commit_rate: float = 0.0 | |
| per_channel_request_rate: dict = field(default_factory=dict) | |
| mean_acquisition_cost: float = 0.0 | |
| mean_total_case_cost: float = 0.0 | |
| def compute_reciprocal_rank( | |
| ranking: list[dict], | |
| ground_truth: str, | |
| candidates: list[str], | |
| ) -> float: | |
| """ | |
| Compute reciprocal rank of the ground truth in the agent's ranking. | |
| Returns 1/rank if found, 0 if not found. | |
| """ | |
| if not ranking: | |
| return 0.0 | |
| gt_lower = ground_truth.lower().strip() | |
| for entry in ranking: | |
| name = entry.get("name", "").lower().strip() | |
| rank = entry.get("rank", 999) | |
| # Flexible matching: check substring containment both ways | |
| if gt_lower in name or name in gt_lower: | |
| return 1.0 / rank | |
| # Check if it matches any candidate that matches ground truth | |
| for candidate in candidates: | |
| if ( | |
| gt_lower in candidate.lower() | |
| and (name in candidate.lower() or candidate.lower() in name) | |
| ): | |
| return 1.0 / rank | |
| # Ground truth not found in ranking — return 1/(N+1) | |
| return 1.0 / (len(ranking) + 1) if ranking else 0.0 | |
| def evaluate_single_case( | |
| result: AgentResult, | |
| case: MedicalCase, | |
| ) -> CaseMetrics: | |
| """Evaluate a single agent result against ground truth.""" | |
| ranking = result.final_ranking | |
| gt = case.ground_truth | |
| candidates = case.candidates | |
| rr = compute_reciprocal_rank(ranking, gt, candidates) | |
| top1_correct = rr == 1.0 # RR=1 means correct answer is ranked first | |
| top1_conf = ranking[0]["confidence"] if ranking else 0.0 | |
| # Determine ground truth rank in agent's output | |
| gt_rank = -1 | |
| gt_lower = gt.lower().strip() | |
| for entry in ranking: | |
| name = entry.get("name", "").lower().strip() | |
| if gt_lower in name or name in gt_lower: | |
| gt_rank = entry.get("rank", -1) | |
| break | |
| return CaseMetrics( | |
| case_id=result.case_id, | |
| dataset=result.dataset, | |
| top1_correct=top1_correct, | |
| reciprocal_rank=rr, | |
| ground_truth_rank=gt_rank, | |
| n_acquired=len(result.acquired_channels), | |
| acquired_channels=result.acquired_channels, | |
| committed_early=result.committed_early, | |
| top1_confidence=top1_conf, | |
| acquisition_cost=result.acquisition_cost, | |
| total_case_cost=result.total_case_cost, | |
| ) | |
| def aggregate_metrics( | |
| case_metrics: list[CaseMetrics], | |
| dataset_name: str, | |
| n_bootstrap: int = None, | |
| ) -> DatasetMetrics: | |
| """Aggregate per-case metrics into dataset-level stats with bootstrap CIs.""" | |
| if n_bootstrap is None: | |
| n_bootstrap = config.N_BOOTSTRAP | |
| n = len(case_metrics) | |
| if n == 0: | |
| return DatasetMetrics(dataset=dataset_name, n_cases=0, top1_accuracy=0, mrr=0) | |
| accuracies = np.array([int(cm.top1_correct) for cm in case_metrics]) | |
| rrs = np.array([cm.reciprocal_rank for cm in case_metrics]) | |
| top1_acc = float(np.mean(accuracies)) | |
| mrr = float(np.mean(rrs)) | |
| # Bootstrap CIs | |
| acc_ci = _bootstrap_ci(accuracies, n_bootstrap) | |
| mrr_ci = _bootstrap_ci(rrs, n_bootstrap) | |
| # Channel request rates | |
| channel_counts: dict[str, int] = {} | |
| for cm in case_metrics: | |
| for ch in cm.acquired_channels: | |
| channel_counts[ch] = channel_counts.get(ch, 0) + 1 | |
| channel_rates = {ch: count / n for ch, count in channel_counts.items()} | |
| return DatasetMetrics( | |
| dataset=dataset_name, | |
| n_cases=n, | |
| top1_accuracy=top1_acc, | |
| mrr=mrr, | |
| top1_accuracy_ci=acc_ci, | |
| mrr_ci=mrr_ci, | |
| mean_channels_acquired=float(np.mean([cm.n_acquired for cm in case_metrics])), | |
| early_commit_rate=float(np.mean([int(cm.committed_early) for cm in case_metrics])), | |
| per_channel_request_rate=channel_rates, | |
| mean_acquisition_cost=float(np.mean([cm.acquisition_cost for cm in case_metrics])), | |
| mean_total_case_cost=float(np.mean([cm.total_case_cost for cm in case_metrics])), | |
| ) | |
| def compute_acquisition_efficiency( | |
| mrr_at_k: float, | |
| mrr_passive: float, | |
| mrr_oracle: float, | |
| ) -> float: | |
| """ | |
| Normalized Acquisition Efficiency. | |
| AE(K) = (MRR_K - MRR_passive) / (MRR_oracle - MRR_passive) | |
| Returns 0 if oracle = passive (no room for improvement), | |
| can exceed 1 if active outperforms oracle (shouldn't happen normally). | |
| """ | |
| denom = mrr_oracle - mrr_passive | |
| if abs(denom) < 1e-8: | |
| return 0.0 | |
| return (mrr_at_k - mrr_passive) / denom | |
| def compute_acquisition_precision( | |
| active_results: list[AgentResult], | |
| passive_results: list[AgentResult], | |
| cases: list[MedicalCase], | |
| ) -> dict: | |
| """ | |
| Acquisition Precision: when the agent requests info, does the diagnosis change? | |
| Two sub-metrics: | |
| - request_change_rate: fraction of acquisitions that changed the top-1 diagnosis | |
| - change_correctness: among diagnosis changes, fraction that were improvements | |
| """ | |
| assert len(active_results) == len(passive_results) == len(cases) | |
| total_acquisitions = 0 | |
| diagnosis_changed = 0 | |
| change_improved = 0 | |
| for active, passive, case in zip(active_results, passive_results, cases): | |
| passive_top1 = _get_top1_name(passive.final_ranking) | |
| active_top1 = _get_top1_name(active.final_ranking) | |
| n_acq = len(active.acquired_channels) | |
| if n_acq > 0: | |
| total_acquisitions += 1 | |
| if passive_top1 != active_top1: | |
| diagnosis_changed += 1 | |
| # Did it change to the correct answer? | |
| gt = case.ground_truth.lower().strip() | |
| if gt in active_top1.lower() or active_top1.lower() in gt: | |
| change_improved += 1 | |
| return { | |
| "total_cases_with_acquisition": total_acquisitions, | |
| "request_change_rate": ( | |
| diagnosis_changed / total_acquisitions if total_acquisitions > 0 else 0 | |
| ), | |
| "change_correctness": ( | |
| change_improved / diagnosis_changed if diagnosis_changed > 0 else 0 | |
| ), | |
| } | |
| def compute_prompt_agreement( | |
| results_by_variant: dict[str, list[AgentResult]], | |
| ) -> dict: | |
| """ | |
| Prompt sensitivity analysis: measure agreement across prompt variants. | |
| Returns: | |
| - top1_agreement: fraction of cases where all variants agree on top-1 | |
| - acquisition_agreement: fraction of cases where all variants request | |
| the same first channel | |
| """ | |
| variants = list(results_by_variant.keys()) | |
| if len(variants) < 2: | |
| return {"top1_agreement": 1.0, "acquisition_agreement": 1.0} | |
| # Align by case_id | |
| case_ids = set() | |
| for results in results_by_variant.values(): | |
| case_ids.update(r.case_id for r in results) | |
| by_case: dict[str, dict[str, AgentResult]] = {} | |
| for variant, results in results_by_variant.items(): | |
| for r in results: | |
| if r.case_id not in by_case: | |
| by_case[r.case_id] = {} | |
| by_case[r.case_id][variant] = r | |
| top1_agree_count = 0 | |
| acq_agree_count = 0 | |
| total = 0 | |
| for case_id, variant_results in by_case.items(): | |
| if len(variant_results) < len(variants): | |
| continue # Skip cases not in all variants | |
| total += 1 | |
| # Top-1 agreement | |
| top1s = set() | |
| for vr in variant_results.values(): | |
| top1s.add(_get_top1_name(vr.final_ranking).lower()) | |
| if len(top1s) == 1: | |
| top1_agree_count += 1 | |
| # First acquisition agreement | |
| first_acqs = set() | |
| for vr in variant_results.values(): | |
| if vr.acquired_channels: | |
| first_acqs.add(vr.acquired_channels[0]) | |
| else: | |
| first_acqs.add("_committed_") | |
| if len(first_acqs) == 1: | |
| acq_agree_count += 1 | |
| return { | |
| "top1_agreement": top1_agree_count / total if total > 0 else 0, | |
| "acquisition_agreement": acq_agree_count / total if total > 0 else 0, | |
| "n_cases_compared": total, | |
| } | |
| def compute_regret_analysis( | |
| active_results: list[AgentResult], | |
| oracle_results: list[AgentResult], | |
| cases: list[MedicalCase], | |
| ) -> dict: | |
| """ | |
| Regret Analysis: when the agent gets a case wrong, could a different | |
| acquisition strategy have saved it? | |
| For each case where active got it wrong: | |
| 1. Did the oracle get it right? (recoverable error) | |
| 2. Which channels were available but not requested? (missed channels) | |
| 3. Among recoverable errors, which missing channels correlate most | |
| with oracle success? (high-regret channels) | |
| Returns a rich dict with per-case traces and aggregate statistics. | |
| """ | |
| assert len(active_results) == len(oracle_results) == len(cases) | |
| per_case_regret = [] | |
| n_active_wrong = 0 | |
| n_oracle_right_when_active_wrong = 0 # recoverable | |
| n_both_wrong = 0 # unrecoverable — VLM reasoning bottleneck | |
| missed_channel_counts: dict[str, int] = {} # channels not requested in recoverable cases | |
| missed_channel_total: dict[str, int] = {} # total times a channel was missed (all wrong) | |
| for active, oracle, case in zip(active_results, oracle_results, cases): | |
| active_rr = compute_reciprocal_rank(active.final_ranking, case.ground_truth, case.candidates) | |
| oracle_rr = compute_reciprocal_rank(oracle.final_ranking, case.ground_truth, case.candidates) | |
| active_correct = active_rr == 1.0 | |
| oracle_correct = oracle_rr == 1.0 | |
| if active_correct: | |
| continue # No regret if agent got it right | |
| n_active_wrong += 1 | |
| # Channels available but not acquired | |
| all_requestable = set(case.requestable_channels.keys()) | |
| acquired = set(active.acquired_channels) | |
| missed = all_requestable - acquired | |
| case_entry = { | |
| "case_id": case.case_id, | |
| "ground_truth": case.ground_truth, | |
| "active_top1": _get_top1_name(active.final_ranking), | |
| "oracle_top1": _get_top1_name(oracle.final_ranking), | |
| "active_correct": False, | |
| "oracle_correct": oracle_correct, | |
| "acquired_channels": list(acquired), | |
| "missed_channels": list(missed), | |
| "recoverable": oracle_correct, | |
| "active_rr": active_rr, | |
| "oracle_rr": oracle_rr, | |
| } | |
| for ch in missed: | |
| missed_channel_total[ch] = missed_channel_total.get(ch, 0) + 1 | |
| if oracle_correct: | |
| n_oracle_right_when_active_wrong += 1 | |
| for ch in missed: | |
| missed_channel_counts[ch] = missed_channel_counts.get(ch, 0) + 1 | |
| else: | |
| n_both_wrong += 1 | |
| per_case_regret.append(case_entry) | |
| # Compute per-channel regret score: how often a missed channel appears | |
| # in recoverable errors vs all errors | |
| channel_regret_scores = {} | |
| for ch in set(list(missed_channel_counts.keys()) + list(missed_channel_total.keys())): | |
| recoverable_miss = missed_channel_counts.get(ch, 0) | |
| total_miss = missed_channel_total.get(ch, 0) | |
| # Regret score: fraction of times this channel was missed AND oracle succeeded | |
| channel_regret_scores[ch] = { | |
| "missed_in_recoverable": recoverable_miss, | |
| "missed_in_all_wrong": total_miss, | |
| "regret_rate": recoverable_miss / total_miss if total_miss > 0 else 0.0, | |
| } | |
| # Sort channels by regret rate descending | |
| sorted_channels = sorted( | |
| channel_regret_scores.items(), | |
| key=lambda x: (-x[1]["regret_rate"], -x[1]["missed_in_recoverable"]), | |
| ) | |
| return { | |
| "n_cases": len(cases), | |
| "n_active_wrong": n_active_wrong, | |
| "n_recoverable": n_oracle_right_when_active_wrong, | |
| "n_unrecoverable": n_both_wrong, | |
| "recovery_rate": ( | |
| n_oracle_right_when_active_wrong / n_active_wrong | |
| if n_active_wrong > 0 else 0.0 | |
| ), | |
| "error_rate": n_active_wrong / len(cases) if cases else 0.0, | |
| "channel_regret_scores": dict(sorted_channels), | |
| "per_case_regret": per_case_regret, | |
| "summary": { | |
| "total_errors": n_active_wrong, | |
| "recoverable_pct": ( | |
| n_oracle_right_when_active_wrong / n_active_wrong * 100 | |
| if n_active_wrong > 0 else 0.0 | |
| ), | |
| "unrecoverable_pct": ( | |
| n_both_wrong / n_active_wrong * 100 | |
| if n_active_wrong > 0 else 0.0 | |
| ), | |
| "highest_regret_channel": sorted_channels[0][0] if sorted_channels else None, | |
| }, | |
| } | |
| def compute_info_theoretic_metrics( | |
| results: list[AgentResult], | |
| ) -> dict: | |
| """ | |
| Compute information-theoretic metrics from belief trajectories. | |
| Extracts BeliefTrajectory objects from AgentResults and computes | |
| aggregate entropy, information gain, and per-channel value metrics. | |
| """ | |
| trajectories = [ | |
| r.belief_trajectory for r in results | |
| if r.belief_trajectory and r.belief_trajectory.states | |
| ] | |
| if not trajectories: | |
| return {"n_cases_with_trajectory": 0} | |
| metrics = compute_information_metrics(trajectories) | |
| metrics["n_cases_with_trajectory"] = len(trajectories) | |
| return metrics | |
| def _get_top1_name(ranking: list[dict]) -> str: | |
| """Get the name of the top-ranked diagnosis.""" | |
| if not ranking: | |
| return "" | |
| return ranking[0].get("name", "") | |
| def _bootstrap_ci( | |
| values: np.ndarray, n_bootstrap: int = 1000, ci: float = 0.95 | |
| ) -> tuple[float, float]: | |
| """Compute bootstrap confidence interval.""" | |
| if len(values) == 0: | |
| return (0.0, 0.0) | |
| rng = np.random.RandomState(config.SEED) | |
| boot_means = [] | |
| for _ in range(n_bootstrap): | |
| sample = rng.choice(values, size=len(values), replace=True) | |
| boot_means.append(np.mean(sample)) | |
| alpha = (1 - ci) / 2 | |
| lower = float(np.percentile(boot_means, alpha * 100)) | |
| upper = float(np.percentile(boot_means, (1 - alpha) * 100)) | |
| return (lower, upper) | |