gemeo-twin-stack / src /gemeo /cwm /cfg_sample.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Classifier-Free Guidance sampling + counterfactual rollouts.
Given a trained Block Diffusion CWM, we sample patient trajectories by
iteratively unmasking tokens from fully-masked input. At each denoise step:
1. Run the conditional model with cond=c -> logits_c
2. Run unconditional model with cond=<NULL> -> logits_null
3. Guided logits = (1+gamma) * logits_c - gamma * logits_null
4. Greedy or temperature-sampled token at each MASK position
5. Top-k confidence remasking schedule (D3PM-style) until t=0
Counterfactual pair:
- sample one batch with cond = treatment
- sample another batch with cond = <NO_TX>
- compare outcome-token distributions to estimate ATE
This is the core counterfactual primitive used by tte_validate.py and
sensitivity.py.
"""
from __future__ import annotations
import logging
import torch
import torch.nn.functional as F
from .block_diffusion import BlockDiffusionTransformer
log = logging.getLogger("gemeo.cwm.sample")
@torch.no_grad()
def cfg_sample(
model: BlockDiffusionTransformer,
cond: torch.Tensor, # (B,) condition ids
*,
seed_prefix: torch.Tensor | None = None, # (B, L) prompt tokens (kept clean)
gamma: float = 2.0,
n_steps: int = None,
temperature: float = 1.0,
null_cond: int = 0,
) -> torch.Tensor:
"""Sample (B, T) trajectories with classifier-free guidance.
seed_prefix is appended unmasked at the front. The rest of the sequence
is filled in by iterative unmasking with confidence-based remasking
schedule.
"""
cfg = model.cfg
n_steps = n_steps or cfg.n_diff_steps
device = cond.device
B = cond.size(0)
T = cfg.max_seq_len
# Start with full MASK
x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
if seed_prefix is not None:
L = seed_prefix.size(1)
x[:, :L] = seed_prefix
fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
fixed_mask[:, :L] = True
else:
fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
ts = torch.linspace(1.0, 0.0, n_steps + 1, device=device)
null = torch.full_like(cond, null_cond)
for k in range(n_steps):
t_cur = ts[k].unsqueeze(0).expand(B)
# Conditional + unconditional logits
logits_c = model(x, t_cur, cond)
logits_n = model(x, t_cur, null)
# CFG
if gamma > 0:
logits = (1 + gamma) * logits_c - gamma * logits_n
else:
logits = logits_c
# Avoid sampling MASK token
logits[:, :, cfg.mask_token] = -1e9
if temperature != 1.0:
logits = logits / max(temperature, 1e-3)
probs = F.softmax(logits, dim=-1)
confs, preds = probs.max(dim=-1)
# Confidence-based remasking: keep top-(1-t_next) fraction of MASKED tokens
t_next = ts[k + 1].item()
keep_frac = 1.0 - t_next
target_kept = int(round(keep_frac * T))
# Already-revealed tokens stay
revealed = (x != cfg.mask_token) | fixed_mask
already_kept = revealed.sum(dim=-1)
# Pick the highest-confidence MASKED positions to unmask
confs_for_pick = torch.where(revealed, torch.full_like(confs, -1e9), confs)
new_x = x.clone()
for b in range(B):
need = max(0, target_kept - int(already_kept[b].item()))
if need == 0:
continue
topi = confs_for_pick[b].topk(need).indices
new_x[b, topi] = preds[b, topi]
x = new_x
# Final pass: replace any remaining MASK with greedy
mask_left = x == cfg.mask_token
if mask_left.any():
t_zero = torch.zeros(B, device=device)
logits_c = model(x, t_zero, cond)
logits_n = model(x, t_zero, null)
logits = (1 + gamma) * logits_c - gamma * logits_n
logits[:, :, cfg.mask_token] = -1e9
preds = logits.argmax(dim=-1)
x = torch.where(mask_left, preds, x)
return x
def counterfactual_pair(
model: BlockDiffusionTransformer,
treatment_cond: int,
null_treatment_cond: int,
*,
seed_prefix: torch.Tensor,
n_samples: int = 100,
gamma: float = 2.0,
device: torch.device | None = None,
) -> dict:
"""Sample n_samples trajectories under treatment AND no-treatment.
Returns dict with both sample batches + per-token frequency tables for
quick comparison of outcome distributions.
"""
device = device or seed_prefix.device
seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)
cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
cond_null = torch.full((n_samples,), null_treatment_cond, device=device, dtype=torch.long)
traj_tx = cfg_sample(model, cond_tx, seed_prefix=seed, gamma=gamma)
traj_null = cfg_sample(model, cond_null, seed_prefix=seed, gamma=gamma)
return {
"traj_treated": traj_tx,
"traj_untreated": traj_null,
"n": n_samples,
"treatment_cond": treatment_cond,
"null_cond": null_treatment_cond,
"gamma": gamma,
}
def outcome_rate(traj: torch.Tensor, target_tok_ids: list[int]) -> float:
"""Fraction of trajectories containing ANY of the target outcome tokens."""
if not target_tok_ids:
return 0.0
target = torch.tensor(target_tok_ids, device=traj.device)
has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
return has.float().mean().item()