"""Sampling primitives for CDF: AR mode, denoise mode, counterfactual rollouts. Diffusion Forcing flexibility — the same model handles: AR mode: Sigma_future = 1, sigma_past = 0. Roll forward like an autoregressive transformer but with per-token noise control. Denoise mode (bidirectional): Sigma low everywhere. Run k denoise steps, model fills the whole sequence. Counterfactual mode (the TTE primitive): Sigma=0 on observed tokens (clamp them clean), sigma=1 on tokens to generate. Condition on (cohort, intervention_action_id). Sample N times, compare distributions of outcome tokens. CFG (classifier-free guidance) wraps any mode: logits_g = (1 + gamma) * logits(c) - gamma * logits(null_c) Shortcut Forcing (Dreamer 4) reduces denoise steps from 32-64 to 4 via distilled student model — implemented in distill.py. """ from __future__ import annotations import logging import torch import torch.nn.functional as F from .diffusion_forcing import CDFTransformer log = logging.getLogger("gemeo.cdf.sample") @torch.no_grad() def sample_denoise( model: CDFTransformer, cond: torch.Tensor, *, seed_prefix: torch.Tensor | None = None, observed_mask: torch.Tensor | None = None, # (B, T) True = clamped clean action: torch.Tensor | None = None, gamma: float = 2.0, n_steps: int = 32, null_cond: int = 0, schedule: str = "cosine", ) -> torch.Tensor: """Denoise-mode sampling: fully-masked sequence + iterative refinement. Supports: - seed_prefix: clean tokens kept at sigma=0 for positions [0, L) - observed_mask: arbitrary positions to clamp (counterfactual mode) - CFG via (cond, null_cond) pair """ cfg = model.cfg device = cond.device B = cond.size(0) T = cfg.max_seq_len # Init with MASK x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long) fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device) if seed_prefix is not None: L = seed_prefix.size(1) x[:, :L] = seed_prefix fixed_mask[:, :L] = True if observed_mask is not None: fixed_mask |= observed_mask # Build noise schedule if schedule == "cosine": # smooth cosine from 1 -> 0 ts = torch.cos(torch.linspace(0, torch.pi/2, n_steps+1, device=device)) else: ts = torch.linspace(1.0, 0.0, n_steps+1, device=device) null = torch.full_like(cond, null_cond) null_action = (torch.full_like(action, cfg.n_latent_actions) if action is not None and cfg.use_latent_action else None) for k in range(n_steps): # Per-token sigma: fixed positions at 0, dynamic positions at ts[k] sigma = torch.where(fixed_mask, torch.zeros_like(ts[k:k+1]).expand(B, T), torch.full((B, T), ts[k].item(), device=device)) logits_c = model(x, sigma, cond, action) if gamma > 0: logits_n = model(x, sigma, null, null_action) logits = (1 + gamma) * logits_c - gamma * logits_n else: logits = logits_c logits[:, :, cfg.mask_token] = -1e9 probs = F.softmax(logits, dim=-1) confs, preds = probs.max(dim=-1) # Confidence-based remasking: reveal top-(1 - ts[k+1]) fraction of free tokens t_next = ts[k+1].item() target_kept = int(round((1 - t_next) * T)) revealed = (x != cfg.mask_token) | fixed_mask already = revealed.sum(dim=-1) new_x = x.clone() for b in range(B): need = max(0, target_kept - int(already[b].item())) if need == 0: continue confs_b = torch.where(revealed[b], torch.full_like(confs[b], -1e9), confs[b]) topi = confs_b.topk(need).indices new_x[b, topi] = preds[b, topi] x = new_x # Final cleanup mask_left = x == cfg.mask_token if mask_left.any(): sigma_final = torch.zeros(B, T, device=device) logits_c = model(x, sigma_final, cond, action) if gamma > 0: logits_n = model(x, sigma_final, null, null_action) logits = (1 + gamma) * logits_c - gamma * logits_n else: logits = logits_c logits[:, :, cfg.mask_token] = -1e9 preds = logits.argmax(-1) x = torch.where(mask_left, preds, x) return x @torch.no_grad() def sample_ar( model: CDFTransformer, cond: torch.Tensor, prefix: torch.Tensor, *, action: torch.Tensor | None = None, max_new: int = 50, temperature: float = 1.0, gamma: float = 0.0, null_cond: int = 0, ) -> torch.Tensor: """AR-mode sampling: future tokens at sigma=1, past at sigma=0. Faster than denoise mode when you only want to continue a prefix. """ cfg = model.cfg device = cond.device B = cond.size(0) x = prefix.clone().to(device) if x.dim() == 1: x = x.unsqueeze(0) null = torch.full_like(cond, null_cond) null_action = (torch.full_like(action, cfg.n_latent_actions) if action is not None and cfg.use_latent_action else None) for _ in range(max_new): T_now = x.size(1) if T_now >= cfg.max_seq_len: break # Pad with MASK x_pad = torch.cat([x, torch.full((B, 1), cfg.mask_token, device=device, dtype=torch.long)], dim=1) sigma = torch.zeros(B, T_now + 1, device=device) sigma[:, -1] = 1.0 a_pad = None if action is not None and cfg.use_latent_action: a_pad = torch.cat([action[:, :T_now], torch.full((B, 1), cfg.n_latent_actions, device=device, dtype=torch.long)], dim=1) logits = model(x_pad, sigma, cond, a_pad) if gamma > 0: logits_n = model(x_pad, sigma, null, null_action) logits = (1 + gamma) * logits - gamma * logits_n logits[:, :, cfg.mask_token] = -1e9 p = F.softmax(logits[:, -1] / max(temperature, 1e-3), dim=-1) nxt = torch.multinomial(p, 1) x = torch.cat([x, nxt], dim=1) return x @torch.no_grad() def counterfactual_rollout( model: CDFTransformer, seed_prefix: torch.Tensor, treatment_cond: int, untreated_cond: int, *, treatment_action: int | None = None, untreated_action: int | None = None, n_samples: int = 100, gamma: float = 2.0, n_steps: int = 32, ) -> dict: """Sample paired counterfactual trajectories under treatment vs no-treatment. Two ways to specify the intervention: - via cond id (cohort-level): treatment_cond / untreated_cond - via latent action id (per-token): treatment_action / untreated_action """ cfg = model.cfg device = next(model.parameters()).device seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device) T = cfg.max_seq_len cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long) cond_null = torch.full((n_samples,), untreated_cond, device=device, dtype=torch.long) action_tx = action_null = None if cfg.use_latent_action: action_tx = torch.full((n_samples, T), treatment_action if treatment_action is not None else cfg.n_latent_actions, device=device, dtype=torch.long) action_null = torch.full((n_samples, T), untreated_action if untreated_action is not None else cfg.n_latent_actions, device=device, dtype=torch.long) traj_tx = sample_denoise(model, cond_tx, seed_prefix=seed, action=action_tx, gamma=gamma, n_steps=n_steps) traj_null = sample_denoise(model, cond_null, seed_prefix=seed, action=action_null, gamma=gamma, n_steps=n_steps) return { "traj_treated": traj_tx, "traj_untreated": traj_null, "n": n_samples, "treatment_cond": treatment_cond, "untreated_cond": untreated_cond, "gamma": gamma, } def outcome_rate(traj: torch.Tensor, target_ids: list[int]) -> float: if not target_ids: return 0.0 target = torch.tensor(target_ids, device=traj.device) has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2)) return has.float().mean().item()