thejagstudio commited on
Commit
c27ef68
Β·
1 Parent(s): 94f7913

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +797 -797
app.py CHANGED
@@ -1,798 +1,798 @@
1
- import gradio as gr
2
- import spaces
3
- import torch
4
- import torch.nn as nn
5
- from torch.nn import functional as F
6
- import numpy as np
7
- import math
8
- import os
9
- import pickle
10
- import requests
11
- import textwrap
12
- import subprocess
13
- import shutil
14
- import time
15
- from dataclasses import dataclass
16
- from typing import Optional
17
- from transformers import AutoTokenizer
18
-
19
- # ==============================================================================
20
- # ------------------------- VERSION 1: SHARED SETUP ----------------------------
21
- # ==============================================================================
22
-
23
- def setup_environment():
24
- """Checks for and sets up the necessary data for V1."""
25
- nano_gpt_repo_path = 'nanoGPT'
26
- data_dir_path = 'shakespeare_char'
27
- meta_path = os.path.join(data_dir_path, 'meta.pkl')
28
-
29
- if os.path.exists(meta_path):
30
- return
31
-
32
- print("Required data not found. Starting one-time setup...")
33
- if not os.path.exists(nano_gpt_repo_path):
34
- try:
35
- subprocess.run(['git', 'clone', 'https://github.com/karpathy/nanoGPT.git'], check=True, capture_output=True, text=True)
36
- except subprocess.CalledProcessError as e:
37
- print(f"Error cloning repository: {e.stderr}")
38
- pass
39
-
40
- source_data_dir = os.path.join(nano_gpt_repo_path, 'data', 'shakespeare_char')
41
- if not os.path.exists(data_dir_path) and os.path.exists(source_data_dir):
42
- shutil.copytree(source_data_dir, data_dir_path)
43
-
44
- # Check if we can run prepare
45
- prepare_script_path = os.path.join(data_dir_path, 'prepare.py')
46
- if os.path.exists(prepare_script_path) and not os.path.exists(meta_path):
47
- subprocess.run(['python', 'prepare.py'], check=True, cwd=data_dir_path, capture_output=True, text=True)
48
-
49
- setup_environment()
50
-
51
- def download_file(url, filename):
52
- if os.path.exists(filename):
53
- return
54
- print(f"Downloading '{filename}'...")
55
- try:
56
- response = requests.get(url, stream=True)
57
- response.raise_for_status()
58
- with open(filename, 'wb') as f:
59
- for chunk in response.iter_content(chunk_size=8192):
60
- f.write(chunk)
61
- except requests.exceptions.RequestException as e:
62
- print(f"Error downloading {url}: {e}")
63
-
64
- # ==============================================================================
65
- # ---------------------- VERSION 1: ARCHITECTURE & LOGIC -----------------------
66
- # ==============================================================================
67
-
68
- # V1 Constants and Meta Loading
69
- v1_data_dir = './shakespeare_char/'
70
- v1_meta_url = 'https://huggingface.co/spaces/thejagstudio/diffusion-gpt/resolve/main/meta.pkl'
71
- v1_meta_path = 'meta.pkl'
72
- download_file(v1_meta_url, v1_meta_path)
73
-
74
- v1_vocab_size = 65 # Fallback
75
- v1_itos = {}
76
- v1_stoi = {}
77
-
78
- if os.path.exists(v1_meta_path):
79
- with open(v1_meta_path, 'rb') as f:
80
- meta = pickle.load(f)
81
- v1_vocab_size = meta['vocab_size']
82
- v1_itos = meta['itos']
83
- v1_stoi = meta['stoi']
84
-
85
- v1_context_length = 256
86
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
-
88
- def v1_decode(indices_tensor: torch.Tensor):
89
- if indices_tensor.dim() > 1:
90
- indices_tensor = indices_tensor.squeeze(0)
91
- indices = indices_tensor.cpu().numpy()
92
- return ''.join([v1_itos.get(i, '?') for i in indices])
93
-
94
- def wrap_text(long_text, width=80):
95
- paragraphs = long_text.splitlines()
96
- wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
97
- return "\n".join(wrapped)
98
-
99
- @dataclass
100
- class V1_GPTConfig:
101
- block_size: int = 1024
102
- vocab_size: int = 50304
103
- n_layer: int = 12
104
- n_head: int = 12
105
- n_embd: int = 768
106
- cond_dim: int = 64
107
- dropout: float = 0.0
108
- bias: bool = False
109
-
110
- class V1_MLP(nn.Module):
111
- def __init__(self, config):
112
- super().__init__()
113
- self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
114
- self.gelu = nn.GELU()
115
- self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
116
- self.dropout = nn.Dropout(config.dropout)
117
- def forward(self, x):
118
- x = self.c_fc(x)
119
- x = self.gelu(x)
120
- x = self.c_proj(x)
121
- x = self.dropout(x)
122
- return x
123
-
124
- class V1_SelfAttention(nn.Module):
125
- def __init__(self, config):
126
- super().__init__()
127
- assert config.n_embd % config.n_head == 0
128
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
129
- self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
130
- self.attn_dropout = nn.Dropout(config.dropout)
131
- self.resid_dropout = nn.Dropout(config.dropout)
132
- self.n_head = config.n_head
133
- self.n_embd = config.n_embd
134
- self.dropout = config.dropout
135
- self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
136
- def forward(self, x):
137
- B, T, C = x.size()
138
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
139
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
140
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
141
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
142
- if self.flash:
143
- y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
144
- else:
145
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
146
- att = F.softmax(att, dim=-1)
147
- att = self.attn_dropout(att)
148
- y = att @ v
149
- y = y.transpose(1, 2).contiguous().view(B, T, C)
150
- y = self.resid_dropout(self.c_proj(y))
151
- return y
152
-
153
- def v1_modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
154
- return x * (1 + scale) + shift
155
-
156
- def v1_bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
157
- if bias is not None:
158
- out = scale * (x + bias)
159
- else:
160
- out = scale * x
161
- if residual is not None:
162
- out = residual + out
163
- return out
164
-
165
- class V1_DDiTBlock(nn.Module):
166
- def __init__(self, config):
167
- super().__init__()
168
- self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
169
- self.attn = V1_SelfAttention(config)
170
- self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
171
- self.mlp = V1_MLP(config)
172
- self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
173
- self.adaLN_modulation.weight.data.zero_()
174
- self.adaLN_modulation.bias.data.zero_()
175
- def forward(self, x, c):
176
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
177
- x_skip = x
178
- x = v1_modulate(self.ln_1(x), shift_msa, scale_msa)
179
- x = self.attn(x)
180
- x = v1_bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
181
- x = v1_bias_add_scale(self.mlp(v1_modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
182
- return x
183
-
184
- class V1_DDitFinalLayer(nn.Module):
185
- def __init__(self, config):
186
- super().__init__()
187
- self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias)
188
- self.linear = nn.Linear(config.n_embd, config.vocab_size)
189
- self.linear.weight.data.zero_()
190
- self.linear.bias.data.zero_()
191
- self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd)
192
- self.adaLN_modulation.weight.data.zero_()
193
- self.adaLN_modulation.bias.data.zero_()
194
- def forward(self, x, c):
195
- shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
196
- x = v1_modulate(self.norm_final(x), shift, scale)
197
- x = self.linear(x)
198
- return x
199
-
200
- class V1_TimestepEmbedder(nn.Module):
201
- def __init__(self, hidden_size, frequency_embedding_size=256):
202
- super().__init__()
203
- self.mlp = nn.Sequential(
204
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
205
- nn.SiLU(),
206
- nn.Linear(hidden_size, hidden_size, bias=True),
207
- )
208
- self.frequency_embedding_size = frequency_embedding_size
209
- @staticmethod
210
- def timestep_embedding(t, dim, max_period=10000):
211
- half = dim // 2
212
- freqs = torch.exp(
213
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
214
- ).to(device=t.device)
215
- args = t[:, None].float() * freqs[None]
216
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
217
- if dim % 2:
218
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
219
- return embedding
220
- def forward(self, t):
221
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
222
- t_emb = self.mlp(t_freq)
223
- return t_emb
224
-
225
- class V1_GPT(nn.Module):
226
- def __init__(self, config):
227
- super().__init__()
228
- assert config.vocab_size is not None
229
- assert config.block_size is not None
230
- self.config = config
231
- self.sigma_map = V1_TimestepEmbedder(config.cond_dim)
232
- self.transformer = nn.ModuleDict(dict(
233
- wte = nn.Embedding(config.vocab_size, config.n_embd),
234
- wpe = nn.Embedding(config.block_size, config.n_embd),
235
- drop = nn.Dropout(config.dropout),
236
- h = nn.ModuleList([V1_DDiTBlock(config) for _ in range(config.n_layer)]),
237
- ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),
238
- ))
239
- self.lm_head = V1_DDitFinalLayer(config)
240
- self.apply(self._init_weights)
241
- for pn, p in self.named_parameters():
242
- if pn.endswith('c_proj.weight'):
243
- torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
244
- def _init_weights(self, module):
245
- if isinstance(module, nn.Linear):
246
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
247
- if module.bias is not None:
248
- torch.nn.init.zeros_(module.bias)
249
- elif isinstance(module, nn.Embedding):
250
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
251
- def forward(self, idx, sigma):
252
- sigma = sigma.reshape(-1)
253
- b, t = idx.size()
254
- c = F.silu(self.sigma_map(sigma))
255
- assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
256
- pos = torch.arange(0, t, dtype=torch.long, device=device)
257
- tok_emb = self.transformer.wte(idx)
258
- pos_emb = self.transformer.wpe(pos)
259
- x = self.transformer.drop(tok_emb + pos_emb)
260
- for block in self.transformer.h:
261
- x = block(x, c)
262
- x = self.transformer.ln_f(x)
263
- x = self.lm_head(x, c)
264
- x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
265
- return x
266
-
267
- class V1_GeometricNoise:
268
- def __init__(self, sigma_min=1e-4, sigma_max=20):
269
- self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]).to(device)
270
- def rate_noise(self, t):
271
- return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
272
- def total_noise(self, t):
273
- return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
274
- def __call__(self, t):
275
- return self.total_noise(t), self.rate_noise(t)
276
-
277
- # --- V1 Inference Logic ---
278
- def v1_transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
279
- base_prob = (1 - torch.exp(-delta_sigma[..., None])) / v1_vocab_size
280
- trans = torch.ones(*x_t.shape, v1_vocab_size, device=x_t.device) * base_prob
281
- trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))
282
- diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
283
- trans = trans.scatter(-1, x_t[..., None], diag_fill)
284
- return trans
285
-
286
- def v1_staggered_score(score, delta_sigma):
287
- exp_factor = torch.exp(-delta_sigma)[..., None]
288
- correction = ((exp_factor - 1) / (v1_vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True)
289
- return correction + score / exp_factor
290
-
291
- def v1_sample_categorical(probs: torch.Tensor) -> torch.Tensor:
292
- eps = 1e-10
293
- gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
294
- return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)
295
-
296
- # --- V1 Model Loading ---
297
- print("Initializing V1 Model...")
298
- v1_model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
299
- bias=False, vocab_size=v1_vocab_size, block_size=v1_context_length, dropout=0.2)
300
- v1_config = V1_GPTConfig(**v1_model_args)
301
- v1_model = V1_GPT(v1_config)
302
- try:
303
- v1_model.load_state_dict(
304
- torch.hub.load_state_dict_from_url(
305
- 'https://huggingface.co/spaces/thejagstudio/diffusion-gpt/resolve/main/final_model.pth?download=true',
306
- map_location=device
307
- )
308
- )
309
- v1_model.to(device)
310
- v1_model.eval()
311
- print("V1 Model loaded successfully.")
312
- except Exception as e:
313
- print(f"Failed to load V1 model: {e}")
314
- v1_model = None
315
-
316
- v1_noise = V1_GeometricNoise(sigma_min=1e-4, sigma_max=20)
317
-
318
-
319
- def v1_generate_stream(steps, speed):
320
- """
321
- Generator function for V1 that yields frames directly.
322
- Combined logic of generation and replay to allow for immediate stopping.
323
- """
324
- if v1_model is None:
325
- yield "Error: V1 Model not loaded"
326
- return
327
-
328
- steps = int(steps)
329
- speed = float(speed)
330
- eps = 1e-5
331
-
332
- # Calculate delay based on speed slider (similar to V2)
333
- # 0.5 is base constant, speed scales it down
334
- delay = 0.5 / max(speed, 0.1)
335
-
336
- x = torch.randint(0, v1_vocab_size, (1, v1_context_length), device=device)
337
- initial_text = f"--- Initial Random Noise ---\n\n{wrap_text(v1_decode(x[0]))}"
338
- yield initial_text
339
- time.sleep(delay)
340
-
341
- timesteps = torch.linspace(1, eps, steps + 1, device=device)
342
- step_size = (1 - eps) / steps
343
-
344
- with torch.no_grad():
345
- for i in range(steps):
346
- t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
347
- curr_sigma_bar = v1_noise(t)[0]
348
-
349
- next_sigma_bar = v1_noise(t - step_size)[0]
350
- delta_sigma = curr_sigma_bar - next_sigma_bar
351
-
352
- log_score = v1_model(x, curr_sigma_bar)
353
- score = torch.exp(log_score)
354
-
355
- stag_score = v1_staggered_score(score, delta_sigma)
356
- probs = stag_score * v1_transition(x, delta_sigma)
357
- x = v1_sample_categorical(probs)
358
-
359
- progress_text = f"--- Denoising Step {i + 1}/{steps} ---\n\n{wrap_text(v1_decode(x[0]))}"
360
- yield progress_text
361
-
362
- # Artificial delay for visualization
363
- if speed < 20:
364
- time.sleep(delay)
365
-
366
- t = timesteps[steps] * torch.ones(x.shape[0], 1, device=device)
367
- curr_sigma_bar = v1_noise(t)[0]
368
- delta_sigma = curr_sigma_bar
369
-
370
- log_score = v1_model(x, curr_sigma_bar)
371
- score = torch.exp(log_score)
372
- stag_score = v1_staggered_score(score, delta_sigma)
373
- probs = stag_score * v1_transition(x, delta_sigma)
374
- x = v1_sample_categorical(probs)
375
-
376
- final_text = f"--- Final Denoised Text (Step {steps}) ---\n\n{wrap_text(v1_decode(x[0]))}"
377
- yield final_text
378
-
379
- # ==============================================================================
380
- # ---------------------- VERSION 2: ARCHITECTURE & LOGIC -----------------------
381
- # ==============================================================================
382
-
383
- # PLEASE UPDATE THIS PATH TO YOUR ACTUAL LOCAL FILE OR URL
384
- V2_MODEL_PATH = "checkpoints/model_fp32.pt"
385
-
386
- class V2_RMSNorm(nn.Module):
387
- def __init__(self, dim: int, eps: float = 1e-6):
388
- super().__init__()
389
- self.eps = eps
390
- self.weight = nn.Parameter(torch.ones(dim))
391
-
392
- def forward(self, x):
393
- var = x.pow(2).mean(-1, keepdim=True)
394
- x = x * torch.rsqrt(var + self.eps)
395
- return self.weight * x
396
-
397
- class V2_RotaryEmbedding(nn.Module):
398
- def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0):
399
- super().__init__()
400
- self.scaling_factor = scaling_factor
401
- self.dim = dim
402
- self.base = base
403
- self.max_position_embeddings = max_position_embeddings
404
- self.inv_freq = None
405
- self._cache = {}
406
-
407
- def _update_freqs(self, device):
408
- base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2)))
409
- inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
410
- self.inv_freq = inv_freq
411
-
412
- def forward(self, x, seq_len=None):
413
- if seq_len is None:
414
- seq_len = x.shape[-2]
415
-
416
- if self.inv_freq is None or self.inv_freq.device != x.device:
417
- self._update_freqs(x.device)
418
-
419
- cache_key = (seq_len, x.device, x.dtype)
420
- if cache_key in self._cache:
421
- return self._cache[cache_key]
422
-
423
- t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
424
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
425
- emb = torch.cat((freqs, freqs), dim=-1)
426
-
427
- cos = emb.cos()[None, None, :, :]
428
- sin = emb.sin()[None, None, :, :]
429
-
430
- self._cache[cache_key] = (cos, sin)
431
- if len(self._cache) > 10:
432
- self._cache.pop(next(iter(self._cache)))
433
-
434
- return cos, sin
435
-
436
- def v2_apply_rotary_pos_emb(q, k, cos, sin):
437
- def rotate_half(x):
438
- x1 = x[..., : x.shape[-1] // 2]
439
- x2 = x[..., x.shape[-1] // 2 :]
440
- return torch.cat((-x2, x1), dim=-1)
441
- q_embed = (q * cos) + (rotate_half(q) * sin)
442
- k_embed = (k * cos) + (rotate_half(k) * sin)
443
- return q_embed, k_embed
444
-
445
- class V2_DiffusionAttention(nn.Module):
446
- def __init__(self, config):
447
- super().__init__()
448
- self.hidden_size = config.hidden_size
449
- self.num_heads = config.num_attention_heads
450
- self.head_dim = self.hidden_size // self.num_heads
451
- self.num_key_value_heads = config.num_key_value_heads
452
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
453
- self.use_flash_attn = config.use_flash_attn
454
-
455
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
456
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
457
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
458
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
459
-
460
- def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None):
461
- bsz, q_len, _ = hidden_states.size()
462
-
463
- q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
464
- k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
465
- v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
466
-
467
- cos, sin = freqs_cis
468
- cos = cos[:, :, :q_len, :]
469
- sin = sin[:, :, :q_len, :]
470
- q, k = v2_apply_rotary_pos_emb(q, k, cos, sin)
471
-
472
- if past_kv is not None:
473
- cache_k, cache_v = past_kv
474
- k = torch.cat([cache_k, k], dim=2)
475
- v = torch.cat([cache_v, v], dim=2)
476
-
477
- current_kv = (k, v)
478
- k = k.repeat_interleave(self.num_key_value_groups, dim=1)
479
- v = v.repeat_interleave(self.num_key_value_groups, dim=1)
480
-
481
- attn_mask = None
482
- if attention_mask is not None:
483
- attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype)
484
- attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
485
-
486
- output = F.scaled_dot_product_attention(
487
- q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
488
- )
489
-
490
- output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
491
- return self.o_proj(output), current_kv
492
-
493
- class V2_MLP(nn.Module):
494
- def __init__(self, config):
495
- super().__init__()
496
- self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
497
- self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
498
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
499
- self.act_fn = nn.SiLU()
500
-
501
- def forward(self, x):
502
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
503
-
504
- class V2_BlockDiffusionBlock(nn.Module):
505
- def __init__(self, config):
506
- super().__init__()
507
- self.self_attn = V2_DiffusionAttention(config)
508
- self.mlp = V2_MLP(config)
509
- self.input_layernorm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
510
- self.post_attention_layernorm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
511
- self.use_activation_checkpointing = config.use_activation_checkpointing
512
-
513
- def forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
514
- return self._forward(hidden_states, freqs_cis, attention_mask, past_kv)
515
-
516
- def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
517
- residual = hidden_states
518
- hidden_states = self.input_layernorm(hidden_states)
519
- attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv)
520
- hidden_states = residual + attn_out
521
-
522
- residual = hidden_states
523
- hidden_states = self.post_attention_layernorm(hidden_states)
524
- hidden_states = residual + self.mlp(hidden_states)
525
- return hidden_states, new_kv
526
-
527
- @dataclass
528
- class V2_ModelConfig:
529
- vocab_size: int = 151936
530
- hidden_size: int = 1024
531
- intermediate_size: int = 2816
532
- num_hidden_layers: int = 16
533
- num_attention_heads: int = 16
534
- num_key_value_heads: int = 4
535
- max_position_embeddings: int = 16384
536
- rms_norm_eps: float = 1e-6
537
- rope_theta: float = 100000.0
538
- pad_token_id: int = 0
539
- mask_token_id: int = 1
540
- use_flash_attn: bool = True
541
- use_activation_checkpointing: bool = False
542
- attention_dropout: float = 0.0
543
- hidden_dropout: float = 0.0
544
-
545
- ModelConfig = V2_ModelConfig
546
-
547
- class V2_DiffusionLLM(nn.Module):
548
- def __init__(self, config: V2_ModelConfig):
549
- super().__init__()
550
- self.config = config
551
- pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None
552
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx)
553
-
554
- self.layers = nn.ModuleList([V2_BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)])
555
- self.norm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
556
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
557
- self.rotary_emb = V2_RotaryEmbedding(
558
- config.hidden_size // config.num_attention_heads,
559
- config.max_position_embeddings
560
- )
561
- self.lm_head.weight = self.embed_tokens.weight
562
-
563
- def forward(self, input_ids, attention_mask=None, past_key_values=None):
564
- bsz, seqlen = input_ids.shape
565
- hidden_states = self.embed_tokens(input_ids)
566
- freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen)
567
-
568
- if past_key_values is None:
569
- past_key_values = [None] * len(self.layers)
570
-
571
- new_kvs = []
572
- for i, layer in enumerate(self.layers):
573
- hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i])
574
- new_kvs.append(kv)
575
-
576
- hidden_states = self.norm(hidden_states)
577
- logits = self.lm_head(hidden_states)
578
- return logits, new_kvs
579
-
580
- DiffusionLLM = V2_DiffusionLLM
581
-
582
- # --- V2 Loading Logic ---
583
- print("Initializing V2 components...")
584
- v2_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
585
- if v2_tokenizer.pad_token is None:
586
- v2_tokenizer.pad_token = v2_tokenizer.eos_token
587
-
588
- v2_model = None
589
- v2_config = None
590
-
591
- if os.path.exists(V2_MODEL_PATH):
592
- print(f"Loading V2 model from {V2_MODEL_PATH}...")
593
- try:
594
- checkpoint = torch.load(V2_MODEL_PATH, map_location=device, weights_only=False)
595
- v2_config = checkpoint['config']
596
- v2_model = V2_DiffusionLLM(v2_config)
597
- state_dict = checkpoint['model_state']
598
- state_dict = {k: v.float() for k, v in state_dict.items()}
599
- v2_model.load_state_dict(state_dict)
600
- v2_model = v2_model.to(device)
601
- v2_model.eval()
602
- print("V2 Model loaded.")
603
- except Exception as e:
604
- print(f"Error loading V2 model: {e}")
605
- else:
606
- print(f"V2 Model file not found at {V2_MODEL_PATH}. Version 2 tab will not work without it.")
607
-
608
-
609
- @torch.no_grad()
610
- def v2_generate_block_diffusion(prompt, steps, block_size, max_new_tokens, replay_speed):
611
- """
612
- Refactored to yield frames for real-time streaming.
613
- """
614
- if v2_model is None:
615
- yield "Error: V2 Model not found. Check path."
616
- return
617
-
618
- v2_model.eval()
619
- # Handle inputs
620
- steps = int(steps)
621
- block_size = int(block_size)
622
- max_new_tokens = int(max_new_tokens)
623
- speed = float(replay_speed)
624
-
625
- prompt_ids = v2_tokenizer.encode(prompt, return_tensors="pt").to(device)
626
- config = v2_model.config
627
- num_blocks = max_new_tokens // block_size
628
-
629
- context_ids = prompt_ids
630
-
631
- # Helper params
632
- temperature = 1.0
633
- top_k = 40
634
- top_p = 0.9
635
- repetition_penalty = 1.2
636
-
637
- # Calculate delay based on speed slider
638
- delay = 0.5 / max(speed, 0.1)
639
-
640
- for block_idx in range(num_blocks):
641
- mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
642
- is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
643
-
644
- for step_idx in range(steps):
645
- # --- SNAPSHOT & YIELD ---
646
- # Decode context
647
- ctx_str = v2_tokenizer.decode(context_ids[0], skip_special_tokens=True)
648
-
649
- # Decode block with masking visual
650
- block_tokens = mask_block[0].tolist()
651
- block_vis = []
652
- for i, tid in enumerate(block_tokens):
653
- if is_masked[0, i]:
654
- block_vis.append("β–‘") # Mask symbol
655
- else:
656
- block_vis.append(v2_tokenizer.decode([tid], skip_special_tokens=False))
657
-
658
- block_str = "".join(block_vis)
659
-
660
- frame_text = (f"--- Generating Block {block_idx+1}/{num_blocks} | Step {step_idx+1}/{steps} ---\n\n"
661
- f"{ctx_str}{block_str}")
662
-
663
- yield frame_text
664
-
665
- # Artificial delay to visualize the step
666
- if speed < 20: # If max speed, skip sleep
667
- time.sleep(delay)
668
- # ------------------------
669
-
670
- full_input = torch.cat([context_ids, mask_block], dim=1)
671
- attention_mask = torch.ones_like(full_input, dtype=torch.float32)
672
-
673
- logits, _ = v2_model(full_input, attention_mask=attention_mask)
674
- block_logits = logits[:, -block_size:, :]
675
-
676
- # Repetition penalty
677
- if repetition_penalty != 1.0:
678
- seen_tokens = set(context_ids[0].tolist())
679
- for i in range(block_size):
680
- if not is_masked[0, i]:
681
- seen_tokens.add(mask_block[0, i].item())
682
- for token_id in seen_tokens:
683
- if token_id < block_logits.shape[-1]:
684
- if block_logits[0, :, token_id].mean() > 0:
685
- block_logits[:, :, token_id] /= repetition_penalty
686
- else:
687
- block_logits[:, :, token_id] *= repetition_penalty
688
-
689
- block_logits = block_logits / temperature
690
-
691
- # Top-K
692
- if top_k > 0:
693
- top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1)
694
- block_logits = torch.full_like(block_logits, float('-inf'))
695
- block_logits.scatter_(-1, top_k_indices, top_k_logits)
696
-
697
- # Top-P
698
- if top_p < 1.0:
699
- sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
700
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
701
- sorted_indices_to_remove = cumulative_probs > top_p
702
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
703
- sorted_indices_to_remove[..., 0] = 0
704
- indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
705
- block_logits[indices_to_remove] = float('-inf')
706
-
707
- probs = F.softmax(block_logits, dim=-1)
708
- probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
709
- probs = probs.clamp(min=1e-10)
710
- probs = probs / probs.sum(dim=-1, keepdim=True)
711
-
712
- sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
713
- sampled_tokens = sampled_tokens.view(1, block_size)
714
-
715
- confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
716
-
717
- tokens_to_unmask = max(1, block_size // steps)
718
- if step_idx == steps - 1:
719
- tokens_to_unmask = is_masked.sum().item()
720
-
721
- if tokens_to_unmask > 0 and is_masked.sum() > 0:
722
- masked_confidence = confidence.clone()
723
- masked_confidence[~is_masked] = -1.0
724
- num_to_unmask = min(tokens_to_unmask, is_masked.sum().item())
725
- _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
726
-
727
- for idx in top_indices:
728
- mask_block[0, idx] = sampled_tokens[0, idx]
729
- is_masked[0, idx] = False
730
-
731
- context_ids = torch.cat([context_ids, mask_block], dim=1)
732
-
733
- generated_ids = context_ids[0].tolist()
734
- final_text = v2_tokenizer.decode(generated_ids, skip_special_tokens=True)
735
- yield final_text
736
-
737
-
738
- # ==============================================================================
739
- # ------------------------------- GRADIO UI ------------------------------------
740
- # ==============================================================================
741
-
742
- css = '''.gradio-container > .fillable {max-width: 900px !important}
743
- h3{margin-top: 1em}
744
- p{margin-top: 0}
745
- textarea{font-family: monospace; background-color: #1a1b1e; color: #e0e0e0}
746
- '''
747
-
748
- with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
749
- gr.Markdown("# Diffusion Language Models Playground")
750
-
751
- with gr.Tabs():
752
-
753
- # --- TAB 1: VERSION 1 (CHAR DIFFUSION) ---
754
- with gr.Tab("Version 1: Character Diffusion (NanoGPT)"):
755
- gr.Markdown("### Tiny 11M parameter character-based continuous diffusion.")
756
- with gr.Row():
757
- with gr.Column(scale=1):
758
- v1_steps = gr.Slider(64, 512, 128, step=1, label="Denoising Steps")
759
- v1_speed = gr.Slider(1, 20, 10, step=1, label="Generation/Replay Speed")
760
- with gr.Row():
761
- v1_btn = gr.Button("Generate", variant="primary")
762
- v1_stop = gr.Button("Stop", variant="stop")
763
- with gr.Column(scale=2):
764
- v1_out = gr.Textbox(label="Generated Text", lines=15, interactive=False)
765
-
766
- # V1 Logic: Merged generation and replay for proper stopping
767
- v1_event = v1_btn.click(v1_generate_stream, inputs=[v1_steps, v1_speed], outputs=[v1_out])
768
- v1_stop.click(fn=None, inputs=None, outputs=None, cancels=[v1_event])
769
-
770
- # --- TAB 2: VERSION 2 (BLOCK DIFFUSION) ---
771
- with gr.Tab("Version 2: Block Diffusion (LLaDA-style)"):
772
- gr.Markdown("### Block-based diffusion using Qwen tokenizer.")
773
- if v2_model is None:
774
- gr.Warning(f"V2 Model not loaded. Please check path: {V2_MODEL_PATH}")
775
-
776
- with gr.Row():
777
- with gr.Column(scale=1):
778
- v2_prompt = gr.Textbox(label="Prompt", value="The king went to the")
779
- v2_steps = gr.Slider(4, 64, 16, step=1, label="Steps per Block")
780
- v2_block_size = gr.Slider(8, 126, 32, step=8, label="Block Size")
781
- v2_max_tokens = gr.Slider(32, 1024, 128, step=32, label="Total New Tokens")
782
- v2_speed = gr.Slider(1, 20, 1, step=1, label="Generation/Replay Speed")
783
- with gr.Row():
784
- v2_btn = gr.Button("Generate", variant="primary")
785
- v2_stop = gr.Button("Stop", variant="stop")
786
- with gr.Column(scale=2):
787
- v2_out = gr.Textbox(label="Generated Text", lines=15, interactive=False)
788
-
789
- # V2 Logic
790
- v2_event = v2_btn.click(
791
- v2_generate_block_diffusion,
792
- inputs=[v2_prompt, v2_steps, v2_block_size, v2_max_tokens, v2_speed],
793
- outputs=[v2_out]
794
- )
795
- v2_stop.click(fn=None, inputs=None, outputs=None, cancels=[v2_event])
796
-
797
- if __name__ == "__main__":
798
  demo.launch()
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+ import math
8
+ import os
9
+ import pickle
10
+ import requests
11
+ import textwrap
12
+ import subprocess
13
+ import shutil
14
+ import time
15
+ from dataclasses import dataclass
16
+ from typing import Optional
17
+ from transformers import AutoTokenizer
18
+
19
+ # ==============================================================================
20
+ # ------------------------- VERSION 1: SHARED SETUP ----------------------------
21
+ # ==============================================================================
22
+
23
+ def setup_environment():
24
+ """Checks for and sets up the necessary data for V1."""
25
+ nano_gpt_repo_path = 'nanoGPT'
26
+ data_dir_path = 'shakespeare_char'
27
+ meta_path = os.path.join(data_dir_path, 'meta.pkl')
28
+
29
+ if os.path.exists(meta_path):
30
+ return
31
+
32
+ print("Required data not found. Starting one-time setup...")
33
+ if not os.path.exists(nano_gpt_repo_path):
34
+ try:
35
+ subprocess.run(['git', 'clone', 'https://github.com/karpathy/nanoGPT.git'], check=True, capture_output=True, text=True)
36
+ except subprocess.CalledProcessError as e:
37
+ print(f"Error cloning repository: {e.stderr}")
38
+ pass
39
+
40
+ source_data_dir = os.path.join(nano_gpt_repo_path, 'data', 'shakespeare_char')
41
+ if not os.path.exists(data_dir_path) and os.path.exists(source_data_dir):
42
+ shutil.copytree(source_data_dir, data_dir_path)
43
+
44
+ # Check if we can run prepare
45
+ prepare_script_path = os.path.join(data_dir_path, 'prepare.py')
46
+ if os.path.exists(prepare_script_path) and not os.path.exists(meta_path):
47
+ subprocess.run(['python', 'prepare.py'], check=True, cwd=data_dir_path, capture_output=True, text=True)
48
+
49
+ setup_environment()
50
+
51
+ def download_file(url, filename):
52
+ if os.path.exists(filename):
53
+ return
54
+ print(f"Downloading '{filename}'...")
55
+ try:
56
+ response = requests.get(url, stream=True)
57
+ response.raise_for_status()
58
+ with open(filename, 'wb') as f:
59
+ for chunk in response.iter_content(chunk_size=8192):
60
+ f.write(chunk)
61
+ except requests.exceptions.RequestException as e:
62
+ print(f"Error downloading {url}: {e}")
63
+
64
+ # ==============================================================================
65
+ # ---------------------- VERSION 1: ARCHITECTURE & LOGIC -----------------------
66
+ # ==============================================================================
67
+
68
+ # V1 Constants and Meta Loading
69
+ v1_data_dir = './shakespeare_char/'
70
+ v1_meta_url = 'https://huggingface.co/spaces/thejagstudio/NanoDiffusion/resolve/main/meta.pkl'
71
+ v1_meta_path = 'meta.pkl'
72
+ download_file(v1_meta_url, v1_meta_path)
73
+
74
+ v1_vocab_size = 65 # Fallback
75
+ v1_itos = {}
76
+ v1_stoi = {}
77
+
78
+ if os.path.exists(v1_meta_path):
79
+ with open(v1_meta_path, 'rb') as f:
80
+ meta = pickle.load(f)
81
+ v1_vocab_size = meta['vocab_size']
82
+ v1_itos = meta['itos']
83
+ v1_stoi = meta['stoi']
84
+
85
+ v1_context_length = 256
86
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
+
88
+ def v1_decode(indices_tensor: torch.Tensor):
89
+ if indices_tensor.dim() > 1:
90
+ indices_tensor = indices_tensor.squeeze(0)
91
+ indices = indices_tensor.cpu().numpy()
92
+ return ''.join([v1_itos.get(i, '?') for i in indices])
93
+
94
+ def wrap_text(long_text, width=80):
95
+ paragraphs = long_text.splitlines()
96
+ wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
97
+ return "\n".join(wrapped)
98
+
99
+ @dataclass
100
+ class V1_GPTConfig:
101
+ block_size: int = 1024
102
+ vocab_size: int = 50304
103
+ n_layer: int = 12
104
+ n_head: int = 12
105
+ n_embd: int = 768
106
+ cond_dim: int = 64
107
+ dropout: float = 0.0
108
+ bias: bool = False
109
+
110
+ class V1_MLP(nn.Module):
111
+ def __init__(self, config):
112
+ super().__init__()
113
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
114
+ self.gelu = nn.GELU()
115
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
116
+ self.dropout = nn.Dropout(config.dropout)
117
+ def forward(self, x):
118
+ x = self.c_fc(x)
119
+ x = self.gelu(x)
120
+ x = self.c_proj(x)
121
+ x = self.dropout(x)
122
+ return x
123
+
124
+ class V1_SelfAttention(nn.Module):
125
+ def __init__(self, config):
126
+ super().__init__()
127
+ assert config.n_embd % config.n_head == 0
128
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
129
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
130
+ self.attn_dropout = nn.Dropout(config.dropout)
131
+ self.resid_dropout = nn.Dropout(config.dropout)
132
+ self.n_head = config.n_head
133
+ self.n_embd = config.n_embd
134
+ self.dropout = config.dropout
135
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
136
+ def forward(self, x):
137
+ B, T, C = x.size()
138
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
139
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
140
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
141
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
142
+ if self.flash:
143
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
144
+ else:
145
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
146
+ att = F.softmax(att, dim=-1)
147
+ att = self.attn_dropout(att)
148
+ y = att @ v
149
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
150
+ y = self.resid_dropout(self.c_proj(y))
151
+ return y
152
+
153
+ def v1_modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
154
+ return x * (1 + scale) + shift
155
+
156
+ def v1_bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
157
+ if bias is not None:
158
+ out = scale * (x + bias)
159
+ else:
160
+ out = scale * x
161
+ if residual is not None:
162
+ out = residual + out
163
+ return out
164
+
165
+ class V1_DDiTBlock(nn.Module):
166
+ def __init__(self, config):
167
+ super().__init__()
168
+ self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
169
+ self.attn = V1_SelfAttention(config)
170
+ self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
171
+ self.mlp = V1_MLP(config)
172
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
173
+ self.adaLN_modulation.weight.data.zero_()
174
+ self.adaLN_modulation.bias.data.zero_()
175
+ def forward(self, x, c):
176
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
177
+ x_skip = x
178
+ x = v1_modulate(self.ln_1(x), shift_msa, scale_msa)
179
+ x = self.attn(x)
180
+ x = v1_bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
181
+ x = v1_bias_add_scale(self.mlp(v1_modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
182
+ return x
183
+
184
+ class V1_DDitFinalLayer(nn.Module):
185
+ def __init__(self, config):
186
+ super().__init__()
187
+ self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias)
188
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
189
+ self.linear.weight.data.zero_()
190
+ self.linear.bias.data.zero_()
191
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd)
192
+ self.adaLN_modulation.weight.data.zero_()
193
+ self.adaLN_modulation.bias.data.zero_()
194
+ def forward(self, x, c):
195
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
196
+ x = v1_modulate(self.norm_final(x), shift, scale)
197
+ x = self.linear(x)
198
+ return x
199
+
200
+ class V1_TimestepEmbedder(nn.Module):
201
+ def __init__(self, hidden_size, frequency_embedding_size=256):
202
+ super().__init__()
203
+ self.mlp = nn.Sequential(
204
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
205
+ nn.SiLU(),
206
+ nn.Linear(hidden_size, hidden_size, bias=True),
207
+ )
208
+ self.frequency_embedding_size = frequency_embedding_size
209
+ @staticmethod
210
+ def timestep_embedding(t, dim, max_period=10000):
211
+ half = dim // 2
212
+ freqs = torch.exp(
213
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
214
+ ).to(device=t.device)
215
+ args = t[:, None].float() * freqs[None]
216
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
217
+ if dim % 2:
218
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
219
+ return embedding
220
+ def forward(self, t):
221
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
222
+ t_emb = self.mlp(t_freq)
223
+ return t_emb
224
+
225
+ class V1_GPT(nn.Module):
226
+ def __init__(self, config):
227
+ super().__init__()
228
+ assert config.vocab_size is not None
229
+ assert config.block_size is not None
230
+ self.config = config
231
+ self.sigma_map = V1_TimestepEmbedder(config.cond_dim)
232
+ self.transformer = nn.ModuleDict(dict(
233
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
234
+ wpe = nn.Embedding(config.block_size, config.n_embd),
235
+ drop = nn.Dropout(config.dropout),
236
+ h = nn.ModuleList([V1_DDiTBlock(config) for _ in range(config.n_layer)]),
237
+ ln_f = nn.LayerNorm(config.n_embd, bias=config.bias),
238
+ ))
239
+ self.lm_head = V1_DDitFinalLayer(config)
240
+ self.apply(self._init_weights)
241
+ for pn, p in self.named_parameters():
242
+ if pn.endswith('c_proj.weight'):
243
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
244
+ def _init_weights(self, module):
245
+ if isinstance(module, nn.Linear):
246
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
247
+ if module.bias is not None:
248
+ torch.nn.init.zeros_(module.bias)
249
+ elif isinstance(module, nn.Embedding):
250
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
251
+ def forward(self, idx, sigma):
252
+ sigma = sigma.reshape(-1)
253
+ b, t = idx.size()
254
+ c = F.silu(self.sigma_map(sigma))
255
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
256
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
257
+ tok_emb = self.transformer.wte(idx)
258
+ pos_emb = self.transformer.wpe(pos)
259
+ x = self.transformer.drop(tok_emb + pos_emb)
260
+ for block in self.transformer.h:
261
+ x = block(x, c)
262
+ x = self.transformer.ln_f(x)
263
+ x = self.lm_head(x, c)
264
+ x = torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
265
+ return x
266
+
267
+ class V1_GeometricNoise:
268
+ def __init__(self, sigma_min=1e-4, sigma_max=20):
269
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]).to(device)
270
+ def rate_noise(self, t):
271
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
272
+ def total_noise(self, t):
273
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
274
+ def __call__(self, t):
275
+ return self.total_noise(t), self.rate_noise(t)
276
+
277
+ # --- V1 Inference Logic ---
278
+ def v1_transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
279
+ base_prob = (1 - torch.exp(-delta_sigma[..., None])) / v1_vocab_size
280
+ trans = torch.ones(*x_t.shape, v1_vocab_size, device=x_t.device) * base_prob
281
+ trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))
282
+ diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
283
+ trans = trans.scatter(-1, x_t[..., None], diag_fill)
284
+ return trans
285
+
286
+ def v1_staggered_score(score, delta_sigma):
287
+ exp_factor = torch.exp(-delta_sigma)[..., None]
288
+ correction = ((exp_factor - 1) / (v1_vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True)
289
+ return correction + score / exp_factor
290
+
291
+ def v1_sample_categorical(probs: torch.Tensor) -> torch.Tensor:
292
+ eps = 1e-10
293
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
294
+ return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)
295
+
296
+ # --- V1 Model Loading ---
297
+ print("Initializing V1 Model...")
298
+ v1_model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
299
+ bias=False, vocab_size=v1_vocab_size, block_size=v1_context_length, dropout=0.2)
300
+ v1_config = V1_GPTConfig(**v1_model_args)
301
+ v1_model = V1_GPT(v1_config)
302
+ try:
303
+ v1_model.load_state_dict(
304
+ torch.hub.load_state_dict_from_url(
305
+ 'https://huggingface.co/spaces/thejagstudio/diffusion-gpt/resolve/main/final_model.pth?download=true',
306
+ map_location=device
307
+ )
308
+ )
309
+ v1_model.to(device)
310
+ v1_model.eval()
311
+ print("V1 Model loaded successfully.")
312
+ except Exception as e:
313
+ print(f"Failed to load V1 model: {e}")
314
+ v1_model = None
315
+
316
+ v1_noise = V1_GeometricNoise(sigma_min=1e-4, sigma_max=20)
317
+
318
+
319
+ def v1_generate_stream(steps, speed):
320
+ """
321
+ Generator function for V1 that yields frames directly.
322
+ Combined logic of generation and replay to allow for immediate stopping.
323
+ """
324
+ if v1_model is None:
325
+ yield "Error: V1 Model not loaded"
326
+ return
327
+
328
+ steps = int(steps)
329
+ speed = float(speed)
330
+ eps = 1e-5
331
+
332
+ # Calculate delay based on speed slider (similar to V2)
333
+ # 0.5 is base constant, speed scales it down
334
+ delay = 0.5 / max(speed, 0.1)
335
+
336
+ x = torch.randint(0, v1_vocab_size, (1, v1_context_length), device=device)
337
+ initial_text = f"--- Initial Random Noise ---\n\n{wrap_text(v1_decode(x[0]))}"
338
+ yield initial_text
339
+ time.sleep(delay)
340
+
341
+ timesteps = torch.linspace(1, eps, steps + 1, device=device)
342
+ step_size = (1 - eps) / steps
343
+
344
+ with torch.no_grad():
345
+ for i in range(steps):
346
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=device)
347
+ curr_sigma_bar = v1_noise(t)[0]
348
+
349
+ next_sigma_bar = v1_noise(t - step_size)[0]
350
+ delta_sigma = curr_sigma_bar - next_sigma_bar
351
+
352
+ log_score = v1_model(x, curr_sigma_bar)
353
+ score = torch.exp(log_score)
354
+
355
+ stag_score = v1_staggered_score(score, delta_sigma)
356
+ probs = stag_score * v1_transition(x, delta_sigma)
357
+ x = v1_sample_categorical(probs)
358
+
359
+ progress_text = f"--- Denoising Step {i + 1}/{steps} ---\n\n{wrap_text(v1_decode(x[0]))}"
360
+ yield progress_text
361
+
362
+ # Artificial delay for visualization
363
+ if speed < 20:
364
+ time.sleep(delay)
365
+
366
+ t = timesteps[steps] * torch.ones(x.shape[0], 1, device=device)
367
+ curr_sigma_bar = v1_noise(t)[0]
368
+ delta_sigma = curr_sigma_bar
369
+
370
+ log_score = v1_model(x, curr_sigma_bar)
371
+ score = torch.exp(log_score)
372
+ stag_score = v1_staggered_score(score, delta_sigma)
373
+ probs = stag_score * v1_transition(x, delta_sigma)
374
+ x = v1_sample_categorical(probs)
375
+
376
+ final_text = f"--- Final Denoised Text (Step {steps}) ---\n\n{wrap_text(v1_decode(x[0]))}"
377
+ yield final_text
378
+
379
+ # ==============================================================================
380
+ # ---------------------- VERSION 2: ARCHITECTURE & LOGIC -----------------------
381
+ # ==============================================================================
382
+
383
+ # PLEASE UPDATE THIS PATH TO YOUR ACTUAL LOCAL FILE OR URL
384
+ V2_MODEL_PATH = "checkpoints/model_fp32.pt"
385
+
386
+ class V2_RMSNorm(nn.Module):
387
+ def __init__(self, dim: int, eps: float = 1e-6):
388
+ super().__init__()
389
+ self.eps = eps
390
+ self.weight = nn.Parameter(torch.ones(dim))
391
+
392
+ def forward(self, x):
393
+ var = x.pow(2).mean(-1, keepdim=True)
394
+ x = x * torch.rsqrt(var + self.eps)
395
+ return self.weight * x
396
+
397
+ class V2_RotaryEmbedding(nn.Module):
398
+ def __init__(self, dim, max_position_embeddings=16384, base=100000, scaling_factor=1.0):
399
+ super().__init__()
400
+ self.scaling_factor = scaling_factor
401
+ self.dim = dim
402
+ self.base = base
403
+ self.max_position_embeddings = max_position_embeddings
404
+ self.inv_freq = None
405
+ self._cache = {}
406
+
407
+ def _update_freqs(self, device):
408
+ base = self.base * (self.scaling_factor ** (self.dim / (self.dim - 2)))
409
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
410
+ self.inv_freq = inv_freq
411
+
412
+ def forward(self, x, seq_len=None):
413
+ if seq_len is None:
414
+ seq_len = x.shape[-2]
415
+
416
+ if self.inv_freq is None or self.inv_freq.device != x.device:
417
+ self._update_freqs(x.device)
418
+
419
+ cache_key = (seq_len, x.device, x.dtype)
420
+ if cache_key in self._cache:
421
+ return self._cache[cache_key]
422
+
423
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
424
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
425
+ emb = torch.cat((freqs, freqs), dim=-1)
426
+
427
+ cos = emb.cos()[None, None, :, :]
428
+ sin = emb.sin()[None, None, :, :]
429
+
430
+ self._cache[cache_key] = (cos, sin)
431
+ if len(self._cache) > 10:
432
+ self._cache.pop(next(iter(self._cache)))
433
+
434
+ return cos, sin
435
+
436
+ def v2_apply_rotary_pos_emb(q, k, cos, sin):
437
+ def rotate_half(x):
438
+ x1 = x[..., : x.shape[-1] // 2]
439
+ x2 = x[..., x.shape[-1] // 2 :]
440
+ return torch.cat((-x2, x1), dim=-1)
441
+ q_embed = (q * cos) + (rotate_half(q) * sin)
442
+ k_embed = (k * cos) + (rotate_half(k) * sin)
443
+ return q_embed, k_embed
444
+
445
+ class V2_DiffusionAttention(nn.Module):
446
+ def __init__(self, config):
447
+ super().__init__()
448
+ self.hidden_size = config.hidden_size
449
+ self.num_heads = config.num_attention_heads
450
+ self.head_dim = self.hidden_size // self.num_heads
451
+ self.num_key_value_heads = config.num_key_value_heads
452
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
453
+ self.use_flash_attn = config.use_flash_attn
454
+
455
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
456
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
457
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
458
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
459
+
460
+ def forward(self, hidden_states, freqs_cis, attention_mask=None, past_kv=None):
461
+ bsz, q_len, _ = hidden_states.size()
462
+
463
+ q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
464
+ k = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
465
+ v = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
466
+
467
+ cos, sin = freqs_cis
468
+ cos = cos[:, :, :q_len, :]
469
+ sin = sin[:, :, :q_len, :]
470
+ q, k = v2_apply_rotary_pos_emb(q, k, cos, sin)
471
+
472
+ if past_kv is not None:
473
+ cache_k, cache_v = past_kv
474
+ k = torch.cat([cache_k, k], dim=2)
475
+ v = torch.cat([cache_v, v], dim=2)
476
+
477
+ current_kv = (k, v)
478
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1)
479
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1)
480
+
481
+ attn_mask = None
482
+ if attention_mask is not None:
483
+ attn_mask = attention_mask[:, None, None, :].to(dtype=q.dtype)
484
+ attn_mask = (1.0 - attn_mask) * torch.finfo(q.dtype).min
485
+
486
+ output = F.scaled_dot_product_attention(
487
+ q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
488
+ )
489
+
490
+ output = output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
491
+ return self.o_proj(output), current_kv
492
+
493
+ class V2_MLP(nn.Module):
494
+ def __init__(self, config):
495
+ super().__init__()
496
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
497
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
498
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
499
+ self.act_fn = nn.SiLU()
500
+
501
+ def forward(self, x):
502
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
503
+
504
+ class V2_BlockDiffusionBlock(nn.Module):
505
+ def __init__(self, config):
506
+ super().__init__()
507
+ self.self_attn = V2_DiffusionAttention(config)
508
+ self.mlp = V2_MLP(config)
509
+ self.input_layernorm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
510
+ self.post_attention_layernorm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
511
+ self.use_activation_checkpointing = config.use_activation_checkpointing
512
+
513
+ def forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
514
+ return self._forward(hidden_states, freqs_cis, attention_mask, past_kv)
515
+
516
+ def _forward(self, hidden_states, freqs_cis, attention_mask, past_kv):
517
+ residual = hidden_states
518
+ hidden_states = self.input_layernorm(hidden_states)
519
+ attn_out, new_kv = self.self_attn(hidden_states, freqs_cis, attention_mask, past_kv)
520
+ hidden_states = residual + attn_out
521
+
522
+ residual = hidden_states
523
+ hidden_states = self.post_attention_layernorm(hidden_states)
524
+ hidden_states = residual + self.mlp(hidden_states)
525
+ return hidden_states, new_kv
526
+
527
+ @dataclass
528
+ class V2_ModelConfig:
529
+ vocab_size: int = 151936
530
+ hidden_size: int = 1024
531
+ intermediate_size: int = 2816
532
+ num_hidden_layers: int = 16
533
+ num_attention_heads: int = 16
534
+ num_key_value_heads: int = 4
535
+ max_position_embeddings: int = 16384
536
+ rms_norm_eps: float = 1e-6
537
+ rope_theta: float = 100000.0
538
+ pad_token_id: int = 0
539
+ mask_token_id: int = 1
540
+ use_flash_attn: bool = True
541
+ use_activation_checkpointing: bool = False
542
+ attention_dropout: float = 0.0
543
+ hidden_dropout: float = 0.0
544
+
545
+ ModelConfig = V2_ModelConfig
546
+
547
+ class V2_DiffusionLLM(nn.Module):
548
+ def __init__(self, config: V2_ModelConfig):
549
+ super().__init__()
550
+ self.config = config
551
+ pad_idx = config.pad_token_id if config.pad_token_id < config.vocab_size else None
552
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=pad_idx)
553
+
554
+ self.layers = nn.ModuleList([V2_BlockDiffusionBlock(config) for _ in range(config.num_hidden_layers)])
555
+ self.norm = V2_RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
556
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
557
+ self.rotary_emb = V2_RotaryEmbedding(
558
+ config.hidden_size // config.num_attention_heads,
559
+ config.max_position_embeddings
560
+ )
561
+ self.lm_head.weight = self.embed_tokens.weight
562
+
563
+ def forward(self, input_ids, attention_mask=None, past_key_values=None):
564
+ bsz, seqlen = input_ids.shape
565
+ hidden_states = self.embed_tokens(input_ids)
566
+ freqs_cis = self.rotary_emb(hidden_states, seq_len=seqlen)
567
+
568
+ if past_key_values is None:
569
+ past_key_values = [None] * len(self.layers)
570
+
571
+ new_kvs = []
572
+ for i, layer in enumerate(self.layers):
573
+ hidden_states, kv = layer(hidden_states, freqs_cis, attention_mask, past_key_values[i])
574
+ new_kvs.append(kv)
575
+
576
+ hidden_states = self.norm(hidden_states)
577
+ logits = self.lm_head(hidden_states)
578
+ return logits, new_kvs
579
+
580
+ DiffusionLLM = V2_DiffusionLLM
581
+
582
+ # --- V2 Loading Logic ---
583
+ print("Initializing V2 components...")
584
+ v2_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
585
+ if v2_tokenizer.pad_token is None:
586
+ v2_tokenizer.pad_token = v2_tokenizer.eos_token
587
+
588
+ v2_model = None
589
+ v2_config = None
590
+
591
+ if os.path.exists(V2_MODEL_PATH):
592
+ print(f"Loading V2 model from {V2_MODEL_PATH}...")
593
+ try:
594
+ checkpoint = torch.load(V2_MODEL_PATH, map_location=device, weights_only=False)
595
+ v2_config = checkpoint['config']
596
+ v2_model = V2_DiffusionLLM(v2_config)
597
+ state_dict = checkpoint['model_state']
598
+ state_dict = {k: v.float() for k, v in state_dict.items()}
599
+ v2_model.load_state_dict(state_dict)
600
+ v2_model = v2_model.to(device)
601
+ v2_model.eval()
602
+ print("V2 Model loaded.")
603
+ except Exception as e:
604
+ print(f"Error loading V2 model: {e}")
605
+ else:
606
+ print(f"V2 Model file not found at {V2_MODEL_PATH}. Version 2 tab will not work without it.")
607
+
608
+
609
+ @torch.no_grad()
610
+ def v2_generate_block_diffusion(prompt, steps, block_size, max_new_tokens, replay_speed):
611
+ """
612
+ Refactored to yield frames for real-time streaming.
613
+ """
614
+ if v2_model is None:
615
+ yield "Error: V2 Model not found. Check path."
616
+ return
617
+
618
+ v2_model.eval()
619
+ # Handle inputs
620
+ steps = int(steps)
621
+ block_size = int(block_size)
622
+ max_new_tokens = int(max_new_tokens)
623
+ speed = float(replay_speed)
624
+
625
+ prompt_ids = v2_tokenizer.encode(prompt, return_tensors="pt").to(device)
626
+ config = v2_model.config
627
+ num_blocks = max_new_tokens // block_size
628
+
629
+ context_ids = prompt_ids
630
+
631
+ # Helper params
632
+ temperature = 1.0
633
+ top_k = 40
634
+ top_p = 0.9
635
+ repetition_penalty = 1.2
636
+
637
+ # Calculate delay based on speed slider
638
+ delay = 0.5 / max(speed, 0.1)
639
+
640
+ for block_idx in range(num_blocks):
641
+ mask_block = torch.full((1, block_size), config.mask_token_id, device=device)
642
+ is_masked = torch.ones(1, block_size, dtype=torch.bool, device=device)
643
+
644
+ for step_idx in range(steps):
645
+ # --- SNAPSHOT & YIELD ---
646
+ # Decode context
647
+ ctx_str = v2_tokenizer.decode(context_ids[0], skip_special_tokens=True)
648
+
649
+ # Decode block with masking visual
650
+ block_tokens = mask_block[0].tolist()
651
+ block_vis = []
652
+ for i, tid in enumerate(block_tokens):
653
+ if is_masked[0, i]:
654
+ block_vis.append("β–‘") # Mask symbol
655
+ else:
656
+ block_vis.append(v2_tokenizer.decode([tid], skip_special_tokens=False))
657
+
658
+ block_str = "".join(block_vis)
659
+
660
+ frame_text = (f"--- Generating Block {block_idx+1}/{num_blocks} | Step {step_idx+1}/{steps} ---\n\n"
661
+ f"{ctx_str}{block_str}")
662
+
663
+ yield frame_text
664
+
665
+ # Artificial delay to visualize the step
666
+ if speed < 20: # If max speed, skip sleep
667
+ time.sleep(delay)
668
+ # ------------------------
669
+
670
+ full_input = torch.cat([context_ids, mask_block], dim=1)
671
+ attention_mask = torch.ones_like(full_input, dtype=torch.float32)
672
+
673
+ logits, _ = v2_model(full_input, attention_mask=attention_mask)
674
+ block_logits = logits[:, -block_size:, :]
675
+
676
+ # Repetition penalty
677
+ if repetition_penalty != 1.0:
678
+ seen_tokens = set(context_ids[0].tolist())
679
+ for i in range(block_size):
680
+ if not is_masked[0, i]:
681
+ seen_tokens.add(mask_block[0, i].item())
682
+ for token_id in seen_tokens:
683
+ if token_id < block_logits.shape[-1]:
684
+ if block_logits[0, :, token_id].mean() > 0:
685
+ block_logits[:, :, token_id] /= repetition_penalty
686
+ else:
687
+ block_logits[:, :, token_id] *= repetition_penalty
688
+
689
+ block_logits = block_logits / temperature
690
+
691
+ # Top-K
692
+ if top_k > 0:
693
+ top_k_logits, top_k_indices = torch.topk(block_logits, top_k, dim=-1)
694
+ block_logits = torch.full_like(block_logits, float('-inf'))
695
+ block_logits.scatter_(-1, top_k_indices, top_k_logits)
696
+
697
+ # Top-P
698
+ if top_p < 1.0:
699
+ sorted_logits, sorted_indices = torch.sort(block_logits, descending=True, dim=-1)
700
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
701
+ sorted_indices_to_remove = cumulative_probs > top_p
702
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
703
+ sorted_indices_to_remove[..., 0] = 0
704
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
705
+ block_logits[indices_to_remove] = float('-inf')
706
+
707
+ probs = F.softmax(block_logits, dim=-1)
708
+ probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
709
+ probs = probs.clamp(min=1e-10)
710
+ probs = probs / probs.sum(dim=-1, keepdim=True)
711
+
712
+ sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1)
713
+ sampled_tokens = sampled_tokens.view(1, block_size)
714
+
715
+ confidence = probs.gather(-1, sampled_tokens.unsqueeze(-1)).squeeze(-1)
716
+
717
+ tokens_to_unmask = max(1, block_size // steps)
718
+ if step_idx == steps - 1:
719
+ tokens_to_unmask = is_masked.sum().item()
720
+
721
+ if tokens_to_unmask > 0 and is_masked.sum() > 0:
722
+ masked_confidence = confidence.clone()
723
+ masked_confidence[~is_masked] = -1.0
724
+ num_to_unmask = min(tokens_to_unmask, is_masked.sum().item())
725
+ _, top_indices = torch.topk(masked_confidence.view(-1), num_to_unmask)
726
+
727
+ for idx in top_indices:
728
+ mask_block[0, idx] = sampled_tokens[0, idx]
729
+ is_masked[0, idx] = False
730
+
731
+ context_ids = torch.cat([context_ids, mask_block], dim=1)
732
+
733
+ generated_ids = context_ids[0].tolist()
734
+ final_text = v2_tokenizer.decode(generated_ids, skip_special_tokens=True)
735
+ yield final_text
736
+
737
+
738
+ # ==============================================================================
739
+ # ------------------------------- GRADIO UI ------------------------------------
740
+ # ==============================================================================
741
+
742
+ css = '''.gradio-container > .fillable {max-width: 900px !important}
743
+ h3{margin-top: 1em}
744
+ p{margin-top: 0}
745
+ textarea{font-family: monospace; background-color: #1a1b1e; color: #e0e0e0}
746
+ '''
747
+
748
+ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
749
+ gr.Markdown("# Diffusion Language Models Playground")
750
+
751
+ with gr.Tabs():
752
+
753
+ # --- TAB 1: VERSION 1 (CHAR DIFFUSION) ---
754
+ with gr.Tab("Version 1: Character Diffusion (NanoGPT)"):
755
+ gr.Markdown("### Tiny 11M parameter character-based continuous diffusion.")
756
+ with gr.Row():
757
+ with gr.Column(scale=1):
758
+ v1_steps = gr.Slider(64, 512, 128, step=1, label="Denoising Steps")
759
+ v1_speed = gr.Slider(1, 20, 10, step=1, label="Generation/Replay Speed")
760
+ with gr.Row():
761
+ v1_btn = gr.Button("Generate", variant="primary")
762
+ v1_stop = gr.Button("Stop", variant="stop")
763
+ with gr.Column(scale=2):
764
+ v1_out = gr.Textbox(label="Generated Text", lines=15, interactive=False)
765
+
766
+ # V1 Logic: Merged generation and replay for proper stopping
767
+ v1_event = v1_btn.click(v1_generate_stream, inputs=[v1_steps, v1_speed], outputs=[v1_out])
768
+ v1_stop.click(fn=None, inputs=None, outputs=None, cancels=[v1_event])
769
+
770
+ # --- TAB 2: VERSION 2 (BLOCK DIFFUSION) ---
771
+ with gr.Tab("Version 2: Block Diffusion (LLaDA-style)"):
772
+ gr.Markdown("### Block-based diffusion using Qwen tokenizer.")
773
+ if v2_model is None:
774
+ gr.Warning(f"V2 Model not loaded. Please check path: {V2_MODEL_PATH}")
775
+
776
+ with gr.Row():
777
+ with gr.Column(scale=1):
778
+ v2_prompt = gr.Textbox(label="Prompt", value="The king went to the")
779
+ v2_steps = gr.Slider(4, 64, 16, step=1, label="Steps per Block")
780
+ v2_block_size = gr.Slider(8, 126, 32, step=8, label="Block Size")
781
+ v2_max_tokens = gr.Slider(32, 1024, 128, step=32, label="Total New Tokens")
782
+ v2_speed = gr.Slider(1, 20, 1, step=1, label="Generation/Replay Speed")
783
+ with gr.Row():
784
+ v2_btn = gr.Button("Generate", variant="primary")
785
+ v2_stop = gr.Button("Stop", variant="stop")
786
+ with gr.Column(scale=2):
787
+ v2_out = gr.Textbox(label="Generated Text", lines=15, interactive=False)
788
+
789
+ # V2 Logic
790
+ v2_event = v2_btn.click(
791
+ v2_generate_block_diffusion,
792
+ inputs=[v2_prompt, v2_steps, v2_block_size, v2_max_tokens, v2_speed],
793
+ outputs=[v2_out]
794
+ )
795
+ v2_stop.click(fn=None, inputs=None, outputs=None, cancels=[v2_event])
796
+
797
+ if __name__ == "__main__":
798
  demo.launch()