Spaces:
Sleeping
Sleeping
| """ | |
| Trajectory Collection for ActiveMedAgent. | |
| Phase 1 of the training pipeline: | |
| 1. Run zero-shot agent on all cases | |
| 2. Record full (state, action, reward) trajectories | |
| 3. Compute per-step rewards: did the acquisition improve the diagnosis? | |
| 4. Save trajectory dataset for Phase 2 policy learning | |
| Each trajectory step records: | |
| - state: current uncertainty, differential, acquired channels so far | |
| - action: which channel was requested | |
| - reward: MRR improvement after receiving the requested info | |
| - outcome: final diagnosis correctness | |
| """ | |
| import json | |
| import logging | |
| import random | |
| from dataclasses import dataclass, field, asdict | |
| from pathlib import Path | |
| import numpy as np | |
| from tqdm import tqdm | |
| import config | |
| from api_client import BaseVLMClient, create_client | |
| from agent import ActiveMedAgent, AgentResult | |
| from datasets.base import MedicalCase | |
| from evaluation import compute_reciprocal_rank | |
| logger = logging.getLogger(__name__) | |
| class TrajectoryStep: | |
| """A single step in an acquisition trajectory.""" | |
| step_idx: int | |
| # State representation | |
| acquired_so_far: list[str] | |
| available_channels: list[str] | |
| uncertainty_text: str | |
| differential_before: list[dict] # Ranking before this acquisition | |
| mrr_before: float | |
| # Action | |
| action: str # Channel name requested (or "COMMIT") | |
| # Outcome (computed after the action) | |
| differential_after: list[dict] # Ranking after receiving the info | |
| mrr_after: float | |
| reward: float # MRR improvement: mrr_after - mrr_before | |
| acquisition_cost: float = 0.0 | |
| normalized_cost: float = 0.0 | |
| utility_reward: float = 0.0 # Cost-aware reward used for policy learning | |
| diagnosis_changed: bool = False # Did top-1 change? | |
| diagnosis_improved: bool = False # Did it change to the correct answer? | |
| class Trajectory: | |
| """Complete trajectory for one case.""" | |
| case_id: str | |
| dataset: str | |
| ground_truth: str | |
| candidates: list[str] | |
| steps: list[TrajectoryStep] = field(default_factory=list) | |
| passive_mrr: float = 0.0 | |
| oracle_mrr: float = 0.0 | |
| final_mrr: float = 0.0 | |
| total_reward: float = 0.0 | |
| total_utility_reward: float = 0.0 | |
| success: bool = False # Did the agent get top-1 correct? | |
| class TrajectoryCollector: | |
| """ | |
| Collect acquisition trajectories with per-step rewards. | |
| Unlike the basic agent.diagnose(), this method runs the agent | |
| step-by-step, evaluating the diagnosis after EACH acquisition | |
| to compute fine-grained reward signals. | |
| Uses the tool-use agent architecture: runs the full agent for | |
| acquisition decisions, then evaluates intermediate states via | |
| the agent's get_diagnosis_at_state() helper. | |
| """ | |
| def __init__( | |
| self, | |
| client: BaseVLMClient, | |
| prompt_variant: str = "A", | |
| budget: int = 3, | |
| ): | |
| self.client = client | |
| self.prompt_variant = prompt_variant | |
| self.budget = budget | |
| def collect_trajectory(self, case: MedicalCase) -> Trajectory: | |
| """ | |
| Collect a full trajectory with per-step rewards for one case. | |
| Strategy: | |
| 1. Get passive baseline (image-only diagnosis) | |
| 2. Get oracle ceiling (all-info diagnosis) | |
| 3. Run the active agent and record its decisions | |
| 4. For each acquisition step, evaluate the intermediate | |
| diagnosis to compute per-step MRR reward | |
| """ | |
| traj = Trajectory( | |
| case_id=case.case_id, | |
| dataset=case.dataset, | |
| ground_truth=case.ground_truth, | |
| candidates=case.candidates, | |
| ) | |
| # ---- Evaluation agent (budget=0, just for scoring) ---- | |
| eval_agent = ActiveMedAgent( | |
| self.client, self.prompt_variant, budget=0 | |
| ) | |
| # ---- Get passive baseline (MRR with no acquisition) ---- | |
| passive_result = eval_agent.diagnose_passive(case) | |
| passive_mrr = compute_reciprocal_rank( | |
| passive_result.final_ranking, case.ground_truth, case.candidates | |
| ) | |
| traj.passive_mrr = passive_mrr | |
| # ---- Get oracle ceiling (MRR with all info) ---- | |
| oracle_result = eval_agent.diagnose_oracle(case) | |
| oracle_mrr = compute_reciprocal_rank( | |
| oracle_result.final_ranking, case.ground_truth, case.candidates | |
| ) | |
| traj.oracle_mrr = oracle_mrr | |
| # ---- Run the active agent to get its acquisition decisions ---- | |
| active_agent = ActiveMedAgent( | |
| self.client, self.prompt_variant, budget=self.budget | |
| ) | |
| active_result = active_agent.diagnose(case) | |
| # ---- Evaluate each intermediate state ---- | |
| current_mrr = passive_mrr | |
| current_ranking = passive_result.final_ranking | |
| acquired_so_far = [] | |
| for step_idx, step in enumerate(active_result.steps): | |
| if step.committed: | |
| # Agent committed early — record and stop | |
| traj_step = TrajectoryStep( | |
| step_idx=step_idx, | |
| acquired_so_far=list(acquired_so_far), | |
| available_channels=[ | |
| n for n in case.requestable_names | |
| if n not in acquired_so_far | |
| ], | |
| uncertainty_text=step.reasoning or "", | |
| differential_before=current_ranking, | |
| mrr_before=current_mrr, | |
| action="COMMIT", | |
| differential_after=current_ranking, | |
| mrr_after=current_mrr, | |
| reward=0.0, | |
| acquisition_cost=0.0, | |
| normalized_cost=0.0, | |
| utility_reward=0.0, | |
| diagnosis_changed=False, | |
| diagnosis_improved=False, | |
| ) | |
| traj.steps.append(traj_step) | |
| break | |
| channel = step.requested_channel | |
| if not channel: | |
| continue | |
| available = [ | |
| n for n in case.requestable_names | |
| if n not in acquired_so_far | |
| ] | |
| # Record the state BEFORE this acquisition | |
| before_ranking = current_ranking | |
| before_mrr = current_mrr | |
| # Execute the acquisition | |
| acquired_so_far.append(channel) | |
| # Evaluate the diagnosis AFTER this acquisition | |
| after_ranking, _ = eval_agent.get_diagnosis_at_state( | |
| case, list(acquired_so_far) | |
| ) | |
| after_mrr = compute_reciprocal_rank( | |
| after_ranking, case.ground_truth, case.candidates | |
| ) | |
| # Compute reward | |
| reward = after_mrr - before_mrr | |
| channel_cost = case.get_channel_cost(channel) | |
| max_requestable_cost = max(case.get_max_requestable_cost(), 1.0) | |
| normalized_cost = channel_cost / max_requestable_cost | |
| utility_reward = reward - ( | |
| config.COST_PENALTY_LAMBDA * normalized_cost | |
| ) | |
| # Did diagnosis change? | |
| top1_before = ( | |
| before_ranking[0]["name"] if before_ranking else "" | |
| ) | |
| top1_after = ( | |
| after_ranking[0]["name"] if after_ranking else "" | |
| ) | |
| diagnosis_changed = ( | |
| top1_before.lower() != top1_after.lower() | |
| ) | |
| gt_lower = case.ground_truth.lower() | |
| diagnosis_improved = ( | |
| diagnosis_changed | |
| and ( | |
| gt_lower in top1_after.lower() | |
| or top1_after.lower() in gt_lower | |
| ) | |
| ) | |
| traj_step = TrajectoryStep( | |
| step_idx=step_idx, | |
| acquired_so_far=list(acquired_so_far[:-1]), | |
| available_channels=available, | |
| uncertainty_text=step.reasoning or "", | |
| differential_before=before_ranking, | |
| mrr_before=before_mrr, | |
| action=channel, | |
| differential_after=after_ranking, | |
| mrr_after=after_mrr, | |
| reward=reward, | |
| acquisition_cost=channel_cost, | |
| normalized_cost=normalized_cost, | |
| utility_reward=utility_reward, | |
| diagnosis_changed=diagnosis_changed, | |
| diagnosis_improved=diagnosis_improved, | |
| ) | |
| traj.steps.append(traj_step) | |
| # Update state for next step | |
| current_mrr = after_mrr | |
| current_ranking = after_ranking | |
| # ---- Finalize trajectory ---- | |
| traj.final_mrr = current_mrr | |
| traj.total_reward = sum(s.reward for s in traj.steps) | |
| traj.total_utility_reward = sum(s.utility_reward for s in traj.steps) | |
| traj.success = (current_mrr == 1.0) | |
| return traj | |
| def collect_dataset( | |
| self, | |
| cases: list[MedicalCase], | |
| max_cases: int = None, | |
| save_path: Path = None, | |
| ) -> list[Trajectory]: | |
| """Collect trajectories for all cases.""" | |
| if max_cases: | |
| cases = cases[:max_cases] | |
| trajectories = [] | |
| for case in tqdm(cases, desc="Collecting trajectories", ncols=80): | |
| try: | |
| traj = self.collect_trajectory(case) | |
| trajectories.append(traj) | |
| except Exception as e: | |
| logger.error(f"Failed on {case.case_id}: {e}") | |
| continue | |
| # Save | |
| if save_path: | |
| save_path = Path(save_path) | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(save_path, "w") as f: | |
| json.dump( | |
| [asdict(t) for t in trajectories], | |
| f, indent=2, default=str, | |
| ) | |
| logger.info(f"Saved {len(trajectories)} trajectories to {save_path}") | |
| # Report statistics | |
| self._report_stats(trajectories) | |
| return trajectories | |
| def _report_stats(self, trajectories: list[Trajectory]): | |
| """Log summary statistics of collected trajectories.""" | |
| n = len(trajectories) | |
| if n == 0: | |
| return | |
| logger.info(f"\n{'='*50}") | |
| logger.info(f"Trajectory Collection Summary (n={n})") | |
| logger.info(f"{'='*50}") | |
| success_rate = np.mean([t.success for t in trajectories]) | |
| avg_steps = np.mean([len(t.steps) for t in trajectories]) | |
| avg_reward = np.mean([t.total_reward for t in trajectories]) | |
| avg_utility = np.mean([t.total_utility_reward for t in trajectories]) | |
| avg_passive_mrr = np.mean([t.passive_mrr for t in trajectories]) | |
| avg_final_mrr = np.mean([t.final_mrr for t in trajectories]) | |
| avg_oracle_mrr = np.mean([t.oracle_mrr for t in trajectories]) | |
| logger.info(f" Success rate: {success_rate:.3f}") | |
| logger.info(f" Avg steps taken: {avg_steps:.1f}") | |
| logger.info(f" Avg total reward: {avg_reward:.3f}") | |
| logger.info(f" Avg utility reward: {avg_utility:.3f}") | |
| logger.info( | |
| f" MRR: passive={avg_passive_mrr:.3f} -> " | |
| f"active={avg_final_mrr:.3f} -> oracle={avg_oracle_mrr:.3f}" | |
| ) | |
| # Per-action reward statistics | |
| all_steps = [ | |
| s for t in trajectories for s in t.steps | |
| if s.action != "COMMIT" | |
| ] | |
| if all_steps: | |
| action_rewards = {} | |
| for s in all_steps: | |
| if s.action not in action_rewards: | |
| action_rewards[s.action] = [] | |
| action_rewards[s.action].append(s.utility_reward) | |
| logger.info(f"\n Per-channel utility statistics:") | |
| for action, rewards in sorted( | |
| action_rewards.items(), key=lambda x: -np.mean(x[1]) | |
| ): | |
| logger.info( | |
| f" {action:<25} mean_utility={np.mean(rewards):+.3f} " | |
| f"n={len(rewards)} " | |
| f"positive_rate={np.mean([r > 0 for r in rewards]):.2f}" | |
| ) | |