File size: 10,700 Bytes
e5f14b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
"""
MDLM β€” Masked Diffusion Language Model for governed structures.

Architecture:
  - Small transformer encoder (4 layers, 128 dim, 4 heads)
  - Absorbing-state masking: tokens β†’ <MASK> at rate alpha(t)
  - Denoising: predict original token from masked sequence
  - Loss: cross-entropy on masked positions (reweighted MLM)

Masking schedules:
  A: hierarchical hierarchical (Tier 1 β†’ Tier 2 β†’ Tier 3+readiness)
  B: flat hierarchical (operators only, no readiness staging)
  C: Uniform random
  D: inverted inverted

Per PLAN-GHA-002 Β§4.4: A > B > C > D predicted.
"""

from __future__ import annotations

import math
import random
from enum import Enum

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False

from pipeline.mdlm.tokenizer import (
    VOCAB_SIZE, MASK, PAD, NEVER_MASKED,
    TIER_1_TOKENS, TIER_2_TOKENS, TIER_3_TOKENS,
    get_tier, pad_sequence,
)


class MaskingSchedule(str, Enum):
    """Masking schedule variants for the hierarchical hypothesis test."""
    HIERARCHICAL = "A"  # hierarchical: Tier 1 β†’ Tier 2 β†’ CL+PreAttest
    FLAT = "B"  # flat: operators only, uniform within tiers
    UNIFORM = "C"           # uniform random over all maskable tokens
    INVERTED = "D"      # inverted: CL first, Tier 1 last


