Spaces:
Sleeping
Sleeping
| """ | |
| Cross-dataset analysis and figure generation. | |
| Produces the key figures for the paper: | |
| 1. Acquisition Efficiency curves (all 3 datasets, shared y-axis) | |
| 2. Per-channel request frequency heatmap | |
| 3. Prompt sensitivity agreement matrix | |
| 4. OLIVES biomarker-tier acquisition analysis | |
| 5. NEJM difficulty-vs-acquisition scatter | |
| """ | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from dataclasses import asdict | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import seaborn as sns | |
| from scipy import stats | |
| from agent import AgentResult | |
| from datasets.base import MedicalCase | |
| from evaluation import ( | |
| CaseMetrics, | |
| DatasetMetrics, | |
| evaluate_single_case, | |
| aggregate_metrics, | |
| compute_acquisition_efficiency, | |
| compute_acquisition_precision, | |
| compute_prompt_agreement, | |
| compute_regret_analysis, | |
| ) | |
| import config | |
| matplotlib.rcParams["font.family"] = "serif" | |
| matplotlib.rcParams["font.size"] = 11 | |
| logger = logging.getLogger(__name__) | |
| class ExperimentAnalyzer: | |
| """Analyze and visualize results across all experiments.""" | |
| def __init__(self, results_dir: Path = None): | |
| self.results_dir = results_dir or config.RESULTS_DIR | |
| self.figures_dir = self.results_dir / "figures" | |
| self.figures_dir.mkdir(parents=True, exist_ok=True) | |
| def load_results(self, experiment_name: str) -> dict: | |
| """Load saved experiment results.""" | |
| path = self.results_dir / f"{experiment_name}.json" | |
| if not path.exists(): | |
| logger.error(f"Results file not found: {path}") | |
| return {} | |
| with open(path) as f: | |
| return json.load(f) | |
| def save_results(self, data: dict, experiment_name: str): | |
| """Save experiment results.""" | |
| path = self.results_dir / f"{experiment_name}.json" | |
| with open(path, "w") as f: | |
| json.dump(data, f, indent=2, default=str) | |
| logger.info(f"Results saved to {path}") | |
| # ================================================================ | |
| # Figure 1: Acquisition Efficiency Curves | |
| # ================================================================ | |
| def plot_acquisition_efficiency( | |
| self, | |
| results_by_dataset: dict[str, dict[int, DatasetMetrics]], | |
| passive_metrics: dict[str, DatasetMetrics], | |
| oracle_metrics: dict[str, DatasetMetrics], | |
| save_name: str = "fig1_acquisition_efficiency", | |
| ): | |
| """ | |
| Main result figure: normalized acquisition efficiency vs budget K. | |
| Args: | |
| results_by_dataset: {dataset_name: {K: DatasetMetrics}} | |
| passive_metrics: {dataset_name: DatasetMetrics} at K=0 | |
| oracle_metrics: {dataset_name: DatasetMetrics} with all channels | |
| """ | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4.5)) | |
| colors = {"midas": "#E07A5F", "nejm": "#3D405B", "olives": "#81B29A"} | |
| markers = {"midas": "o", "nejm": "s", "olives": "D"} | |
| labels = {"midas": "MIDAS (Dermatology)", "nejm": "NEJM (Multi-Specialty)", | |
| "olives": "OLIVES (Ophthalmology)"} | |
| # Left panel: Raw MRR vs K | |
| ax = axes[0] | |
| for ds_name in ["midas", "nejm", "olives"]: | |
| if ds_name not in results_by_dataset: | |
| continue | |
| ks = sorted(results_by_dataset[ds_name].keys()) | |
| mrrs = [results_by_dataset[ds_name][k].mrr for k in ks] | |
| cis = [results_by_dataset[ds_name][k].mrr_ci for k in ks] | |
| # Add passive at K=0 | |
| all_k = [0] + list(ks) | |
| all_mrr = [passive_metrics[ds_name].mrr] + mrrs | |
| all_lower = [passive_metrics[ds_name].mrr_ci[0]] + [c[0] for c in cis] | |
| all_upper = [passive_metrics[ds_name].mrr_ci[1]] + [c[1] for c in cis] | |
| ax.plot(all_k, all_mrr, color=colors[ds_name], marker=markers[ds_name], | |
| label=labels[ds_name], linewidth=2, markersize=7) | |
| ax.fill_between(all_k, all_lower, all_upper, alpha=0.15, color=colors[ds_name]) | |
| # Oracle line | |
| ax.axhline(y=oracle_metrics[ds_name].mrr, color=colors[ds_name], | |
| linestyle="--", alpha=0.4, linewidth=1) | |
| ax.set_xlabel("Acquisition Budget (K)") | |
| ax.set_ylabel("Mean Reciprocal Rank (MRR)") | |
| ax.set_title("(a) Diagnostic Quality vs. Budget") | |
| ax.legend(fontsize=9) | |
| ax.set_xticks(range(max(4, max(max(r.keys()) for r in results_by_dataset.values()) + 1))) | |
| ax.grid(True, alpha=0.3) | |
| # Right panel: Normalized Acquisition Efficiency | |
| ax = axes[1] | |
| for ds_name in ["midas", "nejm", "olives"]: | |
| if ds_name not in results_by_dataset: | |
| continue | |
| ks = sorted(results_by_dataset[ds_name].keys()) | |
| effs = [] | |
| for k in ks: | |
| ae = compute_acquisition_efficiency( | |
| results_by_dataset[ds_name][k].mrr, | |
| passive_metrics[ds_name].mrr, | |
| oracle_metrics[ds_name].mrr, | |
| ) | |
| effs.append(ae) | |
| all_k = [0] + list(ks) | |
| all_eff = [0.0] + effs | |
| ax.plot(all_k, all_eff, color=colors[ds_name], marker=markers[ds_name], | |
| label=labels[ds_name], linewidth=2, markersize=7) | |
| ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5, linewidth=1, | |
| label="Oracle ceiling") | |
| ax.set_xlabel("Acquisition Budget (K)") | |
| ax.set_ylabel("Acquisition Efficiency") | |
| ax.set_title("(b) Normalized Efficiency") | |
| ax.legend(fontsize=9) | |
| ax.set_ylim(-0.05, 1.15) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| save_path = self.figures_dir / f"{save_name}.pdf" | |
| fig.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| logger.info(f"Saved figure: {save_path}") | |
| # ================================================================ | |
| # Figure 2: Per-Channel Request Frequency | |
| # ================================================================ | |
| def plot_channel_request_heatmap( | |
| self, | |
| results_by_dataset: dict[str, list[AgentResult]], | |
| save_name: str = "fig2_channel_requests", | |
| ): | |
| """Heatmap showing which channels the agent requests most, by dataset.""" | |
| fig, axes = plt.subplots(1, 3, figsize=(14, 4)) | |
| dataset_names = ["midas", "nejm", "olives"] | |
| titles = ["MIDAS", "NEJM", "OLIVES"] | |
| for idx, (ds_name, title) in enumerate(zip(dataset_names, titles)): | |
| if ds_name not in results_by_dataset: | |
| continue | |
| results = results_by_dataset[ds_name] | |
| # Count first-request frequency | |
| first_requests: dict[str, int] = {} | |
| for r in results: | |
| if r.acquired_channels: | |
| ch = r.acquired_channels[0] | |
| first_requests[ch] = first_requests.get(ch, 0) + 1 | |
| # Count overall request frequency | |
| all_requests: dict[str, int] = {} | |
| for r in results: | |
| for ch in r.acquired_channels: | |
| all_requests[ch] = all_requests.get(ch, 0) + 1 | |
| if not all_requests: | |
| continue | |
| channels = sorted(all_requests.keys()) | |
| n = len(results) | |
| ax = axes[idx] | |
| data = np.array([ | |
| [first_requests.get(ch, 0) / n for ch in channels], | |
| [all_requests.get(ch, 0) / n for ch in channels], | |
| ]) | |
| sns.heatmap( | |
| data, | |
| ax=ax, | |
| xticklabels=[ch.replace("_", "\n") for ch in channels], | |
| yticklabels=["First\nRequest", "Any\nRequest"], | |
| annot=True, | |
| fmt=".2f", | |
| cmap="YlOrRd", | |
| vmin=0, | |
| vmax=1, | |
| cbar_kws={"shrink": 0.8}, | |
| ) | |
| ax.set_title(title) | |
| plt.tight_layout() | |
| save_path = self.figures_dir / f"{save_name}.pdf" | |
| fig.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| logger.info(f"Saved figure: {save_path}") | |
| # ================================================================ | |
| # Figure 3: OLIVES Biomarker Tier Analysis | |
| # ================================================================ | |
| def plot_olives_biomarker_tiers( | |
| self, | |
| results: list[AgentResult], | |
| cases: list[MedicalCase], | |
| save_name: str = "fig3_olives_biomarker_tiers", | |
| ): | |
| """ | |
| For OLIVES: does the agent request OCT more for OCT-dependent | |
| biomarkers than for fundus-visible ones? | |
| """ | |
| oct_request_by_tier: dict[str, list[bool]] = { | |
| "fundus_visible": [], | |
| "oct_dependent": [], | |
| } | |
| for result, case in zip(results, cases): | |
| if case.dataset != "olives": | |
| continue | |
| tier_labels = case.metadata.get("biomarker_tier_labels", {}) | |
| requested_oct = "oct_scan" in result.acquired_channels | |
| # For cases where the eye has fundus-visible biomarkers | |
| if tier_labels.get("fundus_visible"): | |
| oct_request_by_tier["fundus_visible"].append(requested_oct) | |
| # For cases where the eye has OCT-dependent biomarkers | |
| if tier_labels.get("oct_dependent"): | |
| oct_request_by_tier["oct_dependent"].append(requested_oct) | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| tiers = ["fundus_visible", "oct_dependent"] | |
| tier_labels = ["Fundus-Visible\nBiomarkers", "OCT-Dependent\nBiomarkers"] | |
| rates = [] | |
| cis_lower = [] | |
| cis_upper = [] | |
| for tier in tiers: | |
| vals = oct_request_by_tier.get(tier, []) | |
| if vals: | |
| rate = np.mean(vals) | |
| rates.append(rate) | |
| # Wilson CI for proportions | |
| n = len(vals) | |
| z = 1.96 | |
| p = rate | |
| denom = 1 + z ** 2 / n | |
| center = (p + z ** 2 / (2 * n)) / denom | |
| margin = z * np.sqrt((p * (1 - p) + z ** 2 / (4 * n)) / n) / denom | |
| cis_lower.append(center - margin) | |
| cis_upper.append(center + margin) | |
| else: | |
| rates.append(0) | |
| cis_lower.append(0) | |
| cis_upper.append(0) | |
| colors_bar = ["#81B29A", "#E07A5F"] | |
| bars = ax.bar(tier_labels, rates, color=colors_bar, edgecolor="white", width=0.5) | |
| ax.errorbar( | |
| tier_labels, rates, | |
| yerr=[np.array(rates) - np.array(cis_lower), | |
| np.array(cis_upper) - np.array(rates)], | |
| fmt="none", ecolor="black", capsize=5, | |
| ) | |
| ax.set_ylabel("OCT Request Rate") | |
| ax.set_title("Agent's OCT Request Rate by Biomarker Type") | |
| ax.set_ylim(0, 1.05) | |
| ax.grid(True, axis="y", alpha=0.3) | |
| # Add counts | |
| for i, tier in enumerate(tiers): | |
| n = len(oct_request_by_tier.get(tier, [])) | |
| ax.text(i, rates[i] + 0.05, f"n={n}", ha="center", fontsize=10) | |
| plt.tight_layout() | |
| save_path = self.figures_dir / f"{save_name}.pdf" | |
| fig.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| logger.info(f"Saved figure: {save_path}") | |
| # ================================================================ | |
| # Figure 4: NEJM Difficulty vs Acquisition Behavior | |
| # ================================================================ | |
| def plot_nejm_difficulty_analysis( | |
| self, | |
| results: list[AgentResult], | |
| cases: list[MedicalCase], | |
| save_name: str = "fig4_nejm_difficulty", | |
| ): | |
| """ | |
| Scatter: human difficulty (physician correct rate) vs | |
| agent's acquisition behavior (N channels requested + early commit). | |
| """ | |
| difficulties = [] | |
| n_acquired = [] | |
| committed_early = [] | |
| for result, case in zip(results, cases): | |
| if case.dataset != "nejm": | |
| continue | |
| votes = case.metadata.get("votes", {}) | |
| if not votes: | |
| continue | |
| # Compute human difficulty (proportion correct) | |
| total_votes = sum(float(v) for v in votes.values()) | |
| if total_votes == 0: | |
| continue | |
| gt = case.ground_truth | |
| human_correct = 0.0 | |
| for key, val in votes.items(): | |
| if key in gt or gt.startswith(key): | |
| human_correct = float(val) / total_votes if total_votes > 1 else float(val) | |
| break | |
| difficulties.append(human_correct) | |
| n_acquired.append(len(result.acquired_channels)) | |
| committed_early.append(result.committed_early) | |
| if not difficulties: | |
| logger.warning("No NEJM cases with difficulty data found") | |
| return | |
| fig, axes = plt.subplots(1, 2, figsize=(11, 4.5)) | |
| # Left: Difficulty vs N channels acquired | |
| ax = axes[0] | |
| ax.scatter(difficulties, n_acquired, alpha=0.5, s=30, color="#3D405B", edgecolors="white") | |
| # Add trend line | |
| if len(difficulties) > 10: | |
| z = np.polyfit(difficulties, n_acquired, 1) | |
| p = np.poly1d(z) | |
| x_line = np.linspace(min(difficulties), max(difficulties), 100) | |
| ax.plot(x_line, p(x_line), "--", color="#E07A5F", linewidth=2, | |
| label=f"Trend (slope={z[0]:.2f})") | |
| # Correlation | |
| r, pval = stats.pearsonr(difficulties, n_acquired) | |
| ax.text(0.05, 0.95, f"r={r:.3f}, p={pval:.3f}", | |
| transform=ax.transAxes, fontsize=9, verticalalignment="top") | |
| ax.set_xlabel("Human Correct Rate (easier →)") | |
| ax.set_ylabel("Channels Acquired by Agent") | |
| ax.set_title("(a) Case Difficulty vs. Acquisition Amount") | |
| ax.legend(fontsize=9) | |
| ax.grid(True, alpha=0.3) | |
| # Right: Difficulty bins vs early commit rate | |
| ax = axes[1] | |
| diff_arr = np.array(difficulties) | |
| commit_arr = np.array(committed_early, dtype=float) | |
| bins = [0, 0.25, 0.50, 0.75, 1.01] | |
| bin_labels = ["<25%", "25-50%", "50-75%", ">75%"] | |
| bin_rates = [] | |
| bin_ns = [] | |
| for i in range(len(bins) - 1): | |
| mask = (diff_arr >= bins[i]) & (diff_arr < bins[i + 1]) | |
| if mask.sum() > 0: | |
| bin_rates.append(commit_arr[mask].mean()) | |
| bin_ns.append(mask.sum()) | |
| else: | |
| bin_rates.append(0) | |
| bin_ns.append(0) | |
| bar_colors = ["#E07A5F", "#F2CC8F", "#81B29A", "#3D405B"] | |
| bars = ax.bar(bin_labels, bin_rates, color=bar_colors, edgecolor="white", width=0.6) | |
| for i, (rate, n) in enumerate(zip(bin_rates, bin_ns)): | |
| ax.text(i, rate + 0.02, f"n={n}", ha="center", fontsize=9) | |
| ax.set_xlabel("Human Correct Rate (easier →)") | |
| ax.set_ylabel("Agent Early Commit Rate") | |
| ax.set_title("(b) Early Commitment vs. Difficulty") | |
| ax.set_ylim(0, 1.05) | |
| ax.grid(True, axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| save_path = self.figures_dir / f"{save_name}.pdf" | |
| fig.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| logger.info(f"Saved figure: {save_path}") | |
| # ================================================================ | |
| # Figure 5: Regret Analysis | |
| # ================================================================ | |
| def plot_regret_analysis( | |
| self, | |
| regret: dict, | |
| dataset_name: str = "", | |
| save_name: str = "fig5_regret_analysis", | |
| ): | |
| """ | |
| Visualize regret analysis results. | |
| Left: Stacked bar showing recoverable vs unrecoverable errors. | |
| Right: Per-channel regret scores (which missed channels cost the most). | |
| """ | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4.5)) | |
| title_suffix = f" — {dataset_name.upper()}" if dataset_name else "" | |
| # ---- Left panel: Error decomposition ---- | |
| ax = axes[0] | |
| summary = regret["summary"] | |
| n_correct = regret["n_cases"] - regret["n_active_wrong"] | |
| n_recoverable = regret["n_recoverable"] | |
| n_unrecoverable = regret["n_unrecoverable"] | |
| categories = ["Agent\nCorrect", "Recoverable\nErrors", "Unrecoverable\nErrors"] | |
| values = [n_correct, n_recoverable, n_unrecoverable] | |
| colors_bar = ["#81B29A", "#F2CC8F", "#E07A5F"] | |
| bars = ax.bar(categories, values, color=colors_bar, edgecolor="white", width=0.55) | |
| for bar, val in zip(bars, values): | |
| pct = val / regret["n_cases"] * 100 if regret["n_cases"] > 0 else 0 | |
| ax.text( | |
| bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, | |
| f"{val}\n({pct:.0f}%)", ha="center", fontsize=10, | |
| ) | |
| ax.set_ylabel("Number of Cases") | |
| ax.set_title(f"(a) Error Decomposition{title_suffix}") | |
| ax.grid(True, axis="y", alpha=0.3) | |
| # ---- Right panel: Per-channel regret ---- | |
| ax = axes[1] | |
| channel_scores = regret["channel_regret_scores"] | |
| if channel_scores: | |
| channels = list(channel_scores.keys()) | |
| regret_rates = [channel_scores[ch]["regret_rate"] for ch in channels] | |
| miss_counts = [channel_scores[ch]["missed_in_recoverable"] for ch in channels] | |
| # Sort by regret rate | |
| sorted_idx = sorted(range(len(channels)), key=lambda i: -regret_rates[i]) | |
| channels = [channels[i] for i in sorted_idx] | |
| regret_rates = [regret_rates[i] for i in sorted_idx] | |
| miss_counts = [miss_counts[i] for i in sorted_idx] | |
| y_pos = range(len(channels)) | |
| bar_colors = plt.cm.YlOrRd(np.linspace(0.3, 0.9, len(channels))) | |
| bars = ax.barh( | |
| y_pos, regret_rates, color=bar_colors, edgecolor="white", height=0.6, | |
| ) | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels([ch.replace("_", " ").title() for ch in channels], fontsize=9) | |
| ax.set_xlabel("Regret Rate") | |
| ax.set_xlim(0, 1.05) | |
| ax.invert_yaxis() | |
| # Annotate with counts | |
| for i, (rate, count) in enumerate(zip(regret_rates, miss_counts)): | |
| ax.text( | |
| rate + 0.02, i, f"n={count}", | |
| va="center", fontsize=9, color="#333", | |
| ) | |
| else: | |
| ax.text(0.5, 0.5, "No channel data", ha="center", va="center", | |
| transform=ax.transAxes, fontsize=12) | |
| ax.set_title(f"(b) Channel Regret Scores{title_suffix}") | |
| ax.grid(True, axis="x", alpha=0.3) | |
| plt.tight_layout() | |
| save_path = self.figures_dir / f"{save_name}.pdf" | |
| fig.savefig(save_path, dpi=300, bbox_inches="tight") | |
| plt.close(fig) | |
| logger.info(f"Saved figure: {save_path}") | |
| def print_regret_summary(self, regret: dict): | |
| """Print a concise text summary of regret analysis.""" | |
| s = regret["summary"] | |
| print("\n" + "=" * 55) | |
| print(" REGRET ANALYSIS") | |
| print("=" * 55) | |
| print(f" Total cases: {regret['n_cases']}") | |
| print(f" Agent errors: {s['total_errors']} ({regret['error_rate']*100:.1f}%)") | |
| print(f" Recoverable: {regret['n_recoverable']} ({s['recoverable_pct']:.1f}% of errors)") | |
| print(f" Unrecoverable: {regret['n_unrecoverable']} ({s['unrecoverable_pct']:.1f}% of errors)") | |
| print(f" Highest-regret channel: {s['highest_regret_channel']}") | |
| print() | |
| print(" Per-channel regret:") | |
| for ch, scores in regret["channel_regret_scores"].items(): | |
| print(f" {ch:<25} regret={scores['regret_rate']:.2f} " | |
| f"(missed in {scores['missed_in_recoverable']}/{scores['missed_in_all_wrong']} errors)") | |
| print("=" * 55) | |
| # ================================================================ | |
| # Summary Table | |
| # ================================================================ | |
| def print_summary_table( | |
| self, | |
| all_metrics: dict[str, dict[str, DatasetMetrics]], | |
| ): | |
| """ | |
| Print the main results table. | |
| Args: | |
| all_metrics: {condition: {dataset: DatasetMetrics}} | |
| where condition is "passive", "K=1", "K=2", "K=3", | |
| "fixed_order", "oracle" | |
| """ | |
| header = f"{'Condition':<15} {'Dataset':<12} {'Top-1 Acc':<15} {'MRR':<15} {'Avg K':<8}" | |
| print("=" * len(header)) | |
| print(header) | |
| print("=" * len(header)) | |
| for condition in ["passive", "K=1", "K=2", "K=3", "fixed_order", "oracle"]: | |
| if condition not in all_metrics: | |
| continue | |
| for ds in ["midas", "nejm", "olives"]: | |
| if ds not in all_metrics[condition]: | |
| continue | |
| m = all_metrics[condition][ds] | |
| acc_str = f"{m.top1_accuracy:.3f} ({m.top1_accuracy_ci[0]:.3f}-{m.top1_accuracy_ci[1]:.3f})" | |
| mrr_str = f"{m.mrr:.3f} ({m.mrr_ci[0]:.3f}-{m.mrr_ci[1]:.3f})" | |
| print(f"{condition:<15} {ds:<12} {acc_str:<15} {mrr_str:<15} {m.mean_channels_acquired:<8.1f}") | |
| print("=" * len(header)) | |