Axiom-Ref / pipeline /mdlm /model.py
MetaCortex-Dynamics's picture
Create pipeline/mdlm/model.py
e5f14b1 verified
"""
MDLM β€” Masked Diffusion Language Model for governed structures.
Architecture:
- Small transformer encoder (4 layers, 128 dim, 4 heads)
- Absorbing-state masking: tokens β†’ <MASK> at rate alpha(t)
- Denoising: predict original token from masked sequence
- Loss: cross-entropy on masked positions (reweighted MLM)
Masking schedules:
A: hierarchical hierarchical (Tier 1 β†’ Tier 2 β†’ Tier 3+readiness)
B: flat hierarchical (operators only, no readiness staging)
C: Uniform random
D: inverted inverted
Per PLAN-GHA-002 Β§4.4: A > B > C > D predicted.
"""
from __future__ import annotations
import math
import random
from enum import Enum
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
from pipeline.mdlm.tokenizer import (
VOCAB_SIZE, MASK, PAD, NEVER_MASKED,
TIER_1_TOKENS, TIER_2_TOKENS, TIER_3_TOKENS,
get_tier, pad_sequence,
)
class MaskingSchedule(str, Enum):
"""Masking schedule variants for the hierarchical hypothesis test."""
HIERARCHICAL = "A" # hierarchical: Tier 1 β†’ Tier 2 β†’ CL+PreAttest
FLAT = "B" # flat: operators only, uniform within tiers
UNIFORM = "C" # uniform random over all maskable tokens
INVERTED = "D" # inverted: CL first, Tier 1 last
if HAS_TORCH:
class StructureModel(nn.Module):
"""Small transformer for governed structure denoising."""
def __init__(
self,
vocab_size: int = VOCAB_SIZE,
d_model: int = 128,
nhead: int = 4,
num_layers: int = 4,
max_len: int = 40,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
self.pos_embedding = nn.Embedding(max_len, d_model)
self.timestep_embedding = nn.Embedding(1000, d_model) # diffusion timestep
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
dropout=dropout, batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_proj = nn.Linear(d_model, vocab_size)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""
x: (batch, seq_len) β€” token ids with some positions masked
t: (batch,) β€” diffusion timestep (0 = clean, T = fully masked)
Returns: (batch, seq_len, vocab_size) β€” logits for each position
"""
B, L = x.shape
positions = torch.arange(L, device=x.device).unsqueeze(0).expand(B, -1)
h = self.embedding(x) + self.pos_embedding(positions)
h = h + self.timestep_embedding(t).unsqueeze(1)
pad_mask = (x == PAD)
h = self.transformer(h, src_key_padding_mask=pad_mask)
return self.output_proj(h)
def apply_mask(
tokens: torch.Tensor,
mask_rate: float,
schedule: MaskingSchedule,
timestep: int = 0,
total_timesteps: int = 100,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply masking schedule to a batch of token sequences.
Returns:
masked_tokens: tokens with some positions replaced by MASK
mask_positions: boolean tensor (True = was masked)
"""
B, L = tokens.shape
masked = tokens.clone()
mask_positions = torch.zeros(B, L, dtype=torch.bool, device=tokens.device)
for b in range(B):
for i in range(L):
tok = tokens[b, i].item()
if tok in NEVER_MASKED:
continue
tier = get_tier(tok)
if tier == 0:
continue
# Compute per-tier mask probability based on schedule
p = _tier_mask_prob(tier, mask_rate, schedule, timestep, total_timesteps)
if random.random() < p:
masked[b, i] = MASK
mask_positions[b, i] = True
return masked, mask_positions
def _tier_mask_prob(
tier: int,
base_rate: float,
schedule: MaskingSchedule,
timestep: int,
total_timesteps: int,
) -> float:
"""Compute mask probability for a token based on its tier and the schedule."""
t_frac = timestep / max(total_timesteps, 1) # 0 = clean, 1 = fully masked
if schedule == MaskingSchedule.UNIFORM:
return base_rate
if schedule == MaskingSchedule.HIERARCHICAL:
# Tier 1 (Tier 1): masked last, unmasked first
# Tier 3 (CL+PreAttest): masked first, unmasked last
if tier == 1:
return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0
elif tier == 2:
return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
else: # tier 3
return base_rate * min(1.0, t_frac / 0.33)
if schedule == MaskingSchedule.FLAT:
# Same as 369 but witness tokens are tier 2 priority
if tier == 1:
return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0
elif tier == 2:
return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
else:
return base_rate * min(1.0, t_frac / 0.33)
if schedule == MaskingSchedule.INVERTED:
# Inverted: Tier 1 masked first
if tier == 1:
return base_rate * min(1.0, t_frac / 0.33)
elif tier == 2:
return base_rate * max(0.0, (t_frac - 0.33) / 0.34) if t_frac > 0.33 else 0.0
else:
return base_rate * max(0.0, (t_frac - 0.66) / 0.34) if t_frac > 0.66 else 0.0
return base_rate
def compute_loss(
model: StructureModel,
batch: torch.Tensor,
schedule: MaskingSchedule,
timestep: int,
total_timesteps: int = 100,
mask_rate: float = 0.5,
) -> torch.Tensor:
"""Compute MDLM denoising loss on a batch.
Loss = cross-entropy on masked positions only.
Returns zero loss if no positions were masked (avoids NaN).
"""
device = next(model.parameters()).device
batch = batch.to(device)
t_tensor = torch.full((batch.size(0),), timestep, dtype=torch.long, device=device)
masked, mask_pos = apply_mask(batch, mask_rate, schedule, timestep, total_timesteps)
# If nothing was masked, return zero loss
if not mask_pos.any():
return torch.tensor(0.0, device=device, requires_grad=True)
logits = model(masked, t_tensor)
# Loss only on masked positions
loss = F.cross_entropy(
logits[mask_pos],
batch[mask_pos],
ignore_index=PAD,
)
return loss
def generate(
model: StructureModel,
num_samples: int,
seq_len: int,
schedule: MaskingSchedule,
total_timesteps: int = 50,
g_slots: int = 3,
s_slots: int = 4,
f_slots: int = 3,
) -> torch.Tensor:
"""Generate governed structures by template-guided iterative unmasking.
The channel_b frame is IMPOSED (governance), not learned:
<BOS> <G> [MASK slots] </G> <S> [MASK slots] </S> <F> [MASK slots] </F>
[witness MASK slots] <EOS>
The model fills in operator tokens and witness attestation status.
This respects PROPOSE β‰  PROMOTE: the frame is governance,
the content is what the kernel crystallizes.
g_slots, s_slots, f_slots: number of operator MASK slots per modality.
Should match the corpus distribution.
"""
device = next(model.parameters()).device
from pipeline.mdlm.tokenizer import (
BOS, EOS, G_OPEN, G_CLOSE, S_OPEN, S_CLOSE, F_OPEN, F_CLOSE,
WIT_OFFSET, ATTESTED,
)
# Build template with configurable slot counts
template = [BOS, G_OPEN] + [MASK] * g_slots + [G_CLOSE,
S_OPEN] + [MASK] * s_slots + [S_CLOSE,
F_OPEN] + [MASK] * f_slots + [F_CLOSE]
# 7 witness pairs: WIT_TOKEN MASK
for w in range(7):
template.extend([WIT_OFFSET + w, MASK])
template.append(EOS)
# Pad to seq_len
while len(template) < seq_len:
template.append(PAD)
template = template[:seq_len]
samples = torch.tensor([template] * num_samples, dtype=torch.long, device=device)
model.eval()
with torch.no_grad():
for step in range(total_timesteps, -1, -1):
t_tensor = torch.full((num_samples,), step, dtype=torch.long, device=device)
logits = model(samples, t_tensor)
probs = F.softmax(logits, dim=-1)
t_frac = step / total_timesteps
for b in range(num_samples):
for i in range(seq_len):
if samples[b, i].item() != MASK:
continue
pred = torch.multinomial(probs[b, i], 1).item()
tier = get_tier(pred)
# Tier-based unmasking schedule
should_unmask = False
if schedule == MaskingSchedule.HIERARCHICAL:
should_unmask = (tier == 1 and t_frac < 0.33) or \
(tier == 2 and 0.33 <= t_frac < 0.66) or \
(tier == 3 and t_frac >= 0.66) or \
(step == 0) # unmask everything at final step
else:
should_unmask = True
if should_unmask:
samples[b, i] = pred
# Final pass: force-unmask any remaining MASK tokens
remaining = (samples == MASK)
if remaining.any():
t_tensor = torch.zeros((num_samples,), dtype=torch.long, device=device)
logits = model(samples, t_tensor)
for b in range(num_samples):
for i in range(seq_len):
if samples[b, i].item() == MASK:
samples[b, i] = logits[b, i].argmax().item()
return samples