if HAS_TORCH:

    class StructureModel(nn.Module):
        """Small transformer for governed structure denoising."""

        def __init__(
            self,
            vocab_size: int = VOCAB_SIZE,
            d_model: int = 128,
            nhead: int = 4,
            num_layers: int = 4,
            max_len: int = 40,
            dropout: float = 0.1,
        ):
            super().__init__()
            self.d_model = d_model
            self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
            self.pos_embedding = nn.Embedding(max_len, d_model)
            self.timestep_embedding = nn.Embedding(1000, d_model)  # diffusion timestep

            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
                dropout=dropout, batch_first=True,
            )
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.output_proj = nn.Linear(d_model, vocab_size)

        def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
            """
            x: (batch, seq_len) β€” token ids with some positions masked
            t: (batch,) β€” diffusion timestep (0 = clean, T = fully masked)
            Returns: (batch, seq_len, vocab_size) β€” logits for each position
            """
            B, L = x.shape
            positions = torch.arange(L, device=x.device).unsqueeze(0).expand(B, -1)

            h = self.embedding(x) + self.pos_embedding(positions)
            h = h + self.timestep_embedding(t).unsqueeze(1)

            pad_mask = (x == PAD)
            h = self.transformer(h, src_key_padding_mask=pad_mask)
            return self.output_proj(h)


    def apply_mask(
        tokens: torch.Tensor,
        mask_rate: float,
        schedule: MaskingSchedule,
        timestep: int = 0,
        total_timesteps: int = 100,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Apply masking schedule to a batch of token sequences.

        Returns:
            masked_tokens: tokens with some positions replaced by MASK
            mask_positions: boolean tensor (True = was masked)
        """
        B, L = tokens.shape
        masked = tokens.clone()
        mask_positions = torch.zeros(B, L, dtype=torch.bool, device=tokens.device)

        for b in range(B):
            for i in range(L):
                tok = tokens[b, i].item()
                if tok in NEVER_MASKED:
                    continue

                tier = get_tier(tok)
                if tier == 0:
                    continue

                # Compute per-tier mask probability based on schedule
                p = _tier_mask_prob(tier, mask_rate, schedule, timestep, total_timesteps)

                if random.random() < p:
                    masked[b, i] = MASK
                    mask_positions[b, i] = True

        return masked, mask_positions


    def _tier_mask_prob(
        tier: int,
        base_rate: float,
        schedule: MaskingSchedule,
        timestep: int,
        total_timesteps: int,
    ) -> float:
        """Compute mask probability for a token based on its tier and the schedule."""
        t_frac = timestep / max(total_timesteps, 1)  # 0 = clean, 1 = fully masked

        if schedule == MaskingSchedule.UNIFORM:
            return base_rate

        if schedule == MaskingSchedule.HIERARCHICAL:
            # Tier 1 (Tier 1): masked last, unmasked first
            # Tier 3 (CL+PreAttest): masked first, unmasked last
            if tier == 1:
                return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0
            elif tier == 2:
                return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
            else:  # tier 3
                return base_rate * min(1.0, t_frac / 0.33)

        if schedule == MaskingSchedule.FLAT:
            # Same as 369 but witness tokens are tier 2 priority
            if tier == 1:
                return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0
            elif tier == 2:
                return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
            else:
                return base_rate * min(1.0, t_frac / 0.33)

        if schedule == MaskingSchedule.INVERTED:
            # Inverted: Tier 1 masked first
            if tier == 1:
                return base_rate * min(1.0, t_frac / 0.33)
            elif tier == 2:
                return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
            else:
                return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0

        return base_rate


    def compute_loss(
        model: StructureModel,
        batch: torch.Tensor,
        schedule: MaskingSchedule,
        timestep: int,
        total_timesteps: int = 100,
        mask_rate: float = 0.5,
    ) -> torch.Tensor:
        """Compute MDLM denoising loss on a batch.

        Loss = cross-entropy on masked positions only.
        Returns zero loss if no positions were masked (avoids NaN).
        """
        device = next(model.parameters()).device
        batch = batch.to(device)
        t_tensor = torch.full((batch.size(0),), timestep, dtype=torch.long, device=device)

        masked, mask_pos = apply_mask(batch, mask_rate, schedule, timestep, total_timesteps)

        # If nothing was masked, return zero loss
        if not mask_pos.any():
            return torch.tensor(0.0, device=device, requires_grad=True)

        logits = model(masked, t_tensor)

        # Loss only on masked positions
        loss = F.cross_entropy(
            logits[mask_pos],
            batch[mask_pos],
            ignore_index=PAD,
        )
        return loss


    def generate(
        model: StructureModel,
        num_samples: int,
        seq_len: int,
        schedule: MaskingSchedule,
        total_timesteps: int = 50,
        g_slots: int = 3,
        s_slots: int = 4,
        f_slots: int = 3,
    ) -> torch.Tensor:
        """Generate governed structures by template-guided iterative unmasking.

        The channel_b frame is IMPOSED (governance), not learned:
          <BOS> <G> [MASK slots] </G> <S> [MASK slots] </S> <F> [MASK slots] </F>
                [witness MASK slots] <EOS>

        The model fills in operator tokens and witness attestation status.
        This respects PROPOSE β‰  PROMOTE: the frame is governance,
        the content is what the kernel crystallizes.

        g_slots, s_slots, f_slots: number of operator MASK slots per modality.
        Should match the corpus distribution.
        """
        device = next(model.parameters()).device
        from pipeline.mdlm.tokenizer import (
            BOS, EOS, G_OPEN, G_CLOSE, S_OPEN, S_CLOSE, F_OPEN, F_CLOSE,
            WIT_OFFSET, ATTESTED,
        )

        # Build template with configurable slot counts
        template = [BOS, G_OPEN] + [MASK] * g_slots + [G_CLOSE,
                    S_OPEN] + [MASK] * s_slots + [S_CLOSE,
                    F_OPEN] + [MASK] * f_slots + [F_CLOSE]
        # 7 witness pairs: WIT_TOKEN MASK
        for w in range(7):
            template.extend([WIT_OFFSET + w, MASK])
        template.append(EOS)

        # Pad to seq_len
        while len(template) < seq_len:
            template.append(PAD)
        template = template[:seq_len]

        samples = torch.tensor([template] * num_samples, dtype=torch.long, device=device)

        model.eval()
        with torch.no_grad():
            for step in range(total_timesteps, -1, -1):
                t_tensor = torch.full((num_samples,), step, dtype=torch.long, device=device)
                logits = model(samples, t_tensor)
                probs = F.softmax(logits, dim=-1)

                t_frac = step / total_timesteps

                for b in range(num_samples):
                    for i in range(seq_len):
                        if samples[b, i].item() != MASK:
                            continue

                        pred = torch.multinomial(probs[b, i], 1).item()
                        tier = get_tier(pred)

                        # Tier-based unmasking schedule
                        should_unmask = False
                        if schedule == MaskingSchedule.HIERARCHICAL:
                            should_unmask = (tier == 1 and t_frac < 0.33) or \
                                          (tier == 2 and 0.33 <= t_frac < 0.66) or \
                                          (tier == 3 and t_frac >= 0.66) or \
                                          (step == 0)  # unmask everything at final step
                        else:
                            should_unmask = True

                        if should_unmask:
                            samples[b, i] = pred

            # Final pass: force-unmask any remaining MASK tokens
            remaining = (samples == MASK)
            if remaining.any():
                t_tensor = torch.zeros((num_samples,), dtype=torch.long, device=device)
                logits = model(samples, t_tensor)
                for b in range(num_samples):
                    for i in range(seq_len):
                        if samples[b, i].item() == MASK:
                            samples[b, i] = logits[b, i].argmax().item()

        return samples