gemeo-sus / src /sample.py
timmers's picture
GEMEO/SUS v6 recurrence-aware (RAVEN) — new-onset Top-1 60.1% vs baseline 38.2%, defeats autocorrelation trap. GEMEO Arch v2.0 Principle 7 proven.
908ea05 verified
"""Sampling primitives for CDF: AR mode, denoise mode, counterfactual rollouts.
Diffusion Forcing flexibility — the same model handles:
AR mode:
Sigma_future = 1, sigma_past = 0. Roll forward like an autoregressive
transformer but with per-token noise control.
Denoise mode (bidirectional):
Sigma low everywhere. Run k denoise steps, model fills the whole sequence.
Counterfactual mode (the TTE primitive):
Sigma=0 on observed tokens (clamp them clean), sigma=1 on tokens to
generate. Condition on (cohort, intervention_action_id). Sample N times,
compare distributions of outcome tokens.
CFG (classifier-free guidance) wraps any mode:
logits_g = (1 + gamma) * logits(c) - gamma * logits(null_c)
Shortcut Forcing (Dreamer 4) reduces denoise steps from 32-64 to 4 via
distilled student model — implemented in distill.py.
"""
from __future__ import annotations
import logging
import torch
import torch.nn.functional as F
from .diffusion_forcing import CDFTransformer
log = logging.getLogger("gemeo.cdf.sample")
@torch.no_grad()
def sample_denoise(
model: CDFTransformer,
cond: torch.Tensor,
*,
seed_prefix: torch.Tensor | None = None,
observed_mask: torch.Tensor | None = None, # (B, T) True = clamped clean
action: torch.Tensor | None = None,
gamma: float = 2.0,
n_steps: int = 32,
null_cond: int = 0,
schedule: str = "cosine",
) -> torch.Tensor:
"""Denoise-mode sampling: fully-masked sequence + iterative refinement.
Supports:
- seed_prefix: clean tokens kept at sigma=0 for positions [0, L)
- observed_mask: arbitrary positions to clamp (counterfactual mode)
- CFG via (cond, null_cond) pair
"""
cfg = model.cfg
device = cond.device
B = cond.size(0)
T = cfg.max_seq_len
# Init with MASK
x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
if seed_prefix is not None:
L = seed_prefix.size(1)
x[:, :L] = seed_prefix
fixed_mask[:, :L] = True
if observed_mask is not None:
fixed_mask |= observed_mask
# Build noise schedule
if schedule == "cosine":
# smooth cosine from 1 -> 0
ts = torch.cos(torch.linspace(0, torch.pi/2, n_steps+1, device=device))
else:
ts = torch.linspace(1.0, 0.0, n_steps+1, device=device)
null = torch.full_like(cond, null_cond)
null_action = (torch.full_like(action, cfg.n_latent_actions)
if action is not None and cfg.use_latent_action else None)
for k in range(n_steps):
# Per-token sigma: fixed positions at 0, dynamic positions at ts[k]
sigma = torch.where(fixed_mask, torch.zeros_like(ts[k:k+1]).expand(B, T),
torch.full((B, T), ts[k].item(), device=device))
logits_c = model(x, sigma, cond, action)
if gamma > 0:
logits_n = model(x, sigma, null, null_action)
logits = (1 + gamma) * logits_c - gamma * logits_n
else:
logits = logits_c
logits[:, :, cfg.mask_token] = -1e9
probs = F.softmax(logits, dim=-1)
confs, preds = probs.max(dim=-1)
# Confidence-based remasking: reveal top-(1 - ts[k+1]) fraction of free tokens
t_next = ts[k+1].item()
target_kept = int(round((1 - t_next) * T))
revealed = (x != cfg.mask_token) | fixed_mask
already = revealed.sum(dim=-1)
new_x = x.clone()
for b in range(B):
need = max(0, target_kept - int(already[b].item()))
if need == 0:
continue
confs_b = torch.where(revealed[b], torch.full_like(confs[b], -1e9), confs[b])
topi = confs_b.topk(need).indices
new_x[b, topi] = preds[b, topi]
x = new_x
# Final cleanup
mask_left = x == cfg.mask_token
if mask_left.any():
sigma_final = torch.zeros(B, T, device=device)
logits_c = model(x, sigma_final, cond, action)
if gamma > 0:
logits_n = model(x, sigma_final, null, null_action)
logits = (1 + gamma) * logits_c - gamma * logits_n
else:
logits = logits_c
logits[:, :, cfg.mask_token] = -1e9
preds = logits.argmax(-1)
x = torch.where(mask_left, preds, x)
return x
@torch.no_grad()
def sample_ar(
model: CDFTransformer,
cond: torch.Tensor,
prefix: torch.Tensor,
*,
action: torch.Tensor | None = None,
max_new: int = 50,
temperature: float = 1.0,
gamma: float = 0.0,
null_cond: int = 0,
) -> torch.Tensor:
"""AR-mode sampling: future tokens at sigma=1, past at sigma=0.
Faster than denoise mode when you only want to continue a prefix.
"""
cfg = model.cfg
device = cond.device
B = cond.size(0)
x = prefix.clone().to(device)
if x.dim() == 1: x = x.unsqueeze(0)
null = torch.full_like(cond, null_cond)
null_action = (torch.full_like(action, cfg.n_latent_actions)
if action is not None and cfg.use_latent_action else None)
for _ in range(max_new):
T_now = x.size(1)
if T_now >= cfg.max_seq_len:
break
# Pad with MASK
x_pad = torch.cat([x, torch.full((B, 1), cfg.mask_token,
device=device, dtype=torch.long)], dim=1)
sigma = torch.zeros(B, T_now + 1, device=device)
sigma[:, -1] = 1.0
a_pad = None
if action is not None and cfg.use_latent_action:
a_pad = torch.cat([action[:, :T_now],
torch.full((B, 1), cfg.n_latent_actions,
device=device, dtype=torch.long)], dim=1)
logits = model(x_pad, sigma, cond, a_pad)
if gamma > 0:
logits_n = model(x_pad, sigma, null, null_action)
logits = (1 + gamma) * logits - gamma * logits_n
logits[:, :, cfg.mask_token] = -1e9
p = F.softmax(logits[:, -1] / max(temperature, 1e-3), dim=-1)
nxt = torch.multinomial(p, 1)
x = torch.cat([x, nxt], dim=1)
return x
@torch.no_grad()
def counterfactual_rollout(
model: CDFTransformer,
seed_prefix: torch.Tensor,
treatment_cond: int,
untreated_cond: int,
*,
treatment_action: int | None = None,
untreated_action: int | None = None,
n_samples: int = 100,
gamma: float = 2.0,
n_steps: int = 32,
) -> dict:
"""Sample paired counterfactual trajectories under treatment vs no-treatment.
Two ways to specify the intervention:
- via cond id (cohort-level): treatment_cond / untreated_cond
- via latent action id (per-token): treatment_action / untreated_action
"""
cfg = model.cfg
device = next(model.parameters()).device
seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)
T = cfg.max_seq_len
cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
cond_null = torch.full((n_samples,), untreated_cond, device=device, dtype=torch.long)
action_tx = action_null = None
if cfg.use_latent_action:
action_tx = torch.full((n_samples, T),
treatment_action if treatment_action is not None
else cfg.n_latent_actions,
device=device, dtype=torch.long)
action_null = torch.full((n_samples, T),
untreated_action if untreated_action is not None
else cfg.n_latent_actions,
device=device, dtype=torch.long)
traj_tx = sample_denoise(model, cond_tx, seed_prefix=seed,
action=action_tx, gamma=gamma, n_steps=n_steps)
traj_null = sample_denoise(model, cond_null, seed_prefix=seed,
action=action_null, gamma=gamma, n_steps=n_steps)
return {
"traj_treated": traj_tx, "traj_untreated": traj_null,
"n": n_samples, "treatment_cond": treatment_cond,
"untreated_cond": untreated_cond, "gamma": gamma,
}
def outcome_rate(traj: torch.Tensor, target_ids: list[int]) -> float:
if not target_ids:
return 0.0
target = torch.tensor(target_ids, device=traj.device)
has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
return has.float().mean().item()