File size: 8,464 Bytes
908ea05 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | """Sampling primitives for CDF: AR mode, denoise mode, counterfactual rollouts.
Diffusion Forcing flexibility — the same model handles:
AR mode:
Sigma_future = 1, sigma_past = 0. Roll forward like an autoregressive
transformer but with per-token noise control.
Denoise mode (bidirectional):
Sigma low everywhere. Run k denoise steps, model fills the whole sequence.
Counterfactual mode (the TTE primitive):
Sigma=0 on observed tokens (clamp them clean), sigma=1 on tokens to
generate. Condition on (cohort, intervention_action_id). Sample N times,
compare distributions of outcome tokens.
CFG (classifier-free guidance) wraps any mode:
logits_g = (1 + gamma) * logits(c) - gamma * logits(null_c)
Shortcut Forcing (Dreamer 4) reduces denoise steps from 32-64 to 4 via
distilled student model — implemented in distill.py.
"""
from __future__ import annotations
import logging
import torch
import torch.nn.functional as F
from .diffusion_forcing import CDFTransformer
log = logging.getLogger("gemeo.cdf.sample")
@torch.no_grad()
def sample_denoise(
model: CDFTransformer,
cond: torch.Tensor,
*,
seed_prefix: torch.Tensor | None = None,
observed_mask: torch.Tensor | None = None, # (B, T) True = clamped clean
action: torch.Tensor | None = None,
gamma: float = 2.0,
n_steps: int = 32,
null_cond: int = 0,
schedule: str = "cosine",
) -> torch.Tensor:
"""Denoise-mode sampling: fully-masked sequence + iterative refinement.
Supports:
- seed_prefix: clean tokens kept at sigma=0 for positions [0, L)
- observed_mask: arbitrary positions to clamp (counterfactual mode)
- CFG via (cond, null_cond) pair
"""
cfg = model.cfg
device = cond.device
B = cond.size(0)
T = cfg.max_seq_len
# Init with MASK
x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
if seed_prefix is not None:
L = seed_prefix.size(1)
x[:, :L] = seed_prefix
fixed_mask[:, :L] = True
if observed_mask is not None:
fixed_mask |= observed_mask
# Build noise schedule
if schedule == "cosine":
# smooth cosine from 1 -> 0
ts = torch.cos(torch.linspace(0, torch.pi/2, n_steps+1, device=device))
else:
ts = torch.linspace(1.0, 0.0, n_steps+1, device=device)
null = torch.full_like(cond, null_cond)
null_action = (torch.full_like(action, cfg.n_latent_actions)
if action is not None and cfg.use_latent_action else None)
for k in range(n_steps):
# Per-token sigma: fixed positions at 0, dynamic positions at ts[k]
sigma = torch.where(fixed_mask, torch.zeros_like(ts[k:k+1]).expand(B, T),
torch.full((B, T), ts[k].item(), device=device))
logits_c = model(x, sigma, cond, action)
if gamma > 0:
logits_n = model(x, sigma, null, null_action)
logits = (1 + gamma) * logits_c - gamma * logits_n
else:
logits = logits_c
logits[:, :, cfg.mask_token] = -1e9
probs = F.softmax(logits, dim=-1)
confs, preds = probs.max(dim=-1)
# Confidence-based remasking: reveal top-(1 - ts[k+1]) fraction of free tokens
t_next = ts[k+1].item()
target_kept = int(round((1 - t_next) * T))
revealed = (x != cfg.mask_token) | fixed_mask
already = revealed.sum(dim=-1)
new_x = x.clone()
for b in range(B):
need = max(0, target_kept - int(already[b].item()))
if need == 0:
continue
confs_b = torch.where(revealed[b], torch.full_like(confs[b], -1e9), confs[b])
topi = confs_b.topk(need).indices
new_x[b, topi] = preds[b, topi]
x = new_x
# Final cleanup
mask_left = x == cfg.mask_token
if mask_left.any():
sigma_final = torch.zeros(B, T, device=device)
logits_c = model(x, sigma_final, cond, action)
if gamma > 0:
logits_n = model(x, sigma_final, null, null_action)
logits = (1 + gamma) * logits_c - gamma * logits_n
else:
logits = logits_c
logits[:, :, cfg.mask_token] = -1e9
preds = logits.argmax(-1)
x = torch.where(mask_left, preds, x)
return x
@torch.no_grad()
def sample_ar(
model: CDFTransformer,
cond: torch.Tensor,
prefix: torch.Tensor,
*,
action: torch.Tensor | None = None,
max_new: int = 50,
temperature: float = 1.0,
gamma: float = 0.0,
null_cond: int = 0,
) -> torch.Tensor:
"""AR-mode sampling: future tokens at sigma=1, past at sigma=0.
Faster than denoise mode when you only want to continue a prefix.
"""
cfg = model.cfg
device = cond.device
B = cond.size(0)
x = prefix.clone().to(device)
if x.dim() == 1: x = x.unsqueeze(0)
null = torch.full_like(cond, null_cond)
null_action = (torch.full_like(action, cfg.n_latent_actions)
if action is not None and cfg.use_latent_action else None)
for _ in range(max_new):
T_now = x.size(1)
if T_now >= cfg.max_seq_len:
break
# Pad with MASK
x_pad = torch.cat([x, torch.full((B, 1), cfg.mask_token,
device=device, dtype=torch.long)], dim=1)
sigma = torch.zeros(B, T_now + 1, device=device)
sigma[:, -1] = 1.0
a_pad = None
if action is not None and cfg.use_latent_action:
a_pad = torch.cat([action[:, :T_now],
torch.full((B, 1), cfg.n_latent_actions,
device=device, dtype=torch.long)], dim=1)
logits = model(x_pad, sigma, cond, a_pad)
if gamma > 0:
logits_n = model(x_pad, sigma, null, null_action)
logits = (1 + gamma) * logits - gamma * logits_n
logits[:, :, cfg.mask_token] = -1e9
p = F.softmax(logits[:, -1] / max(temperature, 1e-3), dim=-1)
nxt = torch.multinomial(p, 1)
x = torch.cat([x, nxt], dim=1)
return x
@torch.no_grad()
def counterfactual_rollout(
model: CDFTransformer,
seed_prefix: torch.Tensor,
treatment_cond: int,
untreated_cond: int,
*,
treatment_action: int | None = None,
untreated_action: int | None = None,
n_samples: int = 100,
gamma: float = 2.0,
n_steps: int = 32,
) -> dict:
"""Sample paired counterfactual trajectories under treatment vs no-treatment.
Two ways to specify the intervention:
- via cond id (cohort-level): treatment_cond / untreated_cond
- via latent action id (per-token): treatment_action / untreated_action
"""
cfg = model.cfg
device = next(model.parameters()).device
seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)
T = cfg.max_seq_len
cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
cond_null = torch.full((n_samples,), untreated_cond, device=device, dtype=torch.long)
action_tx = action_null = None
if cfg.use_latent_action:
action_tx = torch.full((n_samples, T),
treatment_action if treatment_action is not None
else cfg.n_latent_actions,
device=device, dtype=torch.long)
action_null = torch.full((n_samples, T),
untreated_action if untreated_action is not None
else cfg.n_latent_actions,
device=device, dtype=torch.long)
traj_tx = sample_denoise(model, cond_tx, seed_prefix=seed,
action=action_tx, gamma=gamma, n_steps=n_steps)
traj_null = sample_denoise(model, cond_null, seed_prefix=seed,
action=action_null, gamma=gamma, n_steps=n_steps)
return {
"traj_treated": traj_tx, "traj_untreated": traj_null,
"n": n_samples, "treatment_cond": treatment_cond,
"untreated_cond": untreated_cond, "gamma": gamma,
}
def outcome_rate(traj: torch.Tensor, target_ids: list[int]) -> float:
if not target_ids:
return 0.0
target = torch.tensor(target_ids, device=traj.device)
has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
return has.float().mean().item()
|