| """Sampling primitives for CDF: AR mode, denoise mode, counterfactual rollouts. |
| |
| Diffusion Forcing flexibility — the same model handles: |
| |
| AR mode: |
| Sigma_future = 1, sigma_past = 0. Roll forward like an autoregressive |
| transformer but with per-token noise control. |
| |
| Denoise mode (bidirectional): |
| Sigma low everywhere. Run k denoise steps, model fills the whole sequence. |
| |
| Counterfactual mode (the TTE primitive): |
| Sigma=0 on observed tokens (clamp them clean), sigma=1 on tokens to |
| generate. Condition on (cohort, intervention_action_id). Sample N times, |
| compare distributions of outcome tokens. |
| |
| CFG (classifier-free guidance) wraps any mode: |
| logits_g = (1 + gamma) * logits(c) - gamma * logits(null_c) |
| |
| Shortcut Forcing (Dreamer 4) reduces denoise steps from 32-64 to 4 via |
| distilled student model — implemented in distill.py. |
| """ |
| from __future__ import annotations |
| import logging |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from .diffusion_forcing import CDFTransformer |
|
|
| log = logging.getLogger("gemeo.cdf.sample") |
|
|
|
|
| @torch.no_grad() |
| def sample_denoise( |
| model: CDFTransformer, |
| cond: torch.Tensor, |
| *, |
| seed_prefix: torch.Tensor | None = None, |
| observed_mask: torch.Tensor | None = None, |
| action: torch.Tensor | None = None, |
| gamma: float = 2.0, |
| n_steps: int = 32, |
| null_cond: int = 0, |
| schedule: str = "cosine", |
| ) -> torch.Tensor: |
| """Denoise-mode sampling: fully-masked sequence + iterative refinement. |
| |
| Supports: |
| - seed_prefix: clean tokens kept at sigma=0 for positions [0, L) |
| - observed_mask: arbitrary positions to clamp (counterfactual mode) |
| - CFG via (cond, null_cond) pair |
| """ |
| cfg = model.cfg |
| device = cond.device |
| B = cond.size(0) |
| T = cfg.max_seq_len |
|
|
| |
| x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long) |
| fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device) |
| if seed_prefix is not None: |
| L = seed_prefix.size(1) |
| x[:, :L] = seed_prefix |
| fixed_mask[:, :L] = True |
| if observed_mask is not None: |
| fixed_mask |= observed_mask |
|
|
| |
| if schedule == "cosine": |
| |
| ts = torch.cos(torch.linspace(0, torch.pi/2, n_steps+1, device=device)) |
| else: |
| ts = torch.linspace(1.0, 0.0, n_steps+1, device=device) |
|
|
| null = torch.full_like(cond, null_cond) |
| null_action = (torch.full_like(action, cfg.n_latent_actions) |
| if action is not None and cfg.use_latent_action else None) |
|
|
| for k in range(n_steps): |
| |
| sigma = torch.where(fixed_mask, torch.zeros_like(ts[k:k+1]).expand(B, T), |
| torch.full((B, T), ts[k].item(), device=device)) |
| logits_c = model(x, sigma, cond, action) |
| if gamma > 0: |
| logits_n = model(x, sigma, null, null_action) |
| logits = (1 + gamma) * logits_c - gamma * logits_n |
| else: |
| logits = logits_c |
| logits[:, :, cfg.mask_token] = -1e9 |
|
|
| probs = F.softmax(logits, dim=-1) |
| confs, preds = probs.max(dim=-1) |
|
|
| |
| t_next = ts[k+1].item() |
| target_kept = int(round((1 - t_next) * T)) |
| revealed = (x != cfg.mask_token) | fixed_mask |
| already = revealed.sum(dim=-1) |
| new_x = x.clone() |
| for b in range(B): |
| need = max(0, target_kept - int(already[b].item())) |
| if need == 0: |
| continue |
| confs_b = torch.where(revealed[b], torch.full_like(confs[b], -1e9), confs[b]) |
| topi = confs_b.topk(need).indices |
| new_x[b, topi] = preds[b, topi] |
| x = new_x |
|
|
| |
| mask_left = x == cfg.mask_token |
| if mask_left.any(): |
| sigma_final = torch.zeros(B, T, device=device) |
| logits_c = model(x, sigma_final, cond, action) |
| if gamma > 0: |
| logits_n = model(x, sigma_final, null, null_action) |
| logits = (1 + gamma) * logits_c - gamma * logits_n |
| else: |
| logits = logits_c |
| logits[:, :, cfg.mask_token] = -1e9 |
| preds = logits.argmax(-1) |
| x = torch.where(mask_left, preds, x) |
| return x |
|
|
|
|
| @torch.no_grad() |
| def sample_ar( |
| model: CDFTransformer, |
| cond: torch.Tensor, |
| prefix: torch.Tensor, |
| *, |
| action: torch.Tensor | None = None, |
| max_new: int = 50, |
| temperature: float = 1.0, |
| gamma: float = 0.0, |
| null_cond: int = 0, |
| ) -> torch.Tensor: |
| """AR-mode sampling: future tokens at sigma=1, past at sigma=0. |
| |
| Faster than denoise mode when you only want to continue a prefix. |
| """ |
| cfg = model.cfg |
| device = cond.device |
| B = cond.size(0) |
| x = prefix.clone().to(device) |
| if x.dim() == 1: x = x.unsqueeze(0) |
| null = torch.full_like(cond, null_cond) |
| null_action = (torch.full_like(action, cfg.n_latent_actions) |
| if action is not None and cfg.use_latent_action else None) |
|
|
| for _ in range(max_new): |
| T_now = x.size(1) |
| if T_now >= cfg.max_seq_len: |
| break |
| |
| x_pad = torch.cat([x, torch.full((B, 1), cfg.mask_token, |
| device=device, dtype=torch.long)], dim=1) |
| sigma = torch.zeros(B, T_now + 1, device=device) |
| sigma[:, -1] = 1.0 |
| a_pad = None |
| if action is not None and cfg.use_latent_action: |
| a_pad = torch.cat([action[:, :T_now], |
| torch.full((B, 1), cfg.n_latent_actions, |
| device=device, dtype=torch.long)], dim=1) |
| logits = model(x_pad, sigma, cond, a_pad) |
| if gamma > 0: |
| logits_n = model(x_pad, sigma, null, null_action) |
| logits = (1 + gamma) * logits - gamma * logits_n |
| logits[:, :, cfg.mask_token] = -1e9 |
| p = F.softmax(logits[:, -1] / max(temperature, 1e-3), dim=-1) |
| nxt = torch.multinomial(p, 1) |
| x = torch.cat([x, nxt], dim=1) |
| return x |
|
|
|
|
| @torch.no_grad() |
| def counterfactual_rollout( |
| model: CDFTransformer, |
| seed_prefix: torch.Tensor, |
| treatment_cond: int, |
| untreated_cond: int, |
| *, |
| treatment_action: int | None = None, |
| untreated_action: int | None = None, |
| n_samples: int = 100, |
| gamma: float = 2.0, |
| n_steps: int = 32, |
| ) -> dict: |
| """Sample paired counterfactual trajectories under treatment vs no-treatment. |
| |
| Two ways to specify the intervention: |
| - via cond id (cohort-level): treatment_cond / untreated_cond |
| - via latent action id (per-token): treatment_action / untreated_action |
| """ |
| cfg = model.cfg |
| device = next(model.parameters()).device |
| seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device) |
| T = cfg.max_seq_len |
|
|
| cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long) |
| cond_null = torch.full((n_samples,), untreated_cond, device=device, dtype=torch.long) |
|
|
| action_tx = action_null = None |
| if cfg.use_latent_action: |
| action_tx = torch.full((n_samples, T), |
| treatment_action if treatment_action is not None |
| else cfg.n_latent_actions, |
| device=device, dtype=torch.long) |
| action_null = torch.full((n_samples, T), |
| untreated_action if untreated_action is not None |
| else cfg.n_latent_actions, |
| device=device, dtype=torch.long) |
|
|
| traj_tx = sample_denoise(model, cond_tx, seed_prefix=seed, |
| action=action_tx, gamma=gamma, n_steps=n_steps) |
| traj_null = sample_denoise(model, cond_null, seed_prefix=seed, |
| action=action_null, gamma=gamma, n_steps=n_steps) |
| return { |
| "traj_treated": traj_tx, "traj_untreated": traj_null, |
| "n": n_samples, "treatment_cond": treatment_cond, |
| "untreated_cond": untreated_cond, "gamma": gamma, |
| } |
|
|
|
|
| def outcome_rate(traj: torch.Tensor, target_ids: list[int]) -> float: |
| if not target_ids: |
| return 0.0 |
| target = torch.tensor(target_ids, device=traj.device) |
| has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2)) |
| return has.float().mean().item() |
|
|