File size: 5,489 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """Classifier-Free Guidance sampling + counterfactual rollouts.
Given a trained Block Diffusion CWM, we sample patient trajectories by
iteratively unmasking tokens from fully-masked input. At each denoise step:
1. Run the conditional model with cond=c -> logits_c
2. Run unconditional model with cond=<NULL> -> logits_null
3. Guided logits = (1+gamma) * logits_c - gamma * logits_null
4. Greedy or temperature-sampled token at each MASK position
5. Top-k confidence remasking schedule (D3PM-style) until t=0
Counterfactual pair:
- sample one batch with cond = treatment
- sample another batch with cond = <NO_TX>
- compare outcome-token distributions to estimate ATE
This is the core counterfactual primitive used by tte_validate.py and
sensitivity.py.
"""
from __future__ import annotations
import logging
import torch
import torch.nn.functional as F
from .block_diffusion import BlockDiffusionTransformer
log = logging.getLogger("gemeo.cwm.sample")
@torch.no_grad()
def cfg_sample(
model: BlockDiffusionTransformer,
cond: torch.Tensor, # (B,) condition ids
*,
seed_prefix: torch.Tensor | None = None, # (B, L) prompt tokens (kept clean)
gamma: float = 2.0,
n_steps: int = None,
temperature: float = 1.0,
null_cond: int = 0,
) -> torch.Tensor:
"""Sample (B, T) trajectories with classifier-free guidance.
seed_prefix is appended unmasked at the front. The rest of the sequence
is filled in by iterative unmasking with confidence-based remasking
schedule.
"""
cfg = model.cfg
n_steps = n_steps or cfg.n_diff_steps
device = cond.device
B = cond.size(0)
T = cfg.max_seq_len
# Start with full MASK
x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
if seed_prefix is not None:
L = seed_prefix.size(1)
x[:, :L] = seed_prefix
fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
fixed_mask[:, :L] = True
else:
fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
ts = torch.linspace(1.0, 0.0, n_steps + 1, device=device)
null = torch.full_like(cond, null_cond)
for k in range(n_steps):
t_cur = ts[k].unsqueeze(0).expand(B)
# Conditional + unconditional logits
logits_c = model(x, t_cur, cond)
logits_n = model(x, t_cur, null)
# CFG
if gamma > 0:
logits = (1 + gamma) * logits_c - gamma * logits_n
else:
logits = logits_c
# Avoid sampling MASK token
logits[:, :, cfg.mask_token] = -1e9
if temperature != 1.0:
logits = logits / max(temperature, 1e-3)
probs = F.softmax(logits, dim=-1)
confs, preds = probs.max(dim=-1)
# Confidence-based remasking: keep top-(1-t_next) fraction of MASKED tokens
t_next = ts[k + 1].item()
keep_frac = 1.0 - t_next
target_kept = int(round(keep_frac * T))
# Already-revealed tokens stay
revealed = (x != cfg.mask_token) | fixed_mask
already_kept = revealed.sum(dim=-1)
# Pick the highest-confidence MASKED positions to unmask
confs_for_pick = torch.where(revealed, torch.full_like(confs, -1e9), confs)
new_x = x.clone()
for b in range(B):
need = max(0, target_kept - int(already_kept[b].item()))
if need == 0:
continue
topi = confs_for_pick[b].topk(need).indices
new_x[b, topi] = preds[b, topi]
x = new_x
# Final pass: replace any remaining MASK with greedy
mask_left = x == cfg.mask_token
if mask_left.any():
t_zero = torch.zeros(B, device=device)
logits_c = model(x, t_zero, cond)
logits_n = model(x, t_zero, null)
logits = (1 + gamma) * logits_c - gamma * logits_n
logits[:, :, cfg.mask_token] = -1e9
preds = logits.argmax(dim=-1)
x = torch.where(mask_left, preds, x)
return x
def counterfactual_pair(
model: BlockDiffusionTransformer,
treatment_cond: int,
null_treatment_cond: int,
*,
seed_prefix: torch.Tensor,
n_samples: int = 100,
gamma: float = 2.0,
device: torch.device | None = None,
) -> dict:
"""Sample n_samples trajectories under treatment AND no-treatment.
Returns dict with both sample batches + per-token frequency tables for
quick comparison of outcome distributions.
"""
device = device or seed_prefix.device
seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)
cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
cond_null = torch.full((n_samples,), null_treatment_cond, device=device, dtype=torch.long)
traj_tx = cfg_sample(model, cond_tx, seed_prefix=seed, gamma=gamma)
traj_null = cfg_sample(model, cond_null, seed_prefix=seed, gamma=gamma)
return {
"traj_treated": traj_tx,
"traj_untreated": traj_null,
"n": n_samples,
"treatment_cond": treatment_cond,
"null_cond": null_treatment_cond,
"gamma": gamma,
}
def outcome_rate(traj: torch.Tensor, target_tok_ids: list[int]) -> float:
"""Fraction of trajectories containing ANY of the target outcome tokens."""
if not target_tok_ids:
return 0.0
target = torch.tensor(target_tok_ids, device=traj.device)
has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
return has.float().mean().item()
|