MetaCortex-Dynamics commited on
Commit
e5f14b1
·
verified ·
1 Parent(s): 70c4b1b

Create pipeline/mdlm/model.py

Browse files
Files changed (1) hide show
  1. pipeline/mdlm/model.py +292 -0
pipeline/mdlm/model.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MDLM — Masked Diffusion Language Model for governed structures.
3
+
4
+ Architecture:
5
+ - Small transformer encoder (4 layers, 128 dim, 4 heads)
6
+ - Absorbing-state masking: tokens → <MASK> at rate alpha(t)
7
+ - Denoising: predict original token from masked sequence
8
+ - Loss: cross-entropy on masked positions (reweighted MLM)
9
+
10
+ Masking schedules:
11
+ A: hierarchical hierarchical (Tier 1 → Tier 2 → Tier 3+readiness)
12
+ B: flat hierarchical (operators only, no readiness staging)
13
+ C: Uniform random
14
+ D: inverted inverted
15
+
16
+ Per PLAN-GHA-002 §4.4: A > B > C > D predicted.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import math
22
+ import random
23
+ from enum import Enum
24
+
25
+ try:
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ HAS_TORCH = True
30
+ except ImportError:
31
+ HAS_TORCH = False
32
+
33
+ from pipeline.mdlm.tokenizer import (
34
+ VOCAB_SIZE, MASK, PAD, NEVER_MASKED,
35
+ TIER_1_TOKENS, TIER_2_TOKENS, TIER_3_TOKENS,
36
+ get_tier, pad_sequence,
37
+ )
38
+
39
+
40
+ class MaskingSchedule(str, Enum):
41
+ """Masking schedule variants for the hierarchical hypothesis test."""
42
+ HIERARCHICAL = "A" # hierarchical: Tier 1 → Tier 2 → CL+PreAttest
43
+ FLAT = "B" # flat: operators only, uniform within tiers
44
+ UNIFORM = "C" # uniform random over all maskable tokens
45
+ INVERTED = "D" # inverted: CL first, Tier 1 last
46
+
47
+
48
+ if HAS_TORCH:
49
+
50
+ class StructureModel(nn.Module):
51
+ """Small transformer for governed structure denoising."""
52
+
53
+ def __init__(
54
+ self,
55
+ vocab_size: int = VOCAB_SIZE,
56
+ d_model: int = 128,
57
+ nhead: int = 4,
58
+ num_layers: int = 4,
59
+ max_len: int = 40,
60
+ dropout: float = 0.1,
61
+ ):
62
+ super().__init__()
63
+ self.d_model = d_model
64
+ self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
65
+ self.pos_embedding = nn.Embedding(max_len, d_model)
66
+ self.timestep_embedding = nn.Embedding(1000, d_model) # diffusion timestep
67
+
68
+ encoder_layer = nn.TransformerEncoderLayer(
69
+ d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
70
+ dropout=dropout, batch_first=True,
71
+ )
72
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
73
+ self.output_proj = nn.Linear(d_model, vocab_size)
74
+
75
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
76
+ """
77
+ x: (batch, seq_len) — token ids with some positions masked
78
+ t: (batch,) — diffusion timestep (0 = clean, T = fully masked)
79
+ Returns: (batch, seq_len, vocab_size) — logits for each position
80
+ """
81
+ B, L = x.shape
82
+ positions = torch.arange(L, device=x.device).unsqueeze(0).expand(B, -1)
83
+
84
+ h = self.embedding(x) + self.pos_embedding(positions)
85
+ h = h + self.timestep_embedding(t).unsqueeze(1)
86
+
87
+ pad_mask = (x == PAD)
88
+ h = self.transformer(h, src_key_padding_mask=pad_mask)
89
+ return self.output_proj(h)
90
+
91
+
92
+ def apply_mask(
93
+ tokens: torch.Tensor,
94
+ mask_rate: float,
95
+ schedule: MaskingSchedule,
96
+ timestep: int = 0,
97
+ total_timesteps: int = 100,
98
+ ) -> tuple[torch.Tensor, torch.Tensor]:
99
+ """Apply masking schedule to a batch of token sequences.
100
+
101
+ Returns:
102
+ masked_tokens: tokens with some positions replaced by MASK
103
+ mask_positions: boolean tensor (True = was masked)
104
+ """
105
+ B, L = tokens.shape
106
+ masked = tokens.clone()
107
+ mask_positions = torch.zeros(B, L, dtype=torch.bool, device=tokens.device)
108
+
109
+ for b in range(B):
110
+ for i in range(L):
111
+ tok = tokens[b, i].item()
112
+ if tok in NEVER_MASKED:
113
+ continue
114
+
115
+ tier = get_tier(tok)
116
+ if tier == 0:
117
+ continue
118
+
119
+ # Compute per-tier mask probability based on schedule
120
+ p = _tier_mask_prob(tier, mask_rate, schedule, timestep, total_timesteps)
121
+
122
+ if random.random() < p:
123
+ masked[b, i] = MASK
124
+ mask_positions[b, i] = True
125
+
126
+ return masked, mask_positions
127
+
128
+
129
+ def _tier_mask_prob(
130
+ tier: int,
131
+ base_rate: float,
132
+ schedule: MaskingSchedule,
133
+ timestep: int,
134
+ total_timesteps: int,
135
+ ) -> float:
136
+ """Compute mask probability for a token based on its tier and the schedule."""
137
+ t_frac = timestep / max(total_timesteps, 1) # 0 = clean, 1 = fully masked
138
+
139
+ if schedule == MaskingSchedule.UNIFORM:
140
+ return base_rate
141
+
142
+ if schedule == MaskingSchedule.HIERARCHICAL:
143
+ # Tier 1 (Tier 1): masked last, unmasked first
144
+ # Tier 3 (CL+PreAttest): masked first, unmasked last
145
+ if tier == 1:
146
+ return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0
147
+ elif tier == 2:
148
+ return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
149
+ else: # tier 3
150
+ return base_rate * min(1.0, t_frac / 0.33)
151
+
152
+ if schedule == MaskingSchedule.FLAT:
153
+ # Same as 369 but witness tokens are tier 2 priority
154
+ if tier == 1:
155
+ return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0
156
+ elif tier == 2:
157
+ return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
158
+ else:
159
+ return base_rate * min(1.0, t_frac / 0.33)
160
+
161
+ if schedule == MaskingSchedule.INVERTED:
162
+ # Inverted: Tier 1 masked first
163
+ if tier == 1:
164
+ return base_rate * min(1.0, t_frac / 0.33)
165
+ elif tier == 2:
166
+ return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
167
+ else:
168
+ return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0
169
+
170
+ return base_rate
171
+
172
+
173
+ def compute_loss(
174
+ model: StructureModel,
175
+ batch: torch.Tensor,
176
+ schedule: MaskingSchedule,
177
+ timestep: int,
178
+ total_timesteps: int = 100,
179
+ mask_rate: float = 0.5,
180
+ ) -> torch.Tensor:
181
+ """Compute MDLM denoising loss on a batch.
182
+
183
+ Loss = cross-entropy on masked positions only.
184
+ Returns zero loss if no positions were masked (avoids NaN).
185
+ """
186
+ device = next(model.parameters()).device
187
+ batch = batch.to(device)
188
+ t_tensor = torch.full((batch.size(0),), timestep, dtype=torch.long, device=device)
189
+
190
+ masked, mask_pos = apply_mask(batch, mask_rate, schedule, timestep, total_timesteps)
191
+
192
+ # If nothing was masked, return zero loss
193
+ if not mask_pos.any():
194
+ return torch.tensor(0.0, device=device, requires_grad=True)
195
+
196
+ logits = model(masked, t_tensor)
197
+
198
+ # Loss only on masked positions
199
+ loss = F.cross_entropy(
200
+ logits[mask_pos],
201
+ batch[mask_pos],
202
+ ignore_index=PAD,
203
+ )
204
+ return loss
205
+
206
+
207
+ def generate(
208
+ model: StructureModel,
209
+ num_samples: int,
210
+ seq_len: int,
211
+ schedule: MaskingSchedule,
212
+ total_timesteps: int = 50,
213
+ g_slots: int = 3,
214
+ s_slots: int = 4,
215
+ f_slots: int = 3,
216
+ ) -> torch.Tensor:
217
+ """Generate governed structures by template-guided iterative unmasking.
218
+
219
+ The channel_b frame is IMPOSED (governance), not learned:
220
+ <BOS> <G> [MASK slots] </G> <S> [MASK slots] </S> <F> [MASK slots] </F>
221
+ [witness MASK slots] <EOS>
222
+
223
+ The model fills in operator tokens and witness attestation status.
224
+ This respects PROPOSE ≠ PROMOTE: the frame is governance,
225
+ the content is what the kernel crystallizes.
226
+
227
+ g_slots, s_slots, f_slots: number of operator MASK slots per modality.
228
+ Should match the corpus distribution.
229
+ """
230
+ device = next(model.parameters()).device
231
+ from pipeline.mdlm.tokenizer import (
232
+ BOS, EOS, G_OPEN, G_CLOSE, S_OPEN, S_CLOSE, F_OPEN, F_CLOSE,
233
+ WIT_OFFSET, ATTESTED,
234
+ )
235
+
236
+ # Build template with configurable slot counts
237
+ template = [BOS, G_OPEN] + [MASK] * g_slots + [G_CLOSE,
238
+ S_OPEN] + [MASK] * s_slots + [S_CLOSE,
239
+ F_OPEN] + [MASK] * f_slots + [F_CLOSE]
240
+ # 7 witness pairs: WIT_TOKEN MASK
241
+ for w in range(7):
242
+ template.extend([WIT_OFFSET + w, MASK])
243
+ template.append(EOS)
244
+
245
+ # Pad to seq_len
246
+ while len(template) < seq_len:
247
+ template.append(PAD)
248
+ template = template[:seq_len]
249
+
250
+ samples = torch.tensor([template] * num_samples, dtype=torch.long, device=device)
251
+
252
+ model.eval()
253
+ with torch.no_grad():
254
+ for step in range(total_timesteps, -1, -1):
255
+ t_tensor = torch.full((num_samples,), step, dtype=torch.long, device=device)
256
+ logits = model(samples, t_tensor)
257
+ probs = F.softmax(logits, dim=-1)
258
+
259
+ t_frac = step / total_timesteps
260
+
261
+ for b in range(num_samples):
262
+ for i in range(seq_len):
263
+ if samples[b, i].item() != MASK:
264
+ continue
265
+
266
+ pred = torch.multinomial(probs[b, i], 1).item()
267
+ tier = get_tier(pred)
268
+
269
+ # Tier-based unmasking schedule
270
+ should_unmask = False
271
+ if schedule == MaskingSchedule.HIERARCHICAL:
272
+ should_unmask = (tier == 1 and t_frac < 0.33) or \
273
+ (tier == 2 and 0.33 <= t_frac < 0.66) or \
274
+ (tier == 3 and t_frac >= 0.66) or \
275
+ (step == 0) # unmask everything at final step
276
+ else:
277
+ should_unmask = True
278
+
279
+ if should_unmask:
280
+ samples[b, i] = pred
281
+
282
+ # Final pass: force-unmask any remaining MASK tokens
283
+ remaining = (samples == MASK)
284
+ if remaining.any():
285
+ t_tensor = torch.zeros((num_samples,), dtype=torch.long, device=device)
286
+ logits = model(samples, t_tensor)
287
+ for b in range(num_samples):
288
+ for i in range(seq_len):
289
+ if samples[b, i].item() == MASK:
290
+ samples[b, i] = logits[b, i].argmax().item()
291
+
292
+ return samples