| """Classifier-Free Guidance sampling + counterfactual rollouts. |
| |
| Given a trained Block Diffusion CWM, we sample patient trajectories by |
| iteratively unmasking tokens from fully-masked input. At each denoise step: |
| |
| 1. Run the conditional model with cond=c -> logits_c |
| 2. Run unconditional model with cond=<NULL> -> logits_null |
| 3. Guided logits = (1+gamma) * logits_c - gamma * logits_null |
| 4. Greedy or temperature-sampled token at each MASK position |
| 5. Top-k confidence remasking schedule (D3PM-style) until t=0 |
| |
| Counterfactual pair: |
| - sample one batch with cond = treatment |
| - sample another batch with cond = <NO_TX> |
| - compare outcome-token distributions to estimate ATE |
| |
| This is the core counterfactual primitive used by tte_validate.py and |
| sensitivity.py. |
| """ |
| from __future__ import annotations |
| import logging |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from .block_diffusion import BlockDiffusionTransformer |
|
|
| log = logging.getLogger("gemeo.cwm.sample") |
|
|
|
|
| @torch.no_grad() |
| def cfg_sample( |
| model: BlockDiffusionTransformer, |
| cond: torch.Tensor, |
| *, |
| seed_prefix: torch.Tensor | None = None, |
| gamma: float = 2.0, |
| n_steps: int = None, |
| temperature: float = 1.0, |
| null_cond: int = 0, |
| ) -> torch.Tensor: |
| """Sample (B, T) trajectories with classifier-free guidance. |
| |
| seed_prefix is appended unmasked at the front. The rest of the sequence |
| is filled in by iterative unmasking with confidence-based remasking |
| schedule. |
| """ |
| cfg = model.cfg |
| n_steps = n_steps or cfg.n_diff_steps |
| device = cond.device |
| B = cond.size(0) |
| T = cfg.max_seq_len |
|
|
| |
| x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long) |
| if seed_prefix is not None: |
| L = seed_prefix.size(1) |
| x[:, :L] = seed_prefix |
| fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device) |
| fixed_mask[:, :L] = True |
| else: |
| fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device) |
|
|
| ts = torch.linspace(1.0, 0.0, n_steps + 1, device=device) |
| null = torch.full_like(cond, null_cond) |
|
|
| for k in range(n_steps): |
| t_cur = ts[k].unsqueeze(0).expand(B) |
| |
| logits_c = model(x, t_cur, cond) |
| logits_n = model(x, t_cur, null) |
| |
| if gamma > 0: |
| logits = (1 + gamma) * logits_c - gamma * logits_n |
| else: |
| logits = logits_c |
| |
| logits[:, :, cfg.mask_token] = -1e9 |
| if temperature != 1.0: |
| logits = logits / max(temperature, 1e-3) |
|
|
| probs = F.softmax(logits, dim=-1) |
| confs, preds = probs.max(dim=-1) |
|
|
| |
| t_next = ts[k + 1].item() |
| keep_frac = 1.0 - t_next |
| target_kept = int(round(keep_frac * T)) |
|
|
| |
| revealed = (x != cfg.mask_token) | fixed_mask |
| already_kept = revealed.sum(dim=-1) |
|
|
| |
| confs_for_pick = torch.where(revealed, torch.full_like(confs, -1e9), confs) |
| new_x = x.clone() |
| for b in range(B): |
| need = max(0, target_kept - int(already_kept[b].item())) |
| if need == 0: |
| continue |
| topi = confs_for_pick[b].topk(need).indices |
| new_x[b, topi] = preds[b, topi] |
| x = new_x |
|
|
| |
| mask_left = x == cfg.mask_token |
| if mask_left.any(): |
| t_zero = torch.zeros(B, device=device) |
| logits_c = model(x, t_zero, cond) |
| logits_n = model(x, t_zero, null) |
| logits = (1 + gamma) * logits_c - gamma * logits_n |
| logits[:, :, cfg.mask_token] = -1e9 |
| preds = logits.argmax(dim=-1) |
| x = torch.where(mask_left, preds, x) |
|
|
| return x |
|
|
|
|
| def counterfactual_pair( |
| model: BlockDiffusionTransformer, |
| treatment_cond: int, |
| null_treatment_cond: int, |
| *, |
| seed_prefix: torch.Tensor, |
| n_samples: int = 100, |
| gamma: float = 2.0, |
| device: torch.device | None = None, |
| ) -> dict: |
| """Sample n_samples trajectories under treatment AND no-treatment. |
| |
| Returns dict with both sample batches + per-token frequency tables for |
| quick comparison of outcome distributions. |
| """ |
| device = device or seed_prefix.device |
| seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device) |
|
|
| cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long) |
| cond_null = torch.full((n_samples,), null_treatment_cond, device=device, dtype=torch.long) |
|
|
| traj_tx = cfg_sample(model, cond_tx, seed_prefix=seed, gamma=gamma) |
| traj_null = cfg_sample(model, cond_null, seed_prefix=seed, gamma=gamma) |
|
|
| return { |
| "traj_treated": traj_tx, |
| "traj_untreated": traj_null, |
| "n": n_samples, |
| "treatment_cond": treatment_cond, |
| "null_cond": null_treatment_cond, |
| "gamma": gamma, |
| } |
|
|
|
|
| def outcome_rate(traj: torch.Tensor, target_tok_ids: list[int]) -> float: |
| """Fraction of trajectories containing ANY of the target outcome tokens.""" |
| if not target_tok_ids: |
| return 0.0 |
| target = torch.tensor(target_tok_ids, device=traj.device) |
| has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2)) |
| return has.float().mean().item() |
|
|