activemedagent-demo / policy.py
yuxbox's picture
Upload folder using huggingface_hub
a1aaf30 verified
"""
Acquisition Policy Learning for ActiveMedAgent.
Three learned policies, all API-based or CPU-only:
1. RewardWeightedICL: Select the best past trajectories as in-context
examples for the VLM. The VLM sees "here's what worked before on
similar cases" and makes better acquisition decisions.
2. PolicyNetwork: A small MLP trained on CPU that predicts which channel
to request given a featurized state. Cheap, fast, interpretable.
3. SelfReflectivePolicy: The VLM critiques its own past failures
and generates an improved acquisition strategy.
All three produce an acquisition policy that replaces the zero-shot
decision in agent.py.
"""
import json
import logging
import random
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import config
from api_client import BaseVLMClient
from datasets.base import MedicalCase
from trajectory import Trajectory, TrajectoryStep
logger = logging.getLogger(__name__)
# ================================================================
# Approach 1: Reward-Weighted In-Context Learning (ICL)
# ================================================================
class RewardWeightedICL:
"""
Learn an acquisition policy via reward-weighted few-shot prompting.
Strategy:
1. From collected trajectories, identify GOOD acquisition decisions
(positive reward) and BAD ones (negative/zero reward)
2. For each new case, retrieve the K most similar past cases
(by dataset + channel overlap + uncertainty similarity)
3. Construct few-shot examples showing good acquisitions
4. The VLM sees concrete examples of "when uncertain about X,
requesting Y helped" and makes better decisions
This is essentially offline policy improvement via in-context learning.
"""
def __init__(
self,
trajectories: list[Trajectory],
n_examples: int = 3,
min_reward: float = 0.05,
):
self.n_examples = n_examples
self.min_reward = min_reward
# Index good acquisition decisions
self.good_decisions: list[dict] = []
self.bad_decisions: list[dict] = []
for traj in trajectories:
for step in traj.steps:
if step.action == "COMMIT":
continue
decision = {
"case_id": traj.case_id,
"dataset": traj.dataset,
"acquired_before": step.acquired_so_far,
"action": step.action,
"uncertainty": step.uncertainty_text,
"reward": step.utility_reward,
"mrr_reward": step.reward,
"cost": step.acquisition_cost,
"diagnosis_changed": step.diagnosis_changed,
"diagnosis_improved": step.diagnosis_improved,
"mrr_before": step.mrr_before,
"mrr_after": step.mrr_after,
}
if step.utility_reward >= min_reward:
self.good_decisions.append(decision)
else:
self.bad_decisions.append(decision)
logger.info(
f"RewardWeightedICL: {len(self.good_decisions)} good, "
f"{len(self.bad_decisions)} bad decisions indexed"
)
def get_few_shot_examples(
self,
case: MedicalCase,
acquired_so_far: list[str],
) -> str:
"""
Retrieve the best few-shot examples for the current case state.
Returns formatted text to prepend to the acquisition prompt.
"""
# Filter to same dataset
candidates = [d for d in self.good_decisions if d["dataset"] == case.dataset]
if not candidates:
candidates = self.good_decisions # Fallback to cross-dataset
# Score by similarity to current state
scored = []
for d in candidates:
similarity = self._compute_similarity(d, acquired_so_far)
scored.append((similarity, d))
scored.sort(key=lambda x: (-x[0], -x[1]["reward"]))
# Take top N
selected = scored[: self.n_examples]
if not selected:
return ""
# Format as few-shot examples
lines = [
"Here are examples of helpful acquisition decisions from similar past cases:\n"
]
for i, (sim, d) in enumerate(selected):
lines.append(f"Example {i + 1}:")
lines.append(f" Already acquired: {d['acquired_before'] or ['(nothing)']}")
lines.append(f" Uncertainty: {d['uncertainty'][:150]}")
lines.append(f" Decision: REQUEST {d['action']}")
lines.append(
f" Outcome: MRR improved from {d['mrr_before']:.2f} to {d['mrr_after']:.2f} "
f"(reward: {d['reward']:+.3f})"
)
lines.append("")
lines.append(
"Learn from these examples. Prioritize channels that resolved similar uncertainties.\n"
)
return "\n".join(lines)
def _compute_similarity(self, decision: dict, acquired_so_far: list[str]) -> float:
"""
Compute similarity between a past decision and current state.
Based on acquisition stage overlap.
"""
past_acquired = set(decision["acquired_before"])
current_acquired = set(acquired_so_far)
# Jaccard similarity of acquisition state
if not past_acquired and not current_acquired:
return 1.0 # Both at start
union = past_acquired | current_acquired
intersection = past_acquired & current_acquired
stage_sim = len(intersection) / max(len(union), 1)
# Bonus for same acquisition stage (same number of channels acquired)
stage_match = 1.0 if len(past_acquired) == len(current_acquired) else 0.5
return stage_sim * 0.5 + stage_match * 0.5
# ================================================================
# Approach 2: Lightweight Policy Network (CPU-only)
# ================================================================
class PolicyNetwork:
"""
Small MLP that predicts which channel to request.
State features (input):
- One-hot: which channels have been acquired
- One-hot: which dataset this is
- Scalar: current top-1 confidence
- Scalar: confidence gap (top1 - top2)
- Scalar: acquisition step index (0, 1, 2)
Output: probability distribution over requestable channels.
Trained with cross-entropy loss weighted by trajectory reward.
Runs entirely on CPU — no GPU needed. This is a <1000 parameter model.
"""
def __init__(
self,
all_channels: list[str],
all_datasets: list[str],
hidden_dim: int = 32,
):
self.all_channels = sorted(all_channels)
self.all_datasets = sorted(all_datasets)
self.channel_to_idx = {c: i for i, c in enumerate(self.all_channels)}
self.dataset_to_idx = {d: i for i, d in enumerate(self.all_datasets)}
self.n_channels = len(self.all_channels)
self.n_datasets = len(self.all_datasets)
# Feature dimension: acquired_mask + dataset_onehot + confidence + gap + step
self.input_dim = self.n_channels + self.n_datasets + 3
self.hidden_dim = hidden_dim
self.output_dim = self.n_channels
# Initialize weights (small random, CPU numpy)
rng = np.random.RandomState(config.SEED)
scale1 = np.sqrt(2.0 / self.input_dim)
scale2 = np.sqrt(2.0 / hidden_dim)
self.W1 = rng.randn(self.input_dim, hidden_dim).astype(np.float32) * scale1
self.b1 = np.zeros(hidden_dim, dtype=np.float32)
self.W2 = rng.randn(hidden_dim, self.output_dim).astype(np.float32) * scale2
self.b2 = np.zeros(self.output_dim, dtype=np.float32)
self.trained = False
def featurize(
self,
dataset: str,
acquired: list[str],
top1_confidence: float,
top2_confidence: float,
step_idx: int,
) -> np.ndarray:
"""Convert state to feature vector."""
features = np.zeros(self.input_dim, dtype=np.float32)
# Acquired channels mask
for ch in acquired:
if ch in self.channel_to_idx:
features[self.channel_to_idx[ch]] = 1.0
# Dataset one-hot
offset = self.n_channels
if dataset in self.dataset_to_idx:
features[offset + self.dataset_to_idx[dataset]] = 1.0
# Scalars
offset += self.n_datasets
features[offset] = top1_confidence
features[offset + 1] = top1_confidence - top2_confidence # Confidence gap
features[offset + 2] = step_idx / 3.0 # Normalized step
return features
def predict(
self,
features: np.ndarray,
available_channels: list[str],
) -> dict[str, float]:
"""
Forward pass: predict channel selection probabilities.
Returns dict mapping channel_name → probability.
Only available (not yet acquired) channels get nonzero probability.
"""
# Forward pass: input → ReLU → softmax (masked)
h = np.maximum(0, features @ self.W1 + self.b1) # ReLU
logits = h @ self.W2 + self.b2
# Mask unavailable channels to -inf
mask = np.full(self.output_dim, -1e9, dtype=np.float32)
for ch in available_channels:
if ch in self.channel_to_idx:
mask[self.channel_to_idx[ch]] = 0.0
logits = logits + mask
# Softmax
logits = logits - logits.max()
exp_logits = np.exp(logits)
probs = exp_logits / (exp_logits.sum() + 1e-8)
return {ch: float(probs[self.channel_to_idx[ch]])
for ch in available_channels if ch in self.channel_to_idx}
def train(
self,
trajectories: list[Trajectory],
lr: float = 0.01,
n_epochs: int = 100,
reward_temperature: float = 1.0,
):
"""
Train the policy network on collected trajectories.
Uses reward-weighted cross-entropy:
loss = -sum(reward * log(P(action|state)))
Positive rewards encourage the action; negative discourage it.
"""
# Build training data
X = []
actions = []
rewards = []
available_masks = []
for traj in trajectories:
for step in traj.steps:
if step.action == "COMMIT":
continue
if step.action not in self.channel_to_idx:
continue
# Extract features from the step's state
top1_conf = step.differential_before[0]["confidence"] if step.differential_before else 0.5
top2_conf = step.differential_before[1]["confidence"] if len(step.differential_before) > 1 else 0.0
feat = self.featurize(
dataset=traj.dataset,
acquired=step.acquired_so_far,
top1_confidence=top1_conf,
top2_confidence=top2_conf,
step_idx=step.step_idx,
)
X.append(feat)
actions.append(self.channel_to_idx[step.action])
# Reward shaping: normalize across trajectories
rewards.append(step.utility_reward)
# Available channels mask
mask = np.zeros(self.output_dim, dtype=np.float32)
for ch in step.available_channels:
if ch in self.channel_to_idx:
mask[self.channel_to_idx[ch]] = 1.0
available_masks.append(mask)
if not X:
logger.warning("No training data available for policy network")
return
X = np.array(X)
actions = np.array(actions)
rewards = np.array(rewards)
available_masks = np.array(available_masks)
# Normalize rewards
if rewards.std() > 0:
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# Apply temperature
weights = np.exp(rewards * reward_temperature)
weights = weights / weights.sum() * len(weights) # Normalize to mean=1
n = len(X)
logger.info(f"Training policy network on {n} state-action pairs for {n_epochs} epochs")
for epoch in range(n_epochs):
# Forward pass
h = np.maximum(0, X @ self.W1 + self.b1)
logits = h @ self.W2 + self.b2
# Mask unavailable channels
logits = logits + (1 - available_masks) * (-1e9)
# Softmax
logits_shifted = logits - logits.max(axis=1, keepdims=True)
exp_logits = np.exp(logits_shifted)
probs = exp_logits / (exp_logits.sum(axis=1, keepdims=True) + 1e-8)
# Cross-entropy loss (reward-weighted)
action_probs = probs[np.arange(n), actions]
loss = -np.mean(weights * np.log(action_probs + 1e-8))
# Backward pass (manual gradient)
# dL/d_logits = probs - one_hot(action), weighted by reward
grad_logits = probs.copy()
grad_logits[np.arange(n), actions] -= 1.0
grad_logits *= weights[:, np.newaxis] / n
# Gradient for W2, b2
grad_W2 = h.T @ grad_logits
grad_b2 = grad_logits.sum(axis=0)
# Gradient for W1, b1 (through ReLU)
grad_h = grad_logits @ self.W2.T
grad_h *= (h > 0).astype(np.float32) # ReLU derivative
grad_W1 = X.T @ grad_h
grad_b1 = grad_h.sum(axis=0)
# Update
self.W1 -= lr * grad_W1
self.b1 -= lr * grad_b1
self.W2 -= lr * grad_W2
self.b2 -= lr * grad_b2
if (epoch + 1) % 20 == 0:
# Compute accuracy
predicted = np.argmax(probs, axis=1)
accuracy = np.mean(predicted == actions)
logger.info(f" Epoch {epoch + 1}: loss={loss:.4f}, accuracy={accuracy:.3f}")
self.trained = True
logger.info("Policy network training complete")
def get_action(
self,
case: MedicalCase,
acquired: list[str],
differential: list[dict],
step_idx: int,
) -> str:
"""Select the best channel to request using the learned policy."""
available = [ch for ch in case.requestable_names if ch not in acquired]
if not available:
return "COMMIT"
top1_conf = differential[0]["confidence"] if differential else 0.5
top2_conf = differential[1]["confidence"] if len(differential) > 1 else 0.0
features = self.featurize(
dataset=case.dataset,
acquired=acquired,
top1_confidence=top1_conf,
top2_confidence=top2_conf,
step_idx=step_idx,
)
probs = self.predict(features, available)
if not probs:
return random.choice(available)
# Select highest probability channel
best_channel = max(probs, key=probs.get)
return best_channel
def save(self, path: Path):
"""Save model weights."""
np.savez(
path,
W1=self.W1, b1=self.b1,
W2=self.W2, b2=self.b2,
channels=self.all_channels,
datasets=self.all_datasets,
)
logger.info(f"Saved policy network to {path}")
def load(self, path: Path):
"""Load model weights."""
data = np.load(path, allow_pickle=True)
self.W1 = data["W1"]
self.b1 = data["b1"]
self.W2 = data["W2"]
self.b2 = data["b2"]
self.trained = True
logger.info(f"Loaded policy network from {path}")
# ================================================================
# Approach 3: Self-Reflective Refinement
# ================================================================
class SelfReflectivePolicy:
"""
The VLM critiques its own past failures and generates improved strategies.
Pipeline:
1. Collect cases where zero-shot acquisition was suboptimal
(the agent requested info that didn't help, or missed info that would have)
2. Show the VLM its own failure traces and ask it to generate
"acquisition rules" — structured if-then policies
3. Inject these self-generated rules into the system prompt
4. Re-run with the improved prompt
This is a form of self-play / self-improvement via reflection.
"""
def __init__(self, client: BaseVLMClient, dataset_name: str):
self.client = client
self.dataset_name = dataset_name
self.rules: list[str] = []
def generate_rules_from_failures(
self,
trajectories: list[Trajectory],
n_failure_examples: int = 10,
) -> list[str]:
"""
Analyze failures and generate acquisition rules.
A "failure" is a case where:
- Agent requested a channel with zero or negative utility
- Agent didn't request a channel that would have helped
- Agent committed too early (final MRR << oracle MRR)
"""
# Collect failure examples
failures = []
for traj in trajectories:
if traj.dataset != self.dataset_name:
continue
# Type 1: Unhelpful acquisitions
for step in traj.steps:
if step.action != "COMMIT" and step.utility_reward <= 0:
failures.append({
"type": "unhelpful_acquisition",
"case_id": traj.case_id,
"action": step.action,
"uncertainty": step.uncertainty_text[:200],
"utility_reward": step.utility_reward,
"mrr_reward": step.reward,
"cost": step.acquisition_cost,
"available": step.available_channels,
})
# Type 2: Premature commitment
if traj.final_mrr < traj.oracle_mrr - 0.2:
failures.append({
"type": "premature_commit",
"case_id": traj.case_id,
"acquired": [s.action for s in traj.steps if s.action != "COMMIT"],
"final_mrr": traj.final_mrr,
"oracle_mrr": traj.oracle_mrr,
"gap": traj.oracle_mrr - traj.final_mrr,
})
if not failures:
logger.info("No failures found — zero-shot policy may already be strong")
return []
# Sample failures
random.shuffle(failures)
sampled = failures[:n_failure_examples]
# Ask the VLM to analyze and generate rules
failure_text = json.dumps(sampled, indent=2, default=str)
prompt = f"""You are analyzing an AI medical diagnostic agent's acquisition failures on {self.dataset_name} cases.
The agent must decide what additional information to request (imaging modalities, clinical data, etc.) before making a diagnosis.
Here are examples of FAILED acquisition decisions:
{failure_text}
Based on these failures, generate 5-8 specific, actionable ACQUISITION RULES that would improve future decisions.
Format each rule as:
RULE N: IF [condition about the current state/uncertainty] THEN [specific acquisition action] BECAUSE [reasoning]
Rules should be specific to the {self.dataset_name} dataset and its available channels.
Focus on patterns across failures, not individual cases.
Be concrete — "request OCT when uncertain about subretinal fluid" is better than "request more information when uncertain."
Respond ONLY with the rules, no preamble."""
response = self.client.call_with_retry(
system_prompt="You are an expert in medical diagnostic AI systems.",
user_text=prompt,
images=None,
temperature=0.3,
max_tokens=2048,
)
# Parse rules
rules = []
for line in response.text.split("\n"):
line = line.strip()
if line.startswith("RULE") or line.startswith("Rule"):
rules.append(line)
elif rules and line and not line.startswith("RULE"):
# Continuation of previous rule
rules[-1] += " " + line
self.rules = rules
logger.info(f"Generated {len(rules)} acquisition rules from {len(sampled)} failures")
for r in rules:
logger.info(f" {r[:120]}...")
return rules
def get_enhanced_system_prompt(self, base_prompt: str) -> str:
"""
Inject learned rules into the system prompt.
This is the key mechanism: the VLM's behavior is modified
by giving it its own self-generated rules as instructions.
"""
if not self.rules:
return base_prompt
rules_text = "\n".join(self.rules)
injection = f"""
LEARNED ACQUISITION STRATEGY (from analyzing past diagnostic cases):
The following rules have been learned from analyzing cases where acquisition
decisions were suboptimal. Apply these rules when deciding what information to request:
{rules_text}
Apply these rules in addition to your general diagnostic reasoning."""
return base_prompt + injection
def save_rules(self, path: Path):
"""Save generated rules."""
with open(path, "w") as f:
json.dump({"dataset": self.dataset_name, "rules": self.rules}, f, indent=2)
def load_rules(self, path: Path):
"""Load previously generated rules."""
with open(path) as f:
data = json.load(f)
self.rules = data["rules"]
logger.info(f"Loaded {len(self.rules)} rules for {self.dataset_name}")