trixyL commited on
Commit
3617dac
·
1 Parent(s): 886ef62

dump: initial dump

Browse files
Files changed (5) hide show
  1. README.md +4 -4
  2. app.py +47 -0
  3. model.py +776 -0
  4. model/model.safetensors +3 -0
  5. requirements.txt +7 -0
README.md CHANGED
@@ -1,14 +1,14 @@
1
  ---
2
- title: Mnist Diff Demo
3
- emoji: 🐢
4
  colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.5.1
8
  python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
- short_description: Generate MNIST like numerals
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MNIST Diffusion (TransformerLM)
3
+ emoji: 🧪
4
  colorFrom: red
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.5.1
8
  python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
+ short_description: Discrete diffusion MNIST digit generation
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import spaces
4
+
5
+ from model import generate_images, load_model
6
+
7
+ MODEL_READY = False
8
+
9
+
10
+ def ensure_model_loaded():
11
+ global MODEL_READY
12
+ if not MODEL_READY:
13
+ load_model()
14
+ MODEL_READY = True
15
+
16
+
17
+ @spaces.GPU
18
+ @torch.inference_mode()
19
+ def predict(label: int, steps: int, num_samples: int):
20
+ ensure_model_loaded()
21
+ return generate_images(label=label, steps=steps, num_samples=num_samples)
22
+
23
+
24
+ with gr.Blocks(title="MNIST Diffusion") as demo:
25
+ gr.Markdown("# MNIST Diffusion")
26
+ gr.Markdown(
27
+ "Discrete diffusion model for MNIST digits. "
28
+ "Sampling uses fixed CFG=2.0, temperature=0.6, top_p=0.99."
29
+ )
30
+
31
+ gallery = gr.Gallery(label="Samples", show_label=True, columns=6, rows=3, height=360)
32
+
33
+ with gr.Row():
34
+ label = gr.Dropdown([str(i) for i in range(10)], value="6", label="Label")
35
+ steps = gr.Slider(1, 784, value=784, step=1, label="Steps")
36
+ num_samples = gr.Slider(1, 36, value=16, step=1, label="Samples")
37
+
38
+ generate_btn = gr.Button("Generate")
39
+
40
+ generate_btn.click(
41
+ fn=predict,
42
+ inputs=[label, steps, num_samples],
43
+ outputs=gallery,
44
+ )
45
+
46
+ if __name__ == "__main__":
47
+ demo.launch()
model.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Tuple, List
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ from safetensors.torch import load_file
10
+ from einops import einsum, rearrange
11
+
12
+
13
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
14
+ CHECKPOINT_PATH = os.path.join(BASE_DIR, "model", "model.safetensors")
15
+
16
+ MODEL_CONFIG = {
17
+ "model_type": "image",
18
+ "label_vocab_size": 11,
19
+ "vocab_size": 33,
20
+ "pixel_bins": 32,
21
+ "context_length": 784,
22
+ "d_model": 256,
23
+ "num_layers": 8,
24
+ "num_heads": 16,
25
+ "d_ff": 1024,
26
+ "rope_theta": 10000.0,
27
+ "attention_backend": "torch_sdpa",
28
+ "attention_sdp_backend": "auto",
29
+ "device": "cuda",
30
+ "dtype": "float16",
31
+ "mask_token_id": 32,
32
+ "null_label_id": 10,
33
+ "image_height": 28,
34
+ "image_width": 28,
35
+ }
36
+
37
+ INFER_CONFIG = {
38
+ "block_length": 784,
39
+ "temperature": 0.6,
40
+ "top_p": 0.99,
41
+ "cfg_scale": 2.0,
42
+ "remasking": "random",
43
+ }
44
+
45
+ DTYPES = {
46
+ "float16": torch.float16,
47
+ "float32": torch.float32,
48
+ "bfloat16": torch.bfloat16,
49
+ }
50
+
51
+
52
+ def _resolve_device_dtype(device: str, dtype_name: str) -> Tuple[str, torch.dtype]:
53
+ resolved_device = device
54
+ if device == "cuda" and not torch.cuda.is_available():
55
+ resolved_device = "cpu"
56
+
57
+ resolved_dtype = DTYPES[dtype_name]
58
+ if resolved_device == "cpu" and resolved_dtype == torch.float16:
59
+ resolved_dtype = torch.float32
60
+
61
+ return resolved_device, resolved_dtype
62
+
63
+
64
+ def set_sdp_backend(backend: str) -> None:
65
+ backend = backend.lower()
66
+ allowed = {"auto", "flash", "mem_efficient", "math"}
67
+ if backend not in allowed:
68
+ raise ValueError(f"attention_sdp_backend must be one of {sorted(allowed)}")
69
+ if not torch.cuda.is_available():
70
+ return
71
+ if backend == "auto":
72
+ torch.backends.cuda.enable_flash_sdp(True)
73
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
74
+ torch.backends.cuda.enable_math_sdp(True)
75
+ return
76
+ torch.backends.cuda.enable_flash_sdp(backend == "flash")
77
+ torch.backends.cuda.enable_mem_efficient_sdp(backend == "mem_efficient")
78
+ torch.backends.cuda.enable_math_sdp(backend == "math")
79
+
80
+
81
+ class Linear(torch.nn.Module):
82
+ def __init__(self, in_features, out_features, device=None, dtype=None):
83
+ super().__init__()
84
+ self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
85
+ mean = 0.0
86
+ std = 2 / (in_features + out_features)
87
+ a = mean - 3 * std
88
+ b = mean + 3 * std
89
+ torch.nn.init.trunc_normal_(self.weight, mean=mean, std=std, a=a, b=b)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ y = einsum(self.weight, x, "out_features in_features, ... in_features -> ... out_features")
93
+ return y
94
+
95
+
96
+ class Embedding(torch.nn.Module):
97
+ def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
98
+ super().__init__()
99
+ self.num_embeddings = num_embeddings
100
+ self.embedding_dim = embedding_dim
101
+ self.weight = torch.nn.Parameter(torch.empty(num_embeddings, embedding_dim, device=device, dtype=dtype))
102
+ torch.nn.init.trunc_normal_(self.weight, mean=0, std=1, a=-3, b=3)
103
+
104
+ def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
105
+ embeds = self.weight[token_ids]
106
+ return embeds
107
+
108
+
109
+ class RMSNorm(torch.nn.Module):
110
+ def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
111
+ super().__init__()
112
+ self.eps = eps
113
+ self.d_model = d_model
114
+ self.weight = torch.nn.Parameter(torch.empty(d_model, device=device, dtype=dtype))
115
+ torch.nn.init.ones_(self.weight)
116
+
117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
118
+ in_dtype = x.dtype
119
+ x = x.to(torch.float32)
120
+ rms = torch.sqrt(torch.mean(x ** 2, dim=-1) + self.eps).unsqueeze(-1)
121
+ x = (1 / rms) * (x * self.weight)
122
+ return x.to(in_dtype)
123
+
124
+
125
+ class SwiGLU(torch.nn.Module):
126
+ def __init__(self, d_model: int, d_ff: int, device=None, dtype=None):
127
+ super().__init__()
128
+ self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype)
129
+ self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype)
130
+ self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype)
131
+
132
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
+ w1x = self.w1(x)
134
+ w3x = self.w3(x)
135
+ silu = w1x * torch.sigmoid(w1x)
136
+ glu = silu * w3x
137
+ w2x = self.w2(glu)
138
+ return w2x
139
+
140
+
141
+ def softmax(x: torch.Tensor, dim: int):
142
+ x_max = x.max(dim=dim, keepdim=True).values
143
+ x_stable = x - x_max
144
+ exp_x = torch.exp(x_stable)
145
+ sum_exp_x = exp_x.sum(dim=dim, keepdim=True)
146
+ return exp_x / sum_exp_x
147
+
148
+
149
+ def top_p_filter(probs: torch.Tensor, p: float) -> torch.Tensor:
150
+ if probs.dim() < 2:
151
+ raise ValueError("probs must have at least 2 dimensions")
152
+ orig_shape = probs.shape
153
+ vocab = orig_shape[-1]
154
+ probs = probs.reshape(-1, vocab)
155
+ if p <= 0:
156
+ argmax = probs.argmax(dim=-1)
157
+ out = torch.zeros_like(probs)
158
+ out.scatter_(-1, argmax.unsqueeze(-1), 1.0)
159
+ return out.reshape(orig_shape)
160
+ if p >= 1:
161
+ return (probs / probs.sum(dim=-1, keepdim=True)).reshape(orig_shape)
162
+
163
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
164
+ cumulative = torch.cumsum(sorted_probs, dim=-1)
165
+
166
+ keep = cumulative <= p
167
+ keep[..., 0] = True
168
+ first_ge = (cumulative >= p).float().argmax(dim=-1)
169
+ rows = torch.arange(keep.shape[0], device=keep.device)
170
+ keep[rows, first_ge] = True
171
+
172
+ filtered_sorted = torch.where(keep, sorted_probs, torch.zeros_like(sorted_probs))
173
+ norm = filtered_sorted.sum(dim=-1, keepdim=True).clamp_min(1e-12)
174
+ filtered_sorted = filtered_sorted / norm
175
+
176
+ filtered = torch.zeros_like(probs)
177
+ filtered.scatter_(dim=-1, index=sorted_indices, src=filtered_sorted)
178
+ return filtered.reshape(orig_shape)
179
+
180
+
181
+ def add_gumbel_noise(logits: torch.Tensor, temperature: float, *, generator: torch.Generator | None = None) -> torch.Tensor:
182
+ if temperature <= 0:
183
+ return logits
184
+
185
+ noise = torch.rand(logits.shape, device=logits.device, dtype=torch.float64, generator=generator)
186
+ gumbel_noise = (-torch.log(noise)) ** temperature
187
+ logits64 = logits.to(torch.float64)
188
+ perturbed = logits64.exp() / gumbel_noise
189
+ return perturbed.to(logits.dtype)
190
+
191
+
192
+ def compute_transfer_schedule(mask: torch.Tensor, steps: int) -> torch.Tensor:
193
+ if steps <= 0:
194
+ raise ValueError("steps must be > 0")
195
+ if mask.dim() != 2:
196
+ raise ValueError("mask must be 2D (batch, block_length)")
197
+
198
+ counts = mask.sum(dim=1, keepdim=True).to(torch.int64)
199
+ base = counts // steps
200
+ remainder = counts % steps
201
+
202
+ schedule = base.expand(-1, steps).clone()
203
+ for idx in range(schedule.size(0)):
204
+ r = remainder[idx, 0].item()
205
+ if r > 0:
206
+ schedule[idx, :r] += 1
207
+ return schedule
208
+
209
+
210
+ def _prepare_attention_mask(attention_mask: torch.Tensor, ref_tensor: torch.Tensor) -> torch.Tensor:
211
+ mask = attention_mask.to(device=ref_tensor.device, dtype=torch.bool)
212
+ if mask.dim() == 2:
213
+ mask = mask[:, None, None, :]
214
+ elif mask.dim() == 3:
215
+ mask = mask[:, None, :, :]
216
+ elif mask.dim() != 4:
217
+ raise ValueError("attention_mask must be 2D, 3D, or 4D")
218
+ return mask
219
+
220
+
221
+ def scaled_dot_product_attention(
222
+ Q: torch.Tensor,
223
+ K: torch.Tensor,
224
+ V: torch.Tensor,
225
+ attention_mask: torch.Tensor | None = None,
226
+ ):
227
+ scale = torch.tensor(Q.shape[-1], device=Q.device, dtype=Q.dtype).sqrt()
228
+ qk_score = einsum(Q, K, "batch_size ... n d_k, batch_size ... m d_k -> batch_size ... n m") / scale
229
+ if attention_mask is not None:
230
+ mask = _prepare_attention_mask(attention_mask, qk_score)
231
+ qk_score = qk_score.masked_fill(~mask, float("-inf"))
232
+ softmax_qk_score = softmax(qk_score, dim=-1)
233
+ attn = einsum(softmax_qk_score, V, "batch_size ... n m, batch_size ... m d_k -> batch_size ... n d_k")
234
+ return attn
235
+
236
+
237
+ def torch_scaled_dot_product_attention(
238
+ Q: torch.Tensor,
239
+ K: torch.Tensor,
240
+ V: torch.Tensor,
241
+ attention_mask: torch.Tensor | None = None,
242
+ ):
243
+ Q = Q.contiguous()
244
+ K = K.contiguous()
245
+ V = V.contiguous()
246
+ mask = None
247
+ if attention_mask is not None:
248
+ mask = _prepare_attention_mask(attention_mask, Q)
249
+ return torch.nn.functional.scaled_dot_product_attention(Q, K, V, attn_mask=mask, dropout_p=0.0, is_causal=False)
250
+
251
+
252
+ class RotaryPositionalEmbedding(torch.nn.Module):
253
+ def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
254
+ super().__init__()
255
+ self.device = device
256
+
257
+ theta_i = theta ** (torch.arange(0, d_k, 2).float() / d_k)
258
+ position = torch.arange(max_seq_len)
259
+
260
+ phases = position.unsqueeze(1) / theta_i.unsqueeze(0)
261
+ phases_cos = torch.cos(phases)
262
+ phases_sin = torch.sin(phases)
263
+ phases_combined = torch.stack([phases_cos, phases_sin], dim=-1).to(device=device)
264
+
265
+ self.register_buffer("phases", phases_combined, persistent=False)
266
+
267
+ def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
268
+ x = rearrange(x, "... (d_k p) -> ... d_k p", p=2)
269
+ x1 = x[..., 0]
270
+ x2 = x[..., 1]
271
+
272
+ phases_cos = self.phases[..., 0][token_positions].to(dtype=x.dtype)
273
+ phases_sin = self.phases[..., 1][token_positions].to(dtype=x.dtype)
274
+
275
+ x_rotated = torch.stack([
276
+ x1 * phases_cos - x2 * phases_sin,
277
+ x1 * phases_sin + x2 * phases_cos,
278
+ ], dim=-1)
279
+
280
+ return x_rotated.flatten(-2)
281
+
282
+
283
+ class MultiheadSelfAttentionRoPE(torch.nn.Module):
284
+ def __init__(
285
+ self,
286
+ d_model: int,
287
+ num_heads: int,
288
+ max_seq_len: int,
289
+ theta: float,
290
+ attention_backend: str = "custom",
291
+ device=None,
292
+ dtype=None,
293
+ ):
294
+ super().__init__()
295
+ self.d_model = d_model
296
+ self.num_heads = num_heads
297
+ self.d_k = self.d_model // self.num_heads
298
+ self.d_v = self.d_k
299
+ self.max_seq_len = max_seq_len
300
+ self.theta = theta
301
+ if attention_backend not in {"custom", "torch_sdpa"}:
302
+ raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']")
303
+ self.attention_backend = attention_backend
304
+
305
+ self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
306
+ self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
307
+ self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
308
+ self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
309
+
310
+ self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device)
311
+
312
+ def forward(
313
+ self,
314
+ x: torch.Tensor,
315
+ token_positions: torch.Tensor,
316
+ attention_mask: torch.Tensor | None = None,
317
+ ) -> torch.Tensor:
318
+ wqx = self.q_proj(x)
319
+ wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
320
+ wqx_rearr_rope = self.rope(wqx_rearr, token_positions)
321
+
322
+ wkx = self.k_proj(x)
323
+ wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
324
+ wkx_rearr_rope = self.rope(wkx_rearr, token_positions)
325
+
326
+ wvx = self.v_proj(x)
327
+ wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v)
328
+
329
+ if self.attention_backend == "torch_sdpa":
330
+ attn = torch_scaled_dot_product_attention(
331
+ wqx_rearr_rope,
332
+ wkx_rearr_rope,
333
+ wvx_rearr,
334
+ attention_mask=attention_mask,
335
+ )
336
+ else:
337
+ attn = scaled_dot_product_attention(
338
+ wqx_rearr_rope,
339
+ wkx_rearr_rope,
340
+ wvx_rearr,
341
+ attention_mask=attention_mask,
342
+ )
343
+ attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v)
344
+ attn_rearr_proj = self.output_proj(attn_rearr)
345
+ return attn_rearr_proj
346
+
347
+
348
+ class MultiheadCrossAttentionRoPE(torch.nn.Module):
349
+ def __init__(
350
+ self,
351
+ d_model: int,
352
+ num_heads: int,
353
+ max_seq_len: int,
354
+ theta: float,
355
+ attention_backend: str = "custom",
356
+ device=None,
357
+ dtype=None,
358
+ ):
359
+ super().__init__()
360
+ self.d_model = d_model
361
+ self.num_heads = num_heads
362
+ self.d_k = self.d_model // self.num_heads
363
+ self.d_v = self.d_k
364
+ self.max_seq_len = max_seq_len
365
+ self.theta = theta
366
+ if attention_backend not in {"custom", "torch_sdpa"}:
367
+ raise ValueError("attention_backend must be one of ['custom', 'torch_sdpa']")
368
+ self.attention_backend = attention_backend
369
+
370
+ self.q_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
371
+ self.k_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
372
+ self.v_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
373
+ self.output_proj = Linear(self.d_model, self.d_model, device=device, dtype=dtype)
374
+
375
+ self.rope = RotaryPositionalEmbedding(self.theta, self.d_k, self.max_seq_len, device)
376
+
377
+ def forward(
378
+ self,
379
+ x: torch.Tensor,
380
+ context: torch.Tensor,
381
+ token_positions: torch.Tensor,
382
+ context_token_positions: torch.Tensor,
383
+ attention_mask: torch.Tensor | None = None,
384
+ ) -> torch.Tensor:
385
+ wqx = self.q_proj(x)
386
+ wqx_rearr = rearrange(wqx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
387
+ wqx_rearr_rope = self.rope(wqx_rearr, token_positions)
388
+
389
+ wkx = self.k_proj(context)
390
+ wkx_rearr = rearrange(wkx, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads=self.num_heads, d_k=self.d_k)
391
+ wkx_rearr_rope = self.rope(wkx_rearr, context_token_positions)
392
+
393
+ wvx = self.v_proj(context)
394
+ wvx_rearr = rearrange(wvx, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads=self.num_heads, d_v=self.d_v)
395
+
396
+ if self.attention_backend == "torch_sdpa":
397
+ attn = torch_scaled_dot_product_attention(
398
+ wqx_rearr_rope,
399
+ wkx_rearr_rope,
400
+ wvx_rearr,
401
+ attention_mask=attention_mask,
402
+ )
403
+ else:
404
+ attn = scaled_dot_product_attention(
405
+ wqx_rearr_rope,
406
+ wkx_rearr_rope,
407
+ wvx_rearr,
408
+ attention_mask=attention_mask,
409
+ )
410
+ attn_rearr = rearrange(attn, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)", num_heads=self.num_heads, d_v=self.d_v)
411
+ attn_rearr_proj = self.output_proj(attn_rearr)
412
+ return attn_rearr_proj
413
+
414
+
415
+ class TransformerImageBlock(torch.nn.Module):
416
+ def __init__(
417
+ self,
418
+ d_model: int,
419
+ num_heads: int,
420
+ max_seq_len: int,
421
+ theta: float,
422
+ d_ff: int,
423
+ attention_backend: str = "custom",
424
+ device=None,
425
+ dtype=None,
426
+ ):
427
+ super().__init__()
428
+ self.ffn = SwiGLU(d_model, d_ff, device, dtype)
429
+ self.self_attn = MultiheadSelfAttentionRoPE(
430
+ d_model,
431
+ num_heads,
432
+ max_seq_len,
433
+ theta,
434
+ attention_backend=attention_backend,
435
+ device=device,
436
+ dtype=dtype,
437
+ )
438
+ self.cross_attn = MultiheadCrossAttentionRoPE(
439
+ d_model,
440
+ num_heads,
441
+ max_seq_len,
442
+ theta,
443
+ attention_backend=attention_backend,
444
+ device=device,
445
+ dtype=dtype,
446
+ )
447
+ self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
448
+ self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
449
+ self.ln3 = RMSNorm(d_model, device=device, dtype=dtype)
450
+
451
+ def forward(
452
+ self,
453
+ x: torch.Tensor,
454
+ token_positions: torch.Tensor,
455
+ context: torch.Tensor,
456
+ context_token_positions: torch.Tensor,
457
+ attention_mask: torch.Tensor | None = None,
458
+ ) -> torch.Tensor:
459
+ ln1x = self.ln1(x)
460
+ x = x + self.self_attn(ln1x, token_positions, attention_mask=attention_mask)
461
+ ln2x = self.ln2(x)
462
+ x = x + self.cross_attn(
463
+ ln2x,
464
+ context,
465
+ token_positions,
466
+ context_token_positions,
467
+ attention_mask=None,
468
+ )
469
+ ln3x = self.ln3(x)
470
+ x = x + self.ffn(ln3x)
471
+ return x
472
+
473
+
474
+ class TransformerImage(torch.nn.Module):
475
+ def __init__(
476
+ self,
477
+ vocab_size: int,
478
+ context_length: int,
479
+ d_model: int,
480
+ num_layers: int,
481
+ num_heads: int,
482
+ d_ff: int,
483
+ rope_theta: float,
484
+ label_vocab_size: int,
485
+ attention_backend: str = "custom",
486
+ device=None,
487
+ dtype=None,
488
+ ):
489
+ super().__init__()
490
+ self.context_length = context_length
491
+ self.token_embeddings = Embedding(vocab_size, d_model, device, dtype)
492
+ self.label_embeddings = Embedding(label_vocab_size, d_model, device, dtype)
493
+ self.layers = torch.nn.ModuleList(
494
+ [
495
+ TransformerImageBlock(
496
+ d_model,
497
+ num_heads,
498
+ context_length,
499
+ rope_theta,
500
+ d_ff,
501
+ attention_backend=attention_backend,
502
+ device=device,
503
+ dtype=dtype,
504
+ )
505
+ for _ in range(num_layers)
506
+ ]
507
+ )
508
+ self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
509
+ self.lm_head = Linear(d_model, vocab_size, device, dtype)
510
+
511
+ def forward(
512
+ self,
513
+ in_indices: torch.Tensor,
514
+ attention_mask: torch.Tensor | None = None,
515
+ context: torch.Tensor | None = None,
516
+ ) -> torch.Tensor:
517
+ if context is None:
518
+ raise ValueError("context must be provided for TransformerImage")
519
+ output_seq = self.token_embeddings(in_indices)
520
+ context_emb = self.label_embeddings(context).unsqueeze(-2)
521
+ token_positions = torch.arange(output_seq.shape[-2], device=output_seq.device, dtype=torch.long)
522
+ context_token_positions = torch.arange(context_emb.shape[-2], device=output_seq.device, dtype=torch.long)
523
+ for layer in self.layers:
524
+ output_seq = layer(
525
+ output_seq,
526
+ token_positions,
527
+ context_emb,
528
+ context_token_positions,
529
+ attention_mask=attention_mask,
530
+ )
531
+ normed_output_seq = self.ln_final(output_seq)
532
+ logits = self.lm_head(normed_output_seq)
533
+ return logits
534
+
535
+
536
+ @torch.no_grad()
537
+ def image_diffusion_generate(
538
+ model,
539
+ prompt_indices: torch.Tensor,
540
+ *,
541
+ context: torch.Tensor,
542
+ mask_id: int,
543
+ eos_token_id: int | None = None,
544
+ steps: int,
545
+ gen_length: int,
546
+ block_length: int,
547
+ temperature: float = 0.0,
548
+ top_p: float | None = None,
549
+ cfg_scale: float = 0.0,
550
+ uncond_context: torch.Tensor | None = None,
551
+ remasking: str = "random",
552
+ logits_eos_inf: bool = False,
553
+ confidence_eos_eot_inf: bool = False,
554
+ generator: torch.Generator | None = None,
555
+ ) -> torch.Tensor:
556
+ if prompt_indices.dim() != 2:
557
+ raise ValueError("prompt_indices must be 2D (batch, seq)")
558
+ if context.dim() != 1:
559
+ raise ValueError("context must be 1D (batch,)")
560
+ if prompt_indices.shape[0] != context.shape[0]:
561
+ raise ValueError("context batch size must match prompt batch size")
562
+ if block_length <= 0:
563
+ raise ValueError("block_length must be > 0")
564
+ if steps <= 0:
565
+ raise ValueError("steps must be > 0")
566
+
567
+ if gen_length <= 0:
568
+ return prompt_indices
569
+
570
+ blocks = max(1, int(np.ceil(gen_length / block_length)))
571
+ if steps < blocks:
572
+ raise ValueError("steps must be >= number of blocks")
573
+ base_steps = steps // blocks
574
+ extra_steps = steps % blocks
575
+
576
+ device = prompt_indices.device
577
+ batch_size, prompt_len = prompt_indices.shape
578
+ total_len = prompt_len + gen_length
579
+
580
+ context_limit = getattr(model, "context_length", None)
581
+ if context_limit is not None and total_len > int(context_limit):
582
+ raise ValueError("prompt length + gen_length exceeds model context_length")
583
+
584
+ x = torch.full(
585
+ (batch_size, total_len),
586
+ fill_value=mask_id,
587
+ device=device,
588
+ dtype=prompt_indices.dtype,
589
+ )
590
+ x[:, :prompt_len] = prompt_indices
591
+
592
+ if uncond_context is not None:
593
+ if uncond_context.dim() != 1:
594
+ raise ValueError("uncond_context must be 1D (batch,)")
595
+ if uncond_context.shape[0] != batch_size:
596
+ raise ValueError("uncond_context batch size must match prompt batch size")
597
+ uncond_context = uncond_context.to(device=context.device, dtype=context.dtype)
598
+
599
+ for block_idx in range(blocks):
600
+ block_start = prompt_len + block_idx * block_length
601
+ block_end = min(block_start + block_length, total_len)
602
+ block_steps = base_steps + (1 if block_idx < extra_steps else 0)
603
+ if block_steps <= 0:
604
+ block_steps = 1
605
+ block_mask = (x[:, block_start:block_end] == mask_id)
606
+ transfer_counts = compute_transfer_schedule(block_mask, block_steps)
607
+
608
+ for step_idx in range(block_steps):
609
+ mask_index = (x == mask_id)
610
+ if cfg_scale > 0.0:
611
+ if uncond_context is None:
612
+ raise ValueError("uncond_context must be set when cfg_scale > 0 for image_diffusion_generate")
613
+ cond_logits = model(x, context=context)
614
+ uncond_logits = model(x, context=uncond_context)
615
+ logits = uncond_logits + (cfg_scale + 1.0) * (cond_logits - uncond_logits)
616
+ else:
617
+ logits = model(x, context=context)
618
+
619
+ if logits_eos_inf and eos_token_id is not None:
620
+ logits[:, :, eos_token_id] = float("-inf")
621
+
622
+ if top_p is not None:
623
+ probs = softmax(logits, dim=-1)
624
+ probs = top_p_filter(probs, float(top_p))
625
+ logits = torch.where(
626
+ probs > 0,
627
+ logits,
628
+ torch.full_like(logits, float("-inf")),
629
+ )
630
+
631
+ logits_with_noise = add_gumbel_noise(logits, temperature, generator=generator)
632
+ predictions = torch.argmax(logits_with_noise, dim=-1)
633
+ predictions = torch.where(mask_index, predictions, x)
634
+
635
+ if remasking == "low_confidence":
636
+ probs = softmax(logits, dim=-1)
637
+ confidence = torch.squeeze(
638
+ torch.gather(probs, dim=-1, index=torch.unsqueeze(predictions, -1)),
639
+ -1,
640
+ )
641
+ elif remasking == "random":
642
+ confidence = torch.rand(
643
+ (batch_size, total_len),
644
+ device=device,
645
+ dtype=torch.float32,
646
+ generator=generator,
647
+ )
648
+ else:
649
+ raise ValueError(f"Unsupported remasking strategy: {remasking}")
650
+
651
+ if confidence_eos_eot_inf and eos_token_id is not None:
652
+ confidence = torch.where(
653
+ predictions == eos_token_id,
654
+ torch.full_like(confidence, float("-inf")),
655
+ confidence,
656
+ )
657
+
658
+ confidence[:, block_end:] = float("-inf")
659
+ confidence = torch.where(mask_index, confidence, torch.full_like(confidence, float("-inf")))
660
+
661
+ transfer_mask = torch.zeros_like(mask_index)
662
+ for b in range(batch_size):
663
+ k = int(transfer_counts[b, step_idx].item())
664
+ if k <= 0:
665
+ continue
666
+ available = confidence[b] > float("-inf")
667
+ available_count = int(available.sum().item())
668
+ if available_count == 0:
669
+ continue
670
+ if available_count < k:
671
+ k = available_count
672
+ topk_indices = torch.topk(confidence[b], k=k, dim=-1).indices
673
+ transfer_mask[b, topk_indices] = True
674
+
675
+ x = torch.where(transfer_mask, predictions, x)
676
+
677
+ return x
678
+
679
+
680
+ def dequantize_tokens_to_uint8(tokens: np.ndarray, *, pixel_bins: int) -> np.ndarray:
681
+ if pixel_bins == 256:
682
+ return tokens.astype(np.uint8)
683
+ vals = np.clip(tokens.astype(np.int32), 0, int(pixel_bins) - 1)
684
+ scale = 256.0 / float(pixel_bins)
685
+ restored = np.round((vals + 0.5) * scale - 0.5)
686
+ return np.clip(restored, 0, 255).astype(np.uint8)
687
+
688
+
689
+ MODEL = None
690
+ DEVICE = None
691
+ DTYPE = None
692
+
693
+
694
+ def load_model():
695
+ global MODEL, DEVICE, DTYPE
696
+ if MODEL is not None:
697
+ return MODEL, DEVICE, DTYPE
698
+
699
+ if not os.path.exists(CHECKPOINT_PATH):
700
+ raise FileNotFoundError(f"Missing checkpoint at {CHECKPOINT_PATH}")
701
+
702
+ device, dtype = _resolve_device_dtype(MODEL_CONFIG["device"], MODEL_CONFIG["dtype"])
703
+ set_sdp_backend(MODEL_CONFIG["attention_sdp_backend"])
704
+
705
+ model = TransformerImage(
706
+ vocab_size=MODEL_CONFIG["vocab_size"],
707
+ context_length=MODEL_CONFIG["context_length"],
708
+ d_model=MODEL_CONFIG["d_model"],
709
+ num_layers=MODEL_CONFIG["num_layers"],
710
+ num_heads=MODEL_CONFIG["num_heads"],
711
+ d_ff=MODEL_CONFIG["d_ff"],
712
+ rope_theta=MODEL_CONFIG["rope_theta"],
713
+ label_vocab_size=MODEL_CONFIG["label_vocab_size"],
714
+ attention_backend=MODEL_CONFIG["attention_backend"],
715
+ device=device,
716
+ dtype=dtype,
717
+ )
718
+
719
+ model_state = load_file(CHECKPOINT_PATH)
720
+ model.load_state_dict(model_state)
721
+ model.eval().to(device)
722
+
723
+ MODEL = model
724
+ DEVICE = device
725
+ DTYPE = dtype
726
+ return MODEL, DEVICE, DTYPE
727
+
728
+
729
+ @torch.inference_mode()
730
+ def generate_images(label: int, steps: int, num_samples: int) -> List[Image.Image]:
731
+ model, device, _ = load_model()
732
+
733
+ num_samples = int(num_samples)
734
+ label = int(label)
735
+ steps = int(steps)
736
+
737
+ context = torch.full((num_samples,), label, device=device, dtype=torch.long)
738
+ prompt = torch.empty((num_samples, 0), device=device, dtype=torch.long)
739
+
740
+ cfg_scale = float(INFER_CONFIG["cfg_scale"])
741
+ uncond_context = None
742
+ if cfg_scale > 0.0:
743
+ null_label_id = int(MODEL_CONFIG["null_label_id"])
744
+ uncond_context = torch.full((num_samples,), null_label_id, device=device, dtype=torch.long)
745
+
746
+ out_indices = image_diffusion_generate(
747
+ model,
748
+ prompt,
749
+ context=context,
750
+ mask_id=int(MODEL_CONFIG["mask_token_id"]),
751
+ eos_token_id=None,
752
+ steps=steps,
753
+ gen_length=int(MODEL_CONFIG["context_length"]),
754
+ block_length=int(INFER_CONFIG["block_length"]),
755
+ temperature=float(INFER_CONFIG["temperature"]),
756
+ top_p=float(INFER_CONFIG["top_p"]),
757
+ cfg_scale=cfg_scale,
758
+ uncond_context=uncond_context,
759
+ remasking=str(INFER_CONFIG["remasking"]),
760
+ logits_eos_inf=False,
761
+ confidence_eos_eot_inf=False,
762
+ generator=None,
763
+ )
764
+
765
+ h = int(MODEL_CONFIG["image_height"])
766
+ w = int(MODEL_CONFIG["image_width"])
767
+ pixel_bins = int(MODEL_CONFIG["pixel_bins"])
768
+
769
+ images: List[Image.Image] = []
770
+ for i in range(num_samples):
771
+ tokens = out_indices[i].detach().cpu().to(torch.int32).numpy().reshape(h, w)
772
+ arr = dequantize_tokens_to_uint8(tokens, pixel_bins=pixel_bins)
773
+ img = Image.fromarray(arr, mode="L")
774
+ images.append(img)
775
+
776
+ return images
model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f863ca7bfd2fc11fc6cf4f3df57567655a43bf4cf9ccaa66f254ed6ed248c9e0
3
+ size 42058920
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ spaces
3
+ torch
4
+ einops
5
+ safetensors
6
+ numpy
7
+ pillow