yuxbox's picture
Upload folder using huggingface_hub
a1aaf30 verified
"""
Cross-dataset analysis and figure generation.
Produces the key figures for the paper:
1. Acquisition Efficiency curves (all 3 datasets, shared y-axis)
2. Per-channel request frequency heatmap
3. Prompt sensitivity agreement matrix
4. OLIVES biomarker-tier acquisition analysis
5. NEJM difficulty-vs-acquisition scatter
"""
import json
import logging
from pathlib import Path
from dataclasses import asdict
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
from scipy import stats
from agent import AgentResult
from datasets.base import MedicalCase
from evaluation import (
CaseMetrics,
DatasetMetrics,
evaluate_single_case,
aggregate_metrics,
compute_acquisition_efficiency,
compute_acquisition_precision,
compute_prompt_agreement,
compute_regret_analysis,
)
import config
matplotlib.rcParams["font.family"] = "serif"
matplotlib.rcParams["font.size"] = 11
logger = logging.getLogger(__name__)
class ExperimentAnalyzer:
"""Analyze and visualize results across all experiments."""
def __init__(self, results_dir: Path = None):
self.results_dir = results_dir or config.RESULTS_DIR
self.figures_dir = self.results_dir / "figures"
self.figures_dir.mkdir(parents=True, exist_ok=True)
def load_results(self, experiment_name: str) -> dict:
"""Load saved experiment results."""
path = self.results_dir / f"{experiment_name}.json"
if not path.exists():
logger.error(f"Results file not found: {path}")
return {}
with open(path) as f:
return json.load(f)
def save_results(self, data: dict, experiment_name: str):
"""Save experiment results."""
path = self.results_dir / f"{experiment_name}.json"
with open(path, "w") as f:
json.dump(data, f, indent=2, default=str)
logger.info(f"Results saved to {path}")
# ================================================================
# Figure 1: Acquisition Efficiency Curves
# ================================================================
def plot_acquisition_efficiency(
self,
results_by_dataset: dict[str, dict[int, DatasetMetrics]],
passive_metrics: dict[str, DatasetMetrics],
oracle_metrics: dict[str, DatasetMetrics],
save_name: str = "fig1_acquisition_efficiency",
):
"""
Main result figure: normalized acquisition efficiency vs budget K.
Args:
results_by_dataset: {dataset_name: {K: DatasetMetrics}}
passive_metrics: {dataset_name: DatasetMetrics} at K=0
oracle_metrics: {dataset_name: DatasetMetrics} with all channels
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
colors = {"midas": "#E07A5F", "nejm": "#3D405B", "olives": "#81B29A"}
markers = {"midas": "o", "nejm": "s", "olives": "D"}
labels = {"midas": "MIDAS (Dermatology)", "nejm": "NEJM (Multi-Specialty)",
"olives": "OLIVES (Ophthalmology)"}
# Left panel: Raw MRR vs K
ax = axes[0]
for ds_name in ["midas", "nejm", "olives"]:
if ds_name not in results_by_dataset:
continue
ks = sorted(results_by_dataset[ds_name].keys())
mrrs = [results_by_dataset[ds_name][k].mrr for k in ks]
cis = [results_by_dataset[ds_name][k].mrr_ci for k in ks]
# Add passive at K=0
all_k = [0] + list(ks)
all_mrr = [passive_metrics[ds_name].mrr] + mrrs
all_lower = [passive_metrics[ds_name].mrr_ci[0]] + [c[0] for c in cis]
all_upper = [passive_metrics[ds_name].mrr_ci[1]] + [c[1] for c in cis]
ax.plot(all_k, all_mrr, color=colors[ds_name], marker=markers[ds_name],
label=labels[ds_name], linewidth=2, markersize=7)
ax.fill_between(all_k, all_lower, all_upper, alpha=0.15, color=colors[ds_name])
# Oracle line
ax.axhline(y=oracle_metrics[ds_name].mrr, color=colors[ds_name],
linestyle="--", alpha=0.4, linewidth=1)
ax.set_xlabel("Acquisition Budget (K)")
ax.set_ylabel("Mean Reciprocal Rank (MRR)")
ax.set_title("(a) Diagnostic Quality vs. Budget")
ax.legend(fontsize=9)
ax.set_xticks(range(max(4, max(max(r.keys()) for r in results_by_dataset.values()) + 1)))
ax.grid(True, alpha=0.3)
# Right panel: Normalized Acquisition Efficiency
ax = axes[1]
for ds_name in ["midas", "nejm", "olives"]:
if ds_name not in results_by_dataset:
continue
ks = sorted(results_by_dataset[ds_name].keys())
effs = []
for k in ks:
ae = compute_acquisition_efficiency(
results_by_dataset[ds_name][k].mrr,
passive_metrics[ds_name].mrr,
oracle_metrics[ds_name].mrr,
)
effs.append(ae)
all_k = [0] + list(ks)
all_eff = [0.0] + effs
ax.plot(all_k, all_eff, color=colors[ds_name], marker=markers[ds_name],
label=labels[ds_name], linewidth=2, markersize=7)
ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5, linewidth=1,
label="Oracle ceiling")
ax.set_xlabel("Acquisition Budget (K)")
ax.set_ylabel("Acquisition Efficiency")
ax.set_title("(b) Normalized Efficiency")
ax.legend(fontsize=9)
ax.set_ylim(-0.05, 1.15)
ax.grid(True, alpha=0.3)
plt.tight_layout()
save_path = self.figures_dir / f"{save_name}.pdf"
fig.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.info(f"Saved figure: {save_path}")
# ================================================================
# Figure 2: Per-Channel Request Frequency
# ================================================================
def plot_channel_request_heatmap(
self,
results_by_dataset: dict[str, list[AgentResult]],
save_name: str = "fig2_channel_requests",
):
"""Heatmap showing which channels the agent requests most, by dataset."""
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
dataset_names = ["midas", "nejm", "olives"]
titles = ["MIDAS", "NEJM", "OLIVES"]
for idx, (ds_name, title) in enumerate(zip(dataset_names, titles)):
if ds_name not in results_by_dataset:
continue
results = results_by_dataset[ds_name]
# Count first-request frequency
first_requests: dict[str, int] = {}
for r in results:
if r.acquired_channels:
ch = r.acquired_channels[0]
first_requests[ch] = first_requests.get(ch, 0) + 1
# Count overall request frequency
all_requests: dict[str, int] = {}
for r in results:
for ch in r.acquired_channels:
all_requests[ch] = all_requests.get(ch, 0) + 1
if not all_requests:
continue
channels = sorted(all_requests.keys())
n = len(results)
ax = axes[idx]
data = np.array([
[first_requests.get(ch, 0) / n for ch in channels],
[all_requests.get(ch, 0) / n for ch in channels],
])
sns.heatmap(
data,
ax=ax,
xticklabels=[ch.replace("_", "\n") for ch in channels],
yticklabels=["First\nRequest", "Any\nRequest"],
annot=True,
fmt=".2f",
cmap="YlOrRd",
vmin=0,
vmax=1,
cbar_kws={"shrink": 0.8},
)
ax.set_title(title)
plt.tight_layout()
save_path = self.figures_dir / f"{save_name}.pdf"
fig.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.info(f"Saved figure: {save_path}")
# ================================================================
# Figure 3: OLIVES Biomarker Tier Analysis
# ================================================================
def plot_olives_biomarker_tiers(
self,
results: list[AgentResult],
cases: list[MedicalCase],
save_name: str = "fig3_olives_biomarker_tiers",
):
"""
For OLIVES: does the agent request OCT more for OCT-dependent
biomarkers than for fundus-visible ones?
"""
oct_request_by_tier: dict[str, list[bool]] = {
"fundus_visible": [],
"oct_dependent": [],
}
for result, case in zip(results, cases):
if case.dataset != "olives":
continue
tier_labels = case.metadata.get("biomarker_tier_labels", {})
requested_oct = "oct_scan" in result.acquired_channels
# For cases where the eye has fundus-visible biomarkers
if tier_labels.get("fundus_visible"):
oct_request_by_tier["fundus_visible"].append(requested_oct)
# For cases where the eye has OCT-dependent biomarkers
if tier_labels.get("oct_dependent"):
oct_request_by_tier["oct_dependent"].append(requested_oct)
fig, ax = plt.subplots(figsize=(6, 4))
tiers = ["fundus_visible", "oct_dependent"]
tier_labels = ["Fundus-Visible\nBiomarkers", "OCT-Dependent\nBiomarkers"]
rates = []
cis_lower = []
cis_upper = []
for tier in tiers:
vals = oct_request_by_tier.get(tier, [])
if vals:
rate = np.mean(vals)
rates.append(rate)
# Wilson CI for proportions
n = len(vals)
z = 1.96
p = rate
denom = 1 + z ** 2 / n
center = (p + z ** 2 / (2 * n)) / denom
margin = z * np.sqrt((p * (1 - p) + z ** 2 / (4 * n)) / n) / denom
cis_lower.append(center - margin)
cis_upper.append(center + margin)
else:
rates.append(0)
cis_lower.append(0)
cis_upper.append(0)
colors_bar = ["#81B29A", "#E07A5F"]
bars = ax.bar(tier_labels, rates, color=colors_bar, edgecolor="white", width=0.5)
ax.errorbar(
tier_labels, rates,
yerr=[np.array(rates) - np.array(cis_lower),
np.array(cis_upper) - np.array(rates)],
fmt="none", ecolor="black", capsize=5,
)
ax.set_ylabel("OCT Request Rate")
ax.set_title("Agent's OCT Request Rate by Biomarker Type")
ax.set_ylim(0, 1.05)
ax.grid(True, axis="y", alpha=0.3)
# Add counts
for i, tier in enumerate(tiers):
n = len(oct_request_by_tier.get(tier, []))
ax.text(i, rates[i] + 0.05, f"n={n}", ha="center", fontsize=10)
plt.tight_layout()
save_path = self.figures_dir / f"{save_name}.pdf"
fig.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.info(f"Saved figure: {save_path}")
# ================================================================
# Figure 4: NEJM Difficulty vs Acquisition Behavior
# ================================================================
def plot_nejm_difficulty_analysis(
self,
results: list[AgentResult],
cases: list[MedicalCase],
save_name: str = "fig4_nejm_difficulty",
):
"""
Scatter: human difficulty (physician correct rate) vs
agent's acquisition behavior (N channels requested + early commit).
"""
difficulties = []
n_acquired = []
committed_early = []
for result, case in zip(results, cases):
if case.dataset != "nejm":
continue
votes = case.metadata.get("votes", {})
if not votes:
continue
# Compute human difficulty (proportion correct)
total_votes = sum(float(v) for v in votes.values())
if total_votes == 0:
continue
gt = case.ground_truth
human_correct = 0.0
for key, val in votes.items():
if key in gt or gt.startswith(key):
human_correct = float(val) / total_votes if total_votes > 1 else float(val)
break
difficulties.append(human_correct)
n_acquired.append(len(result.acquired_channels))
committed_early.append(result.committed_early)
if not difficulties:
logger.warning("No NEJM cases with difficulty data found")
return
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
# Left: Difficulty vs N channels acquired
ax = axes[0]
ax.scatter(difficulties, n_acquired, alpha=0.5, s=30, color="#3D405B", edgecolors="white")
# Add trend line
if len(difficulties) > 10:
z = np.polyfit(difficulties, n_acquired, 1)
p = np.poly1d(z)
x_line = np.linspace(min(difficulties), max(difficulties), 100)
ax.plot(x_line, p(x_line), "--", color="#E07A5F", linewidth=2,
label=f"Trend (slope={z[0]:.2f})")
# Correlation
r, pval = stats.pearsonr(difficulties, n_acquired)
ax.text(0.05, 0.95, f"r={r:.3f}, p={pval:.3f}",
transform=ax.transAxes, fontsize=9, verticalalignment="top")
ax.set_xlabel("Human Correct Rate (easier →)")
ax.set_ylabel("Channels Acquired by Agent")
ax.set_title("(a) Case Difficulty vs. Acquisition Amount")
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)
# Right: Difficulty bins vs early commit rate
ax = axes[1]
diff_arr = np.array(difficulties)
commit_arr = np.array(committed_early, dtype=float)
bins = [0, 0.25, 0.50, 0.75, 1.01]
bin_labels = ["<25%", "25-50%", "50-75%", ">75%"]
bin_rates = []
bin_ns = []
for i in range(len(bins) - 1):
mask = (diff_arr >= bins[i]) & (diff_arr < bins[i + 1])
if mask.sum() > 0:
bin_rates.append(commit_arr[mask].mean())
bin_ns.append(mask.sum())
else:
bin_rates.append(0)
bin_ns.append(0)
bar_colors = ["#E07A5F", "#F2CC8F", "#81B29A", "#3D405B"]
bars = ax.bar(bin_labels, bin_rates, color=bar_colors, edgecolor="white", width=0.6)
for i, (rate, n) in enumerate(zip(bin_rates, bin_ns)):
ax.text(i, rate + 0.02, f"n={n}", ha="center", fontsize=9)
ax.set_xlabel("Human Correct Rate (easier →)")
ax.set_ylabel("Agent Early Commit Rate")
ax.set_title("(b) Early Commitment vs. Difficulty")
ax.set_ylim(0, 1.05)
ax.grid(True, axis="y", alpha=0.3)
plt.tight_layout()
save_path = self.figures_dir / f"{save_name}.pdf"
fig.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.info(f"Saved figure: {save_path}")
# ================================================================
# Figure 5: Regret Analysis
# ================================================================
def plot_regret_analysis(
self,
regret: dict,
dataset_name: str = "",
save_name: str = "fig5_regret_analysis",
):
"""
Visualize regret analysis results.
Left: Stacked bar showing recoverable vs unrecoverable errors.
Right: Per-channel regret scores (which missed channels cost the most).
"""
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
title_suffix = f" — {dataset_name.upper()}" if dataset_name else ""
# ---- Left panel: Error decomposition ----
ax = axes[0]
summary = regret["summary"]
n_correct = regret["n_cases"] - regret["n_active_wrong"]
n_recoverable = regret["n_recoverable"]
n_unrecoverable = regret["n_unrecoverable"]
categories = ["Agent\nCorrect", "Recoverable\nErrors", "Unrecoverable\nErrors"]
values = [n_correct, n_recoverable, n_unrecoverable]
colors_bar = ["#81B29A", "#F2CC8F", "#E07A5F"]
bars = ax.bar(categories, values, color=colors_bar, edgecolor="white", width=0.55)
for bar, val in zip(bars, values):
pct = val / regret["n_cases"] * 100 if regret["n_cases"] > 0 else 0
ax.text(
bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
f"{val}\n({pct:.0f}%)", ha="center", fontsize=10,
)
ax.set_ylabel("Number of Cases")
ax.set_title(f"(a) Error Decomposition{title_suffix}")
ax.grid(True, axis="y", alpha=0.3)
# ---- Right panel: Per-channel regret ----
ax = axes[1]
channel_scores = regret["channel_regret_scores"]
if channel_scores:
channels = list(channel_scores.keys())
regret_rates = [channel_scores[ch]["regret_rate"] for ch in channels]
miss_counts = [channel_scores[ch]["missed_in_recoverable"] for ch in channels]
# Sort by regret rate
sorted_idx = sorted(range(len(channels)), key=lambda i: -regret_rates[i])
channels = [channels[i] for i in sorted_idx]
regret_rates = [regret_rates[i] for i in sorted_idx]
miss_counts = [miss_counts[i] for i in sorted_idx]
y_pos = range(len(channels))
bar_colors = plt.cm.YlOrRd(np.linspace(0.3, 0.9, len(channels)))
bars = ax.barh(
y_pos, regret_rates, color=bar_colors, edgecolor="white", height=0.6,
)
ax.set_yticks(y_pos)
ax.set_yticklabels([ch.replace("_", " ").title() for ch in channels], fontsize=9)
ax.set_xlabel("Regret Rate")
ax.set_xlim(0, 1.05)
ax.invert_yaxis()
# Annotate with counts
for i, (rate, count) in enumerate(zip(regret_rates, miss_counts)):
ax.text(
rate + 0.02, i, f"n={count}",
va="center", fontsize=9, color="#333",
)
else:
ax.text(0.5, 0.5, "No channel data", ha="center", va="center",
transform=ax.transAxes, fontsize=12)
ax.set_title(f"(b) Channel Regret Scores{title_suffix}")
ax.grid(True, axis="x", alpha=0.3)
plt.tight_layout()
save_path = self.figures_dir / f"{save_name}.pdf"
fig.savefig(save_path, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.info(f"Saved figure: {save_path}")
def print_regret_summary(self, regret: dict):
"""Print a concise text summary of regret analysis."""
s = regret["summary"]
print("\n" + "=" * 55)
print(" REGRET ANALYSIS")
print("=" * 55)
print(f" Total cases: {regret['n_cases']}")
print(f" Agent errors: {s['total_errors']} ({regret['error_rate']*100:.1f}%)")
print(f" Recoverable: {regret['n_recoverable']} ({s['recoverable_pct']:.1f}% of errors)")
print(f" Unrecoverable: {regret['n_unrecoverable']} ({s['unrecoverable_pct']:.1f}% of errors)")
print(f" Highest-regret channel: {s['highest_regret_channel']}")
print()
print(" Per-channel regret:")
for ch, scores in regret["channel_regret_scores"].items():
print(f" {ch:<25} regret={scores['regret_rate']:.2f} "
f"(missed in {scores['missed_in_recoverable']}/{scores['missed_in_all_wrong']} errors)")
print("=" * 55)
# ================================================================
# Summary Table
# ================================================================
def print_summary_table(
self,
all_metrics: dict[str, dict[str, DatasetMetrics]],
):
"""
Print the main results table.
Args:
all_metrics: {condition: {dataset: DatasetMetrics}}
where condition is "passive", "K=1", "K=2", "K=3",
"fixed_order", "oracle"
"""
header = f"{'Condition':<15} {'Dataset':<12} {'Top-1 Acc':<15} {'MRR':<15} {'Avg K':<8}"
print("=" * len(header))
print(header)
print("=" * len(header))
for condition in ["passive", "K=1", "K=2", "K=3", "fixed_order", "oracle"]:
if condition not in all_metrics:
continue
for ds in ["midas", "nejm", "olives"]:
if ds not in all_metrics[condition]:
continue
m = all_metrics[condition][ds]
acc_str = f"{m.top1_accuracy:.3f} ({m.top1_accuracy_ci[0]:.3f}-{m.top1_accuracy_ci[1]:.3f})"
mrr_str = f"{m.mrr:.3f} ({m.mrr_ci[0]:.3f}-{m.mrr_ci[1]:.3f})"
print(f"{condition:<15} {ds:<12} {acc_str:<15} {mrr_str:<15} {m.mean_channels_acquired:<8.1f}")
print("=" * len(header))