youraveragedev commited on
Commit
26b231d
Β·
verified Β·
1 Parent(s): 2ad81e0

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +340 -0
model.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Small MDLM (Masked Diffusion Language Model) for text generation.
3
+
4
+ Based on: "Simple and Effective Masked Diffusion Language Models" (Sahoo et al., NeurIPS 2024)
5
+ Architecture: DiT backbone with adaLN-zero conditioning, RoPE, bidirectional attention.
6
+ No flash_attn dependency β€” uses PyTorch native scaled_dot_product_attention.
7
+ """
8
+
9
+ import math
10
+ import typing
11
+ import json
12
+ import os
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from transformers import PreTrainedModel, PretrainedConfig
18
+ from transformers.modeling_outputs import MaskedLMOutput
19
+
20
+
21
+ class MDLMConfig(PretrainedConfig):
22
+ """Configuration for a small MDLM text diffusion model."""
23
+ model_type = "mdlm"
24
+
25
+ def __init__(
26
+ self,
27
+ vocab_size: int = 50258,
28
+ model_length: int = 256,
29
+ hidden_dim: int = 512,
30
+ cond_dim: int = 128,
31
+ n_blocks: int = 6,
32
+ n_heads: int = 8,
33
+ dropout: float = 0.1,
34
+ time_conditioning: bool = True,
35
+ mlp_ratio: int = 4,
36
+ mask_token_id: int = 50257,
37
+ **kwargs
38
+ ):
39
+ super().__init__(**kwargs)
40
+ self.vocab_size = vocab_size
41
+ self.model_length = model_length
42
+ self.hidden_dim = hidden_dim
43
+ self.cond_dim = cond_dim
44
+ self.n_blocks = n_blocks
45
+ self.n_heads = n_heads
46
+ self.dropout = dropout
47
+ self.time_conditioning = time_conditioning
48
+ self.mlp_ratio = mlp_ratio
49
+ self.mask_token_id = mask_token_id
50
+
51
+
52
+ # ─── Rotary Position Embeddings ───────────────────────────
53
+
54
+ class RotaryEmbedding(nn.Module):
55
+ def __init__(self, dim, base=10000):
56
+ super().__init__()
57
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
58
+ self.register_buffer("inv_freq", inv_freq)
59
+
60
+ def forward(self, seq_len, device):
61
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
62
+ freqs = torch.outer(t, self.inv_freq)
63
+ return torch.cat([freqs, freqs], dim=-1) # (seq_len, dim)
64
+
65
+
66
+ def rotate_half(x):
67
+ x1, x2 = x.chunk(2, dim=-1)
68
+ return torch.cat((-x2, x1), dim=-1)
69
+
70
+
71
+ def apply_rotary_pos_emb(q, k, freqs):
72
+ """Apply RoPE to query and key tensors."""
73
+ cos = freqs.cos().unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim)
74
+ sin = freqs.sin().unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim)
75
+ q = q * cos + rotate_half(q) * sin
76
+ k = k * cos + rotate_half(k) * sin
77
+ return q, k
78
+
79
+
80
+ # ─── Timestep Embedding ──────────────────────────────────
81
+
82
+ class TimestepEmbedder(nn.Module):
83
+ def __init__(self, hidden_size, frequency_embedding_size=256):
84
+ super().__init__()
85
+ self.mlp = nn.Sequential(
86
+ nn.Linear(frequency_embedding_size, hidden_size),
87
+ nn.SiLU(),
88
+ nn.Linear(hidden_size, hidden_size),
89
+ )
90
+ self.frequency_embedding_size = frequency_embedding_size
91
+
92
+ @staticmethod
93
+ def timestep_embedding(t, dim, max_period=10000):
94
+ half = dim // 2
95
+ freqs = torch.exp(
96
+ -math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half
97
+ )
98
+ args = t[:, None].float() * freqs[None]
99
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
100
+ if dim % 2:
101
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
102
+ return embedding
103
+
104
+ def forward(self, t):
105
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
106
+ return self.mlp(t_freq)
107
+
108
+
109
+ # ─── LayerNorm ────────────────────────────────────────────
110
+
111
+ class LayerNorm(nn.Module):
112
+ def __init__(self, dim):
113
+ super().__init__()
114
+ self.weight = nn.Parameter(torch.ones(dim))
115
+ self.dim = dim
116
+
117
+ def forward(self, x):
118
+ with torch.amp.autocast("cuda", enabled=False):
119
+ x = F.layer_norm(x.float(), [self.dim])
120
+ return x * self.weight[None, None, :]
121
+
122
+
123
+ # ─── DiT Block with adaLN-zero ───────────────────────────
124
+
125
+ class DDiTBlock(nn.Module):
126
+ def __init__(self, dim, n_heads, cond_dim, mlp_ratio=4, dropout=0.1):
127
+ super().__init__()
128
+ self.n_heads = n_heads
129
+ self.head_dim = dim // n_heads
130
+
131
+ self.norm1 = LayerNorm(dim)
132
+ self.attn_qkv = nn.Linear(dim, 3 * dim, bias=False)
133
+ self.attn_out = nn.Linear(dim, dim, bias=False)
134
+
135
+ self.norm2 = LayerNorm(dim)
136
+ self.mlp = nn.Sequential(
137
+ nn.Linear(dim, mlp_ratio * dim),
138
+ nn.GELU(approximate="tanh"),
139
+ nn.Linear(mlp_ratio * dim, dim),
140
+ )
141
+ self.dropout = nn.Dropout(dropout)
142
+ self.drop_p = dropout
143
+
144
+ # adaLN-zero: 6 modulation params (shift, scale, gate for attn & mlp)
145
+ self.adaLN_modulation = nn.Linear(cond_dim, 6 * dim, bias=True)
146
+ nn.init.zeros_(self.adaLN_modulation.weight)
147
+ nn.init.zeros_(self.adaLN_modulation.bias)
148
+
149
+ def forward(self, x, rotary_freqs, c):
150
+ B, S, D = x.shape
151
+
152
+ # adaLN modulation
153
+ mod = self.adaLN_modulation(c)[:, None, :] # (B, 1, 6*D)
154
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=-1)
155
+
156
+ # ── Self-Attention ──
157
+ h = self.norm1(x)
158
+ h = h * (1 + scale_msa) + shift_msa
159
+
160
+ qkv = self.attn_qkv(h)
161
+ qkv = qkv.view(B, S, 3, self.n_heads, self.head_dim)
162
+ q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
163
+ # q, k, v: (B, S, n_heads, head_dim)
164
+
165
+ # Apply RoPE
166
+ q, k = apply_rotary_pos_emb(q, k, rotary_freqs)
167
+
168
+ # Transpose to (B, n_heads, S, head_dim) for SDPA
169
+ q = q.transpose(1, 2)
170
+ k = k.transpose(1, 2)
171
+ v = v.transpose(1, 2)
172
+
173
+ # Bidirectional attention (no causal mask)
174
+ attn_out = F.scaled_dot_product_attention(
175
+ q, k, v,
176
+ dropout_p=self.drop_p if self.training else 0.0,
177
+ is_causal=False,
178
+ )
179
+ attn_out = attn_out.transpose(1, 2).reshape(B, S, D)
180
+
181
+ attn_out = self.attn_out(attn_out)
182
+ x = x + gate_msa * self.dropout(attn_out)
183
+
184
+ # ── MLP ──
185
+ h = self.norm2(x)
186
+ h = h * (1 + scale_mlp) + shift_mlp
187
+ x = x + gate_mlp * self.dropout(self.mlp(h))
188
+
189
+ return x
190
+
191
+
192
+ # ─── Final Layer ──────────────────────────────────────────
193
+
194
+ class DDitFinalLayer(nn.Module):
195
+ def __init__(self, hidden_size, out_channels, cond_dim):
196
+ super().__init__()
197
+ self.norm_final = LayerNorm(hidden_size)
198
+ self.linear = nn.Linear(hidden_size, out_channels)
199
+ nn.init.zeros_(self.linear.weight)
200
+ nn.init.zeros_(self.linear.bias)
201
+
202
+ self.adaLN_modulation = nn.Linear(cond_dim, 2 * hidden_size, bias=True)
203
+ nn.init.zeros_(self.adaLN_modulation.weight)
204
+ nn.init.zeros_(self.adaLN_modulation.bias)
205
+
206
+ def forward(self, x, c):
207
+ shift, scale = self.adaLN_modulation(c)[:, None, :].chunk(2, dim=-1)
208
+ x = self.norm_final(x)
209
+ x = x * (1 + scale) + shift
210
+ return self.linear(x)
211
+
212
+
213
+ # ─── Full Model ──────────────────────────────────────────
214
+
215
+ class MDLM(PreTrainedModel):
216
+ """
217
+ Small Masked Diffusion Language Model.
218
+
219
+ Forward pass: given noisy input_ids and timesteps t ∈ [0,1],
220
+ predicts logits over vocab for each position.
221
+ """
222
+ config_class = MDLMConfig
223
+
224
+ def __init__(self, config: MDLMConfig):
225
+ super().__init__(config)
226
+ self.config = config
227
+
228
+ self.vocab_embed = nn.Embedding(config.vocab_size, config.hidden_dim)
229
+ nn.init.kaiming_uniform_(self.vocab_embed.weight, a=math.sqrt(5))
230
+
231
+ self.sigma_map = TimestepEmbedder(config.cond_dim)
232
+ self.rotary_emb = RotaryEmbedding(config.hidden_dim // config.n_heads)
233
+
234
+ self.blocks = nn.ModuleList([
235
+ DDiTBlock(
236
+ config.hidden_dim,
237
+ config.n_heads,
238
+ config.cond_dim,
239
+ mlp_ratio=config.mlp_ratio,
240
+ dropout=config.dropout,
241
+ )
242
+ for _ in range(config.n_blocks)
243
+ ])
244
+
245
+ self.output_layer = DDitFinalLayer(
246
+ config.hidden_dim, config.vocab_size, config.cond_dim
247
+ )
248
+
249
+ # Separate output projection (no weight tying with embeddings)
250
+ self.post_init()
251
+
252
+ def get_num_params(self):
253
+ return sum(p.numel() for p in self.parameters())
254
+
255
+ def forward(
256
+ self,
257
+ input_ids: torch.LongTensor,
258
+ timesteps: torch.FloatTensor,
259
+ output_hidden_states: bool = False,
260
+ return_dict: bool = True,
261
+ ):
262
+ B, S = input_ids.shape
263
+
264
+ x = self.vocab_embed(input_ids)
265
+
266
+ if not self.config.time_conditioning:
267
+ timesteps = torch.zeros_like(timesteps)
268
+
269
+ c = F.silu(self.sigma_map(timesteps))
270
+
271
+ rotary_freqs = self.rotary_emb(S, device=x.device)
272
+
273
+ all_hidden = [x] if output_hidden_states else None
274
+
275
+ # Mixed precision: let the outer training loop handle autocast
276
+ for block in self.blocks:
277
+ x = block(x, rotary_freqs, c)
278
+ if output_hidden_states:
279
+ all_hidden.append(x)
280
+ logits = self.output_layer(x, c)
281
+
282
+ if return_dict:
283
+ return MaskedLMOutput(logits=logits, hidden_states=all_hidden, loss=None)
284
+ return logits
285
+
286
+
287
+ # ─── Sampling ─────────────────────────────────────────────
288
+
289
+ @torch.no_grad()
290
+ def sample(
291
+ model: MDLM,
292
+ seq_len: int,
293
+ batch_size: int = 1,
294
+ num_steps: int = 100,
295
+ temperature: float = 0.7,
296
+ device: str = "cuda",
297
+ ):
298
+ """
299
+ Ancestral sampling from MDLM.
300
+
301
+ Start from all [MASK] tokens.
302
+ At each step s→t (t < s): unmask tokens with probability (1 - t/s),
303
+ using model predictions.
304
+ """
305
+ mask_id = model.config.mask_token_id
306
+
307
+ # Start with all masked
308
+ x = torch.full((batch_size, seq_len), mask_id, dtype=torch.long, device=device)
309
+
310
+ # Discretize time from 1β†’0
311
+ timesteps = torch.linspace(1.0, 0.0, num_steps + 1, device=device)
312
+
313
+ for i in range(num_steps):
314
+ t_now = timesteps[i]
315
+ t_next = timesteps[i + 1]
316
+
317
+ # Get model predictions
318
+ t_batch = torch.full((batch_size,), t_now.item(), device=device)
319
+ output = model(x, t_batch, return_dict=True)
320
+ logits = output.logits / temperature
321
+
322
+ # Sample from predicted distribution
323
+ probs = F.softmax(logits, dim=-1)
324
+ predicted = torch.multinomial(probs.view(-1, probs.shape[-1]), 1).view(batch_size, seq_len)
325
+
326
+ # Determine which masked positions to unmask
327
+ is_masked = (x == mask_id)
328
+
329
+ if t_next <= 0:
330
+ # Last step: unmask everything
331
+ x = torch.where(is_masked, predicted, x)
332
+ else:
333
+ # Unmask with probability (1 - t_next/t_now)
334
+ unmask_prob = 1.0 - (t_next / t_now)
335
+ unmask = torch.bernoulli(
336
+ torch.full_like(x, unmask_prob, dtype=torch.float)
337
+ ).bool() & is_masked
338
+ x = torch.where(unmask, predicted, x)
339
+
340
+ return x