File size: 5,489 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Classifier-Free Guidance sampling + counterfactual rollouts.

Given a trained Block Diffusion CWM, we sample patient trajectories by
iteratively unmasking tokens from fully-masked input. At each denoise step:

  1. Run the conditional model with cond=c -> logits_c
  2. Run unconditional model with cond=<NULL> -> logits_null
  3. Guided logits = (1+gamma) * logits_c - gamma * logits_null
  4. Greedy or temperature-sampled token at each MASK position
  5. Top-k confidence remasking schedule (D3PM-style) until t=0

Counterfactual pair:
  - sample one batch with cond = treatment
  - sample another batch with cond = <NO_TX>
  - compare outcome-token distributions to estimate ATE

This is the core counterfactual primitive used by tte_validate.py and
sensitivity.py.
"""
from __future__ import annotations
import logging

import torch
import torch.nn.functional as F

from .block_diffusion import BlockDiffusionTransformer

log = logging.getLogger("gemeo.cwm.sample")


@torch.no_grad()
def cfg_sample(
    model: BlockDiffusionTransformer,
    cond: torch.Tensor,                  # (B,) condition ids
    *,
    seed_prefix: torch.Tensor | None = None,  # (B, L) prompt tokens (kept clean)
    gamma: float = 2.0,
    n_steps: int = None,
    temperature: float = 1.0,
    null_cond: int = 0,
) -> torch.Tensor:
    """Sample (B, T) trajectories with classifier-free guidance.

    seed_prefix is appended unmasked at the front. The rest of the sequence
    is filled in by iterative unmasking with confidence-based remasking
    schedule.
    """
    cfg = model.cfg
    n_steps = n_steps or cfg.n_diff_steps
    device = cond.device
    B = cond.size(0)
    T = cfg.max_seq_len

    # Start with full MASK
    x = torch.full((B, T), cfg.mask_token, device=device, dtype=torch.long)
    if seed_prefix is not None:
        L = seed_prefix.size(1)
        x[:, :L] = seed_prefix
        fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
        fixed_mask[:, :L] = True
    else:
        fixed_mask = torch.zeros(B, T, dtype=torch.bool, device=device)

    ts = torch.linspace(1.0, 0.0, n_steps + 1, device=device)
    null = torch.full_like(cond, null_cond)

    for k in range(n_steps):
        t_cur = ts[k].unsqueeze(0).expand(B)
        # Conditional + unconditional logits
        logits_c = model(x, t_cur, cond)
        logits_n = model(x, t_cur, null)
        # CFG
        if gamma > 0:
            logits = (1 + gamma) * logits_c - gamma * logits_n
        else:
            logits = logits_c
        # Avoid sampling MASK token
        logits[:, :, cfg.mask_token] = -1e9
        if temperature != 1.0:
            logits = logits / max(temperature, 1e-3)

        probs = F.softmax(logits, dim=-1)
        confs, preds = probs.max(dim=-1)

        # Confidence-based remasking: keep top-(1-t_next) fraction of MASKED tokens
        t_next = ts[k + 1].item()
        keep_frac = 1.0 - t_next
        target_kept = int(round(keep_frac * T))

        # Already-revealed tokens stay
        revealed = (x != cfg.mask_token) | fixed_mask
        already_kept = revealed.sum(dim=-1)

        # Pick the highest-confidence MASKED positions to unmask
        confs_for_pick = torch.where(revealed, torch.full_like(confs, -1e9), confs)
        new_x = x.clone()
        for b in range(B):
            need = max(0, target_kept - int(already_kept[b].item()))
            if need == 0:
                continue
            topi = confs_for_pick[b].topk(need).indices
            new_x[b, topi] = preds[b, topi]
        x = new_x

    # Final pass: replace any remaining MASK with greedy
    mask_left = x == cfg.mask_token
    if mask_left.any():
        t_zero = torch.zeros(B, device=device)
        logits_c = model(x, t_zero, cond)
        logits_n = model(x, t_zero, null)
        logits = (1 + gamma) * logits_c - gamma * logits_n
        logits[:, :, cfg.mask_token] = -1e9
        preds = logits.argmax(dim=-1)
        x = torch.where(mask_left, preds, x)

    return x


def counterfactual_pair(
    model: BlockDiffusionTransformer,
    treatment_cond: int,
    null_treatment_cond: int,
    *,
    seed_prefix: torch.Tensor,
    n_samples: int = 100,
    gamma: float = 2.0,
    device: torch.device | None = None,
) -> dict:
    """Sample n_samples trajectories under treatment AND no-treatment.

    Returns dict with both sample batches + per-token frequency tables for
    quick comparison of outcome distributions.
    """
    device = device or seed_prefix.device
    seed = seed_prefix.unsqueeze(0).expand(n_samples, -1).to(device)

    cond_tx = torch.full((n_samples,), treatment_cond, device=device, dtype=torch.long)
    cond_null = torch.full((n_samples,), null_treatment_cond, device=device, dtype=torch.long)

    traj_tx = cfg_sample(model, cond_tx, seed_prefix=seed, gamma=gamma)
    traj_null = cfg_sample(model, cond_null, seed_prefix=seed, gamma=gamma)

    return {
        "traj_treated": traj_tx,
        "traj_untreated": traj_null,
        "n": n_samples,
        "treatment_cond": treatment_cond,
        "null_cond": null_treatment_cond,
        "gamma": gamma,
    }


def outcome_rate(traj: torch.Tensor, target_tok_ids: list[int]) -> float:
    """Fraction of trajectories containing ANY of the target outcome tokens."""
    if not target_tok_ids:
        return 0.0
    target = torch.tensor(target_tok_ids, device=traj.device)
    has = (traj.unsqueeze(-1) == target).any(dim=(-1, -2))
    return has.float().mean().item()