"""Classifier-Free Guidance sampling + counterfactual rollouts. Given a trained Block Diffusion CWM, we sample patient trajectories by iteratively unmasking tokens from fully-masked input. At each denoise step: 1. Run the conditional model with cond=c -> logits_c 2. Run unconditional model with cond= -> logits_null 3. Guided logits = (1+gamma) * logits_c - gamma * logits_null 4. Greedy or temperature-sampled token at each MASK position 5. Top-k confidence remasking schedule (D3PM-style) until t=0 Counterfactual pair: - sample one batch with cond = treatment - sample another batch with cond = - compare outcome-token distributions to estimate ATE This is the core counterfactual primitive used by tte_validate.py and sensitivity.py. """ from __future__ import annotations import logging import torch import torch.nn.functional as F from .block_diffusion import BlockDiffusionTransformer log = logging.getLogger("gemeo.cwm.sample") @torch.no_grad() def cfg_sample( model: BlockDiffusionTransformer, cond: torch.Tensor, # (B,) condition ids *, seed_prefix: torch.Tensor | None = None, # (B, L) prompt tokens (kept clean) gamma: float = 2.0, n_steps: int = None, temperature: float = 1.0, null_cond: int = 0, ) -> torch.Tensor: """Sample (B, T) trajectories with classifier-free guidance. seed_prefix is appended unmasked at the front. The rest of the sequence is filled in by iterative unmasking with confidence-based remasking schedule. """ cfg = model.cfg n_steps = n_steps or cfg.n_diff_steps device = cond.device B = cond.size(0) T = cfg.max_seq_len # Start with full MASK x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long) if seed_prefix is not None: L = seed_prefix.size(1) x[:, :L] = seed_prefix fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device) fixed_mask[:, :L] = True else: fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device) ts = torch.linspace(1.0, 0.0, n_steps + 1, device=device) null = torch.full_like(cond, null_cond) for k in range(n_steps): t_cur = ts[k].unsqueeze(0).expand(B) # Conditional + unconditional logits logits_c = model(x, t_cur, cond) logits_n = model(x, t_cur, null) # CFG if gamma > 0: logits = (1 + gamma) * logits_c - gamma * logits_n else: logits = logits_c # Avoid sampling MASK token logits[:, :, cfg.mask_token] = -1e9 if temperature != 1.0: logits = logits / max(temperature, 1e-3) probs = F.softmax(logits, dim=-1) confs, preds = probs.max(dim=-1) # Confidence-based remasking: keep top-(1-t_next) fraction of MASKED tokens t_next = ts[k + 1].item() keep_frac = 1.0 - t_next target_kept = int(round(keep_frac * T)) # Already-revealed tokens stay revealed = (x != cfg.mask_token) | fixed_mask already_kept = revealed.sum(dim=-1) # Pick the highest-confidence MASKED positions to unmask confs_for_pick = torch.where(revealed, torch.full_like(confs, -1e9), confs) new_x = x.clone() for b in range(B): need = max(0, target_kept - int(already_kept[b].item())) if need == 0: continue topi = confs_for_pick[b].topk(need).indices new_x[b, topi] = preds[b, topi] x = new_x # Final pass: replace any remaining MASK with greedy mask_left = x == cfg.mask_token if mask_left.any(): t_zero = torch.zeros(B, device=device) logits_c = model(x, t_zero, cond) logits_n = model(x, t_zero, null) logits = (1 + gamma) * logits_c - gamma * logits_n logits[:, :, cfg.mask_token] = -1e9 preds = logits.argmax(dim=-1) x = torch.where(mask_left, preds, x) return x def counterfactual_pair( model: BlockDiffusionTransformer, treatment_cond: int, null_treatment_cond: int, *, seed_prefix: torch.Tensor, n_samples: int = 100, gamma: float = 2.0, device: torch.device | None = None, ) -> dict: """Sample n_samples trajectories under treatment AND no-treatment. Returns dict with both sample batches + per-token frequency tables for quick comparison of outcome distributions. """ device = device or seed_prefix.device seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device) cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long) cond_null = torch.full((n_samples,), null_treatment_cond, device=device, dtype=torch.long) traj_tx = cfg_sample(model, cond_tx, seed_prefix=seed, gamma=gamma) traj_null = cfg_sample(model, cond_null, seed_prefix=seed, gamma=gamma) return { "traj_treated": traj_tx, "traj_untreated": traj_null, "n": n_samples, "treatment_cond": treatment_cond, "null_cond": null_treatment_cond, "gamma": gamma, } def outcome_rate(traj: torch.Tensor, target_tok_ids: list[int]) -> float: """Fraction of trajectories containing ANY of the target outcome tokens.""" if not target_tok_ids: return 0.0 target = torch.tensor(target_tok_ids, device=traj.device) has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2)) return has.float().mean().item()