activemedagent-demo / trajectory.py
yuxbox's picture
Upload folder using huggingface_hub
a1aaf30 verified
"""
Trajectory Collection for ActiveMedAgent.
Phase 1 of the training pipeline:
1. Run zero-shot agent on all cases
2. Record full (state, action, reward) trajectories
3. Compute per-step rewards: did the acquisition improve the diagnosis?
4. Save trajectory dataset for Phase 2 policy learning
Each trajectory step records:
- state: current uncertainty, differential, acquired channels so far
- action: which channel was requested
- reward: MRR improvement after receiving the requested info
- outcome: final diagnosis correctness
"""
import json
import logging
import random
from dataclasses import dataclass, field, asdict
from pathlib import Path
import numpy as np
from tqdm import tqdm
import config
from api_client import BaseVLMClient, create_client
from agent import ActiveMedAgent, AgentResult
from datasets.base import MedicalCase
from evaluation import compute_reciprocal_rank
logger = logging.getLogger(__name__)
@dataclass
class TrajectoryStep:
"""A single step in an acquisition trajectory."""
step_idx: int
# State representation
acquired_so_far: list[str]
available_channels: list[str]
uncertainty_text: str
differential_before: list[dict] # Ranking before this acquisition
mrr_before: float
# Action
action: str # Channel name requested (or "COMMIT")
# Outcome (computed after the action)
differential_after: list[dict] # Ranking after receiving the info
mrr_after: float
reward: float # MRR improvement: mrr_after - mrr_before
acquisition_cost: float = 0.0
normalized_cost: float = 0.0
utility_reward: float = 0.0 # Cost-aware reward used for policy learning
diagnosis_changed: bool = False # Did top-1 change?
diagnosis_improved: bool = False # Did it change to the correct answer?
@dataclass
class Trajectory:
"""Complete trajectory for one case."""
case_id: str
dataset: str
ground_truth: str
candidates: list[str]
steps: list[TrajectoryStep] = field(default_factory=list)
passive_mrr: float = 0.0
oracle_mrr: float = 0.0
final_mrr: float = 0.0
total_reward: float = 0.0
total_utility_reward: float = 0.0
success: bool = False # Did the agent get top-1 correct?
class TrajectoryCollector:
"""
Collect acquisition trajectories with per-step rewards.
Unlike the basic agent.diagnose(), this method runs the agent
step-by-step, evaluating the diagnosis after EACH acquisition
to compute fine-grained reward signals.
Uses the tool-use agent architecture: runs the full agent for
acquisition decisions, then evaluates intermediate states via
the agent's get_diagnosis_at_state() helper.
"""
def __init__(
self,
client: BaseVLMClient,
prompt_variant: str = "A",
budget: int = 3,
):
self.client = client
self.prompt_variant = prompt_variant
self.budget = budget
def collect_trajectory(self, case: MedicalCase) -> Trajectory:
"""
Collect a full trajectory with per-step rewards for one case.
Strategy:
1. Get passive baseline (image-only diagnosis)
2. Get oracle ceiling (all-info diagnosis)
3. Run the active agent and record its decisions
4. For each acquisition step, evaluate the intermediate
diagnosis to compute per-step MRR reward
"""
traj = Trajectory(
case_id=case.case_id,
dataset=case.dataset,
ground_truth=case.ground_truth,
candidates=case.candidates,
)
# ---- Evaluation agent (budget=0, just for scoring) ----
eval_agent = ActiveMedAgent(
self.client, self.prompt_variant, budget=0
)
# ---- Get passive baseline (MRR with no acquisition) ----
passive_result = eval_agent.diagnose_passive(case)
passive_mrr = compute_reciprocal_rank(
passive_result.final_ranking, case.ground_truth, case.candidates
)
traj.passive_mrr = passive_mrr
# ---- Get oracle ceiling (MRR with all info) ----
oracle_result = eval_agent.diagnose_oracle(case)
oracle_mrr = compute_reciprocal_rank(
oracle_result.final_ranking, case.ground_truth, case.candidates
)
traj.oracle_mrr = oracle_mrr
# ---- Run the active agent to get its acquisition decisions ----
active_agent = ActiveMedAgent(
self.client, self.prompt_variant, budget=self.budget
)
active_result = active_agent.diagnose(case)
# ---- Evaluate each intermediate state ----
current_mrr = passive_mrr
current_ranking = passive_result.final_ranking
acquired_so_far = []
for step_idx, step in enumerate(active_result.steps):
if step.committed:
# Agent committed early — record and stop
traj_step = TrajectoryStep(
step_idx=step_idx,
acquired_so_far=list(acquired_so_far),
available_channels=[
n for n in case.requestable_names
if n not in acquired_so_far
],
uncertainty_text=step.reasoning or "",
differential_before=current_ranking,
mrr_before=current_mrr,
action="COMMIT",
differential_after=current_ranking,
mrr_after=current_mrr,
reward=0.0,
acquisition_cost=0.0,
normalized_cost=0.0,
utility_reward=0.0,
diagnosis_changed=False,
diagnosis_improved=False,
)
traj.steps.append(traj_step)
break
channel = step.requested_channel
if not channel:
continue
available = [
n for n in case.requestable_names
if n not in acquired_so_far
]
# Record the state BEFORE this acquisition
before_ranking = current_ranking
before_mrr = current_mrr
# Execute the acquisition
acquired_so_far.append(channel)
# Evaluate the diagnosis AFTER this acquisition
after_ranking, _ = eval_agent.get_diagnosis_at_state(
case, list(acquired_so_far)
)
after_mrr = compute_reciprocal_rank(
after_ranking, case.ground_truth, case.candidates
)
# Compute reward
reward = after_mrr - before_mrr
channel_cost = case.get_channel_cost(channel)
max_requestable_cost = max(case.get_max_requestable_cost(), 1.0)
normalized_cost = channel_cost / max_requestable_cost
utility_reward = reward - (
config.COST_PENALTY_LAMBDA * normalized_cost
)
# Did diagnosis change?
top1_before = (
before_ranking[0]["name"] if before_ranking else ""
)
top1_after = (
after_ranking[0]["name"] if after_ranking else ""
)
diagnosis_changed = (
top1_before.lower() != top1_after.lower()
)
gt_lower = case.ground_truth.lower()
diagnosis_improved = (
diagnosis_changed
and (
gt_lower in top1_after.lower()
or top1_after.lower() in gt_lower
)
)
traj_step = TrajectoryStep(
step_idx=step_idx,
acquired_so_far=list(acquired_so_far[:-1]),
available_channels=available,
uncertainty_text=step.reasoning or "",
differential_before=before_ranking,
mrr_before=before_mrr,
action=channel,
differential_after=after_ranking,
mrr_after=after_mrr,
reward=reward,
acquisition_cost=channel_cost,
normalized_cost=normalized_cost,
utility_reward=utility_reward,
diagnosis_changed=diagnosis_changed,
diagnosis_improved=diagnosis_improved,
)
traj.steps.append(traj_step)
# Update state for next step
current_mrr = after_mrr
current_ranking = after_ranking
# ---- Finalize trajectory ----
traj.final_mrr = current_mrr
traj.total_reward = sum(s.reward for s in traj.steps)
traj.total_utility_reward = sum(s.utility_reward for s in traj.steps)
traj.success = (current_mrr == 1.0)
return traj
def collect_dataset(
self,
cases: list[MedicalCase],
max_cases: int = None,
save_path: Path = None,
) -> list[Trajectory]:
"""Collect trajectories for all cases."""
if max_cases:
cases = cases[:max_cases]
trajectories = []
for case in tqdm(cases, desc="Collecting trajectories", ncols=80):
try:
traj = self.collect_trajectory(case)
trajectories.append(traj)
except Exception as e:
logger.error(f"Failed on {case.case_id}: {e}")
continue
# Save
if save_path:
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
with open(save_path, "w") as f:
json.dump(
[asdict(t) for t in trajectories],
f, indent=2, default=str,
)
logger.info(f"Saved {len(trajectories)} trajectories to {save_path}")
# Report statistics
self._report_stats(trajectories)
return trajectories
def _report_stats(self, trajectories: list[Trajectory]):
"""Log summary statistics of collected trajectories."""
n = len(trajectories)
if n == 0:
return
logger.info(f"\n{'='*50}")
logger.info(f"Trajectory Collection Summary (n={n})")
logger.info(f"{'='*50}")
success_rate = np.mean([t.success for t in trajectories])
avg_steps = np.mean([len(t.steps) for t in trajectories])
avg_reward = np.mean([t.total_reward for t in trajectories])
avg_utility = np.mean([t.total_utility_reward for t in trajectories])
avg_passive_mrr = np.mean([t.passive_mrr for t in trajectories])
avg_final_mrr = np.mean([t.final_mrr for t in trajectories])
avg_oracle_mrr = np.mean([t.oracle_mrr for t in trajectories])
logger.info(f" Success rate: {success_rate:.3f}")
logger.info(f" Avg steps taken: {avg_steps:.1f}")
logger.info(f" Avg total reward: {avg_reward:.3f}")
logger.info(f" Avg utility reward: {avg_utility:.3f}")
logger.info(
f" MRR: passive={avg_passive_mrr:.3f} -> "
f"active={avg_final_mrr:.3f} -> oracle={avg_oracle_mrr:.3f}"
)
# Per-action reward statistics
all_steps = [
s for t in trajectories for s in t.steps
if s.action != "COMMIT"
]
if all_steps:
action_rewards = {}
for s in all_steps:
if s.action not in action_rewards:
action_rewards[s.action] = []
action_rewards[s.action].append(s.utility_reward)
logger.info(f"\n Per-channel utility statistics:")
for action, rewards in sorted(
action_rewards.items(), key=lambda x: -np.mean(x[1])
):
logger.info(
f" {action:<25} mean_utility={np.mean(rewards):+.3f} "
f"n={len(rewards)} "
f"positive_rate={np.mean([r > 0 for r in rewards]):.2f}"
)