Spaces:
Sleeping
Sleeping
| """ | |
| Acquisition Policy Learning for ActiveMedAgent. | |
| Three learned policies, all API-based or CPU-only: | |
| 1. RewardWeightedICL: Select the best past trajectories as in-context | |
| examples for the VLM. The VLM sees "here's what worked before on | |
| similar cases" and makes better acquisition decisions. | |
| 2. PolicyNetwork: A small MLP trained on CPU that predicts which channel | |
| to request given a featurized state. Cheap, fast, interpretable. | |
| 3. SelfReflectivePolicy: The VLM critiques its own past failures | |
| and generates an improved acquisition strategy. | |
| All three produce an acquisition policy that replaces the zero-shot | |
| decision in agent.py. | |
| """ | |
| import json | |
| import logging | |
| import random | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import numpy as np | |
| import config | |
| from api_client import BaseVLMClient | |
| from datasets.base import MedicalCase | |
| from trajectory import Trajectory, TrajectoryStep | |
| logger = logging.getLogger(__name__) | |
| # ================================================================ | |
| # Approach 1: Reward-Weighted In-Context Learning (ICL) | |
| # ================================================================ | |
| class RewardWeightedICL: | |
| """ | |
| Learn an acquisition policy via reward-weighted few-shot prompting. | |
| Strategy: | |
| 1. From collected trajectories, identify GOOD acquisition decisions | |
| (positive reward) and BAD ones (negative/zero reward) | |
| 2. For each new case, retrieve the K most similar past cases | |
| (by dataset + channel overlap + uncertainty similarity) | |
| 3. Construct few-shot examples showing good acquisitions | |
| 4. The VLM sees concrete examples of "when uncertain about X, | |
| requesting Y helped" and makes better decisions | |
| This is essentially offline policy improvement via in-context learning. | |
| """ | |
| def __init__( | |
| self, | |
| trajectories: list[Trajectory], | |
| n_examples: int = 3, | |
| min_reward: float = 0.05, | |
| ): | |
| self.n_examples = n_examples | |
| self.min_reward = min_reward | |
| # Index good acquisition decisions | |
| self.good_decisions: list[dict] = [] | |
| self.bad_decisions: list[dict] = [] | |
| for traj in trajectories: | |
| for step in traj.steps: | |
| if step.action == "COMMIT": | |
| continue | |
| decision = { | |
| "case_id": traj.case_id, | |
| "dataset": traj.dataset, | |
| "acquired_before": step.acquired_so_far, | |
| "action": step.action, | |
| "uncertainty": step.uncertainty_text, | |
| "reward": step.utility_reward, | |
| "mrr_reward": step.reward, | |
| "cost": step.acquisition_cost, | |
| "diagnosis_changed": step.diagnosis_changed, | |
| "diagnosis_improved": step.diagnosis_improved, | |
| "mrr_before": step.mrr_before, | |
| "mrr_after": step.mrr_after, | |
| } | |
| if step.utility_reward >= min_reward: | |
| self.good_decisions.append(decision) | |
| else: | |
| self.bad_decisions.append(decision) | |
| logger.info( | |
| f"RewardWeightedICL: {len(self.good_decisions)} good, " | |
| f"{len(self.bad_decisions)} bad decisions indexed" | |
| ) | |
| def get_few_shot_examples( | |
| self, | |
| case: MedicalCase, | |
| acquired_so_far: list[str], | |
| ) -> str: | |
| """ | |
| Retrieve the best few-shot examples for the current case state. | |
| Returns formatted text to prepend to the acquisition prompt. | |
| """ | |
| # Filter to same dataset | |
| candidates = [d for d in self.good_decisions if d["dataset"] == case.dataset] | |
| if not candidates: | |
| candidates = self.good_decisions # Fallback to cross-dataset | |
| # Score by similarity to current state | |
| scored = [] | |
| for d in candidates: | |
| similarity = self._compute_similarity(d, acquired_so_far) | |
| scored.append((similarity, d)) | |
| scored.sort(key=lambda x: (-x[0], -x[1]["reward"])) | |
| # Take top N | |
| selected = scored[: self.n_examples] | |
| if not selected: | |
| return "" | |
| # Format as few-shot examples | |
| lines = [ | |
| "Here are examples of helpful acquisition decisions from similar past cases:\n" | |
| ] | |
| for i, (sim, d) in enumerate(selected): | |
| lines.append(f"Example {i + 1}:") | |
| lines.append(f" Already acquired: {d['acquired_before'] or ['(nothing)']}") | |
| lines.append(f" Uncertainty: {d['uncertainty'][:150]}") | |
| lines.append(f" Decision: REQUEST {d['action']}") | |
| lines.append( | |
| f" Outcome: MRR improved from {d['mrr_before']:.2f} to {d['mrr_after']:.2f} " | |
| f"(reward: {d['reward']:+.3f})" | |
| ) | |
| lines.append("") | |
| lines.append( | |
| "Learn from these examples. Prioritize channels that resolved similar uncertainties.\n" | |
| ) | |
| return "\n".join(lines) | |
| def _compute_similarity(self, decision: dict, acquired_so_far: list[str]) -> float: | |
| """ | |
| Compute similarity between a past decision and current state. | |
| Based on acquisition stage overlap. | |
| """ | |
| past_acquired = set(decision["acquired_before"]) | |
| current_acquired = set(acquired_so_far) | |
| # Jaccard similarity of acquisition state | |
| if not past_acquired and not current_acquired: | |
| return 1.0 # Both at start | |
| union = past_acquired | current_acquired | |
| intersection = past_acquired & current_acquired | |
| stage_sim = len(intersection) / max(len(union), 1) | |
| # Bonus for same acquisition stage (same number of channels acquired) | |
| stage_match = 1.0 if len(past_acquired) == len(current_acquired) else 0.5 | |
| return stage_sim * 0.5 + stage_match * 0.5 | |
| # ================================================================ | |
| # Approach 2: Lightweight Policy Network (CPU-only) | |
| # ================================================================ | |
| class PolicyNetwork: | |
| """ | |
| Small MLP that predicts which channel to request. | |
| State features (input): | |
| - One-hot: which channels have been acquired | |
| - One-hot: which dataset this is | |
| - Scalar: current top-1 confidence | |
| - Scalar: confidence gap (top1 - top2) | |
| - Scalar: acquisition step index (0, 1, 2) | |
| Output: probability distribution over requestable channels. | |
| Trained with cross-entropy loss weighted by trajectory reward. | |
| Runs entirely on CPU — no GPU needed. This is a <1000 parameter model. | |
| """ | |
| def __init__( | |
| self, | |
| all_channels: list[str], | |
| all_datasets: list[str], | |
| hidden_dim: int = 32, | |
| ): | |
| self.all_channels = sorted(all_channels) | |
| self.all_datasets = sorted(all_datasets) | |
| self.channel_to_idx = {c: i for i, c in enumerate(self.all_channels)} | |
| self.dataset_to_idx = {d: i for i, d in enumerate(self.all_datasets)} | |
| self.n_channels = len(self.all_channels) | |
| self.n_datasets = len(self.all_datasets) | |
| # Feature dimension: acquired_mask + dataset_onehot + confidence + gap + step | |
| self.input_dim = self.n_channels + self.n_datasets + 3 | |
| self.hidden_dim = hidden_dim | |
| self.output_dim = self.n_channels | |
| # Initialize weights (small random, CPU numpy) | |
| rng = np.random.RandomState(config.SEED) | |
| scale1 = np.sqrt(2.0 / self.input_dim) | |
| scale2 = np.sqrt(2.0 / hidden_dim) | |
| self.W1 = rng.randn(self.input_dim, hidden_dim).astype(np.float32) * scale1 | |
| self.b1 = np.zeros(hidden_dim, dtype=np.float32) | |
| self.W2 = rng.randn(hidden_dim, self.output_dim).astype(np.float32) * scale2 | |
| self.b2 = np.zeros(self.output_dim, dtype=np.float32) | |
| self.trained = False | |
| def featurize( | |
| self, | |
| dataset: str, | |
| acquired: list[str], | |
| top1_confidence: float, | |
| top2_confidence: float, | |
| step_idx: int, | |
| ) -> np.ndarray: | |
| """Convert state to feature vector.""" | |
| features = np.zeros(self.input_dim, dtype=np.float32) | |
| # Acquired channels mask | |
| for ch in acquired: | |
| if ch in self.channel_to_idx: | |
| features[self.channel_to_idx[ch]] = 1.0 | |
| # Dataset one-hot | |
| offset = self.n_channels | |
| if dataset in self.dataset_to_idx: | |
| features[offset + self.dataset_to_idx[dataset]] = 1.0 | |
| # Scalars | |
| offset += self.n_datasets | |
| features[offset] = top1_confidence | |
| features[offset + 1] = top1_confidence - top2_confidence # Confidence gap | |
| features[offset + 2] = step_idx / 3.0 # Normalized step | |
| return features | |
| def predict( | |
| self, | |
| features: np.ndarray, | |
| available_channels: list[str], | |
| ) -> dict[str, float]: | |
| """ | |
| Forward pass: predict channel selection probabilities. | |
| Returns dict mapping channel_name → probability. | |
| Only available (not yet acquired) channels get nonzero probability. | |
| """ | |
| # Forward pass: input → ReLU → softmax (masked) | |
| h = np.maximum(0, features @ self.W1 + self.b1) # ReLU | |
| logits = h @ self.W2 + self.b2 | |
| # Mask unavailable channels to -inf | |
| mask = np.full(self.output_dim, -1e9, dtype=np.float32) | |
| for ch in available_channels: | |
| if ch in self.channel_to_idx: | |
| mask[self.channel_to_idx[ch]] = 0.0 | |
| logits = logits + mask | |
| # Softmax | |
| logits = logits - logits.max() | |
| exp_logits = np.exp(logits) | |
| probs = exp_logits / (exp_logits.sum() + 1e-8) | |
| return {ch: float(probs[self.channel_to_idx[ch]]) | |
| for ch in available_channels if ch in self.channel_to_idx} | |
| def train( | |
| self, | |
| trajectories: list[Trajectory], | |
| lr: float = 0.01, | |
| n_epochs: int = 100, | |
| reward_temperature: float = 1.0, | |
| ): | |
| """ | |
| Train the policy network on collected trajectories. | |
| Uses reward-weighted cross-entropy: | |
| loss = -sum(reward * log(P(action|state))) | |
| Positive rewards encourage the action; negative discourage it. | |
| """ | |
| # Build training data | |
| X = [] | |
| actions = [] | |
| rewards = [] | |
| available_masks = [] | |
| for traj in trajectories: | |
| for step in traj.steps: | |
| if step.action == "COMMIT": | |
| continue | |
| if step.action not in self.channel_to_idx: | |
| continue | |
| # Extract features from the step's state | |
| top1_conf = step.differential_before[0]["confidence"] if step.differential_before else 0.5 | |
| top2_conf = step.differential_before[1]["confidence"] if len(step.differential_before) > 1 else 0.0 | |
| feat = self.featurize( | |
| dataset=traj.dataset, | |
| acquired=step.acquired_so_far, | |
| top1_confidence=top1_conf, | |
| top2_confidence=top2_conf, | |
| step_idx=step.step_idx, | |
| ) | |
| X.append(feat) | |
| actions.append(self.channel_to_idx[step.action]) | |
| # Reward shaping: normalize across trajectories | |
| rewards.append(step.utility_reward) | |
| # Available channels mask | |
| mask = np.zeros(self.output_dim, dtype=np.float32) | |
| for ch in step.available_channels: | |
| if ch in self.channel_to_idx: | |
| mask[self.channel_to_idx[ch]] = 1.0 | |
| available_masks.append(mask) | |
| if not X: | |
| logger.warning("No training data available for policy network") | |
| return | |
| X = np.array(X) | |
| actions = np.array(actions) | |
| rewards = np.array(rewards) | |
| available_masks = np.array(available_masks) | |
| # Normalize rewards | |
| if rewards.std() > 0: | |
| rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8) | |
| # Apply temperature | |
| weights = np.exp(rewards * reward_temperature) | |
| weights = weights / weights.sum() * len(weights) # Normalize to mean=1 | |
| n = len(X) | |
| logger.info(f"Training policy network on {n} state-action pairs for {n_epochs} epochs") | |
| for epoch in range(n_epochs): | |
| # Forward pass | |
| h = np.maximum(0, X @ self.W1 + self.b1) | |
| logits = h @ self.W2 + self.b2 | |
| # Mask unavailable channels | |
| logits = logits + (1 - available_masks) * (-1e9) | |
| # Softmax | |
| logits_shifted = logits - logits.max(axis=1, keepdims=True) | |
| exp_logits = np.exp(logits_shifted) | |
| probs = exp_logits / (exp_logits.sum(axis=1, keepdims=True) + 1e-8) | |
| # Cross-entropy loss (reward-weighted) | |
| action_probs = probs[np.arange(n), actions] | |
| loss = -np.mean(weights * np.log(action_probs + 1e-8)) | |
| # Backward pass (manual gradient) | |
| # dL/d_logits = probs - one_hot(action), weighted by reward | |
| grad_logits = probs.copy() | |
| grad_logits[np.arange(n), actions] -= 1.0 | |
| grad_logits *= weights[:, np.newaxis] / n | |
| # Gradient for W2, b2 | |
| grad_W2 = h.T @ grad_logits | |
| grad_b2 = grad_logits.sum(axis=0) | |
| # Gradient for W1, b1 (through ReLU) | |
| grad_h = grad_logits @ self.W2.T | |
| grad_h *= (h > 0).astype(np.float32) # ReLU derivative | |
| grad_W1 = X.T @ grad_h | |
| grad_b1 = grad_h.sum(axis=0) | |
| # Update | |
| self.W1 -= lr * grad_W1 | |
| self.b1 -= lr * grad_b1 | |
| self.W2 -= lr * grad_W2 | |
| self.b2 -= lr * grad_b2 | |
| if (epoch + 1) % 20 == 0: | |
| # Compute accuracy | |
| predicted = np.argmax(probs, axis=1) | |
| accuracy = np.mean(predicted == actions) | |
| logger.info(f" Epoch {epoch + 1}: loss={loss:.4f}, accuracy={accuracy:.3f}") | |
| self.trained = True | |
| logger.info("Policy network training complete") | |
| def get_action( | |
| self, | |
| case: MedicalCase, | |
| acquired: list[str], | |
| differential: list[dict], | |
| step_idx: int, | |
| ) -> str: | |
| """Select the best channel to request using the learned policy.""" | |
| available = [ch for ch in case.requestable_names if ch not in acquired] | |
| if not available: | |
| return "COMMIT" | |
| top1_conf = differential[0]["confidence"] if differential else 0.5 | |
| top2_conf = differential[1]["confidence"] if len(differential) > 1 else 0.0 | |
| features = self.featurize( | |
| dataset=case.dataset, | |
| acquired=acquired, | |
| top1_confidence=top1_conf, | |
| top2_confidence=top2_conf, | |
| step_idx=step_idx, | |
| ) | |
| probs = self.predict(features, available) | |
| if not probs: | |
| return random.choice(available) | |
| # Select highest probability channel | |
| best_channel = max(probs, key=probs.get) | |
| return best_channel | |
| def save(self, path: Path): | |
| """Save model weights.""" | |
| np.savez( | |
| path, | |
| W1=self.W1, b1=self.b1, | |
| W2=self.W2, b2=self.b2, | |
| channels=self.all_channels, | |
| datasets=self.all_datasets, | |
| ) | |
| logger.info(f"Saved policy network to {path}") | |
| def load(self, path: Path): | |
| """Load model weights.""" | |
| data = np.load(path, allow_pickle=True) | |
| self.W1 = data["W1"] | |
| self.b1 = data["b1"] | |
| self.W2 = data["W2"] | |
| self.b2 = data["b2"] | |
| self.trained = True | |
| logger.info(f"Loaded policy network from {path}") | |
| # ================================================================ | |
| # Approach 3: Self-Reflective Refinement | |
| # ================================================================ | |
| class SelfReflectivePolicy: | |
| """ | |
| The VLM critiques its own past failures and generates improved strategies. | |
| Pipeline: | |
| 1. Collect cases where zero-shot acquisition was suboptimal | |
| (the agent requested info that didn't help, or missed info that would have) | |
| 2. Show the VLM its own failure traces and ask it to generate | |
| "acquisition rules" — structured if-then policies | |
| 3. Inject these self-generated rules into the system prompt | |
| 4. Re-run with the improved prompt | |
| This is a form of self-play / self-improvement via reflection. | |
| """ | |
| def __init__(self, client: BaseVLMClient, dataset_name: str): | |
| self.client = client | |
| self.dataset_name = dataset_name | |
| self.rules: list[str] = [] | |
| def generate_rules_from_failures( | |
| self, | |
| trajectories: list[Trajectory], | |
| n_failure_examples: int = 10, | |
| ) -> list[str]: | |
| """ | |
| Analyze failures and generate acquisition rules. | |
| A "failure" is a case where: | |
| - Agent requested a channel with zero or negative utility | |
| - Agent didn't request a channel that would have helped | |
| - Agent committed too early (final MRR << oracle MRR) | |
| """ | |
| # Collect failure examples | |
| failures = [] | |
| for traj in trajectories: | |
| if traj.dataset != self.dataset_name: | |
| continue | |
| # Type 1: Unhelpful acquisitions | |
| for step in traj.steps: | |
| if step.action != "COMMIT" and step.utility_reward <= 0: | |
| failures.append({ | |
| "type": "unhelpful_acquisition", | |
| "case_id": traj.case_id, | |
| "action": step.action, | |
| "uncertainty": step.uncertainty_text[:200], | |
| "utility_reward": step.utility_reward, | |
| "mrr_reward": step.reward, | |
| "cost": step.acquisition_cost, | |
| "available": step.available_channels, | |
| }) | |
| # Type 2: Premature commitment | |
| if traj.final_mrr < traj.oracle_mrr - 0.2: | |
| failures.append({ | |
| "type": "premature_commit", | |
| "case_id": traj.case_id, | |
| "acquired": [s.action for s in traj.steps if s.action != "COMMIT"], | |
| "final_mrr": traj.final_mrr, | |
| "oracle_mrr": traj.oracle_mrr, | |
| "gap": traj.oracle_mrr - traj.final_mrr, | |
| }) | |
| if not failures: | |
| logger.info("No failures found — zero-shot policy may already be strong") | |
| return [] | |
| # Sample failures | |
| random.shuffle(failures) | |
| sampled = failures[:n_failure_examples] | |
| # Ask the VLM to analyze and generate rules | |
| failure_text = json.dumps(sampled, indent=2, default=str) | |
| prompt = f"""You are analyzing an AI medical diagnostic agent's acquisition failures on {self.dataset_name} cases. | |
| The agent must decide what additional information to request (imaging modalities, clinical data, etc.) before making a diagnosis. | |
| Here are examples of FAILED acquisition decisions: | |
| {failure_text} | |
| Based on these failures, generate 5-8 specific, actionable ACQUISITION RULES that would improve future decisions. | |
| Format each rule as: | |
| RULE N: IF [condition about the current state/uncertainty] THEN [specific acquisition action] BECAUSE [reasoning] | |
| Rules should be specific to the {self.dataset_name} dataset and its available channels. | |
| Focus on patterns across failures, not individual cases. | |
| Be concrete — "request OCT when uncertain about subretinal fluid" is better than "request more information when uncertain." | |
| Respond ONLY with the rules, no preamble.""" | |
| response = self.client.call_with_retry( | |
| system_prompt="You are an expert in medical diagnostic AI systems.", | |
| user_text=prompt, | |
| images=None, | |
| temperature=0.3, | |
| max_tokens=2048, | |
| ) | |
| # Parse rules | |
| rules = [] | |
| for line in response.text.split("\n"): | |
| line = line.strip() | |
| if line.startswith("RULE") or line.startswith("Rule"): | |
| rules.append(line) | |
| elif rules and line and not line.startswith("RULE"): | |
| # Continuation of previous rule | |
| rules[-1] += " " + line | |
| self.rules = rules | |
| logger.info(f"Generated {len(rules)} acquisition rules from {len(sampled)} failures") | |
| for r in rules: | |
| logger.info(f" {r[:120]}...") | |
| return rules | |
| def get_enhanced_system_prompt(self, base_prompt: str) -> str: | |
| """ | |
| Inject learned rules into the system prompt. | |
| This is the key mechanism: the VLM's behavior is modified | |
| by giving it its own self-generated rules as instructions. | |
| """ | |
| if not self.rules: | |
| return base_prompt | |
| rules_text = "\n".join(self.rules) | |
| injection = f""" | |
| LEARNED ACQUISITION STRATEGY (from analyzing past diagnostic cases): | |
| The following rules have been learned from analyzing cases where acquisition | |
| decisions were suboptimal. Apply these rules when deciding what information to request: | |
| {rules_text} | |
| Apply these rules in addition to your general diagnostic reasoning.""" | |
| return base_prompt + injection | |
| def save_rules(self, path: Path): | |
| """Save generated rules.""" | |
| with open(path, "w") as f: | |
| json.dump({"dataset": self.dataset_name, "rules": self.rules}, f, indent=2) | |
| def load_rules(self, path: Path): | |
| """Load previously generated rules.""" | |
| with open(path) as f: | |
| data = json.load(f) | |
| self.rules = data["rules"] | |
| logger.info(f"Loaded {len(self.rules)} rules for {self.dataset_name}") | |