josesho commited on
Commit
5139b04
·
verified ·
1 Parent(s): 14195db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +463 -0
app.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random, time, ast
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ from wonderwords import RandomWord
6
+ from transformers import AutoTokenizer, AutoModel
7
+
8
+
9
+
10
+
11
+
12
+
13
+ if torch.cuda.is_available():
14
+ # Checks if you have an Nvidia GPU.
15
+ # If so, it will use it for inference.
16
+ device = "cuda"
17
+ elif torch.backends.mps.is_available():
18
+ # Checks if you are using Apple Silicon.
19
+ # If so, it will take advantage of the integrated GPU.
20
+ DEVICE = "mps"
21
+ else:
22
+ # Else, it will just use your CPU.
23
+ DEVICE = "cpu"
24
+ print(f"Using device: {DEVICE}")
25
+
26
+
27
+
28
+ PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
29
+ try:
30
+ # Load model and tokenizer
31
+ TOKENIZER = AutoTokenizer.from_pretrained(
32
+ "GSAI-ML/LLaDA-8B-Base", trust_remote_code=True
33
+ )
34
+ MODEL = AutoModel.from_pretrained(
35
+ "GSAI-ML/LLaDA-8B-Base",
36
+ trust_remote_code=True,
37
+ torch_dtype=torch.bfloat16
38
+ ).to(DEVICE)
39
+ print("Model and Tokenizer loaded.")
40
+ except Exception as e:
41
+ error_msg = f"Error: {str(e)}"
42
+ print(error_msg)
43
+
44
+ # Constants
45
+ MASK_TOKEN = "[MASK]"
46
+ MASK_ID = 126336 # The token ID of [MASK] in LLaDA
47
+
48
+
49
+
50
+
51
+
52
+ rw = RandomWord()
53
+
54
+ def random_sample_without_replacement(sample_size: int,
55
+ population_size: int) -> list:
56
+ if not (1 <= sample_size <= population_size):
57
+ raise ValueError("Sample size must be between 1 and population size.")
58
+
59
+ selected_indices = set()
60
+ while len(selected_indices) < sample_size:
61
+ index = random.randrange(population_size)
62
+ if index not in selected_indices:
63
+ selected_indices.add(index)
64
+ yield index
65
+
66
+ def format_constraints(num_words: int,
67
+ max_gen_length: int) -> dict:
68
+ """Format constraints in format: 'position:word, position:word, ...'"""
69
+ out = {}
70
+
71
+ word_list = rw.random_words(num_words)
72
+ positions = [i for i in random_sample_without_replacement(num_words,
73
+ max_gen_length)]
74
+
75
+ for j, position in enumerate(positions):
76
+ out[position] = word_list[j]
77
+ return out
78
+
79
+
80
+ def add_gumbel_noise(logits, temperature):
81
+ """
82
+ The Gumbel max is a method for sampling categorical distributions.
83
+ According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
84
+ Thus, we use float32.
85
+ """
86
+ if temperature <= 0:
87
+ return logits
88
+
89
+ logits = logits.to(torch.float32)
90
+ noise = torch.rand_like(logits, dtype=torch.float32)
91
+ gumbel_noise = (-torch.log(noise)) ** temperature
92
+ return logits.exp() / gumbel_noise
93
+
94
+
95
+ def get_num_transfer_tokens(mask_index, steps):
96
+ """
97
+ In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
98
+ Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
99
+ the expected number of tokens transitioned at each step should be consistent.
100
+
101
+ This function is designed to precompute the number of tokens that need to be transitioned at each step.
102
+ """
103
+ mask_num = mask_index.sum(dim=1, keepdim=True)
104
+
105
+ base = mask_num // steps
106
+ remainder = mask_num % steps
107
+
108
+ num_transfer_tokens = (
109
+ torch.zeros(
110
+ mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64
111
+ )
112
+ + base
113
+ )
114
+
115
+ for i in range(mask_num.size(0)):
116
+ num_transfer_tokens[i, : remainder[i]] += 1
117
+
118
+ return num_transfer_tokens
119
+
120
+
121
+ def generate_response_with_visualization(
122
+ model,
123
+ tokenizer,
124
+ device,
125
+ prompt,
126
+ gen_length=64,
127
+ steps=32,
128
+ constraints=None,
129
+ temperature=0.0,
130
+ cfg_scale=0.0,
131
+ block_length=32,
132
+ remasking="low_confidence",
133
+ ):
134
+ """
135
+ Generate text with LLaDA model with visualization using the same sampling as in generate.py
136
+
137
+ Args:
138
+ prompt: The prompt
139
+ gen_length: Length of text to generate
140
+ steps: Number of denoising steps
141
+ constraints: Dictionary mapping positions to words
142
+ temperature: Sampling temperature
143
+ cfg_scale: Classifier-free guidance scale
144
+ block_length: Block length for semi-autoregressive generation
145
+ remasking: Remasking strategy ('low_confidence' or 'random')
146
+
147
+ Returns:
148
+ List of visualization states showing the progression and final text
149
+ """
150
+
151
+ # Process constraints
152
+ if constraints is None:
153
+ constraints = {}
154
+ else:
155
+ constraints = ast.literal_eval(constraints)
156
+
157
+ # Convert any string constraints to token IDs
158
+ processed_constraints = {}
159
+ for pos, word in constraints.items():
160
+ tokens = tokenizer.encode(" " + word, add_special_tokens=False)
161
+ for i, token_id in enumerate(tokens):
162
+ processed_constraints[pos + i] = token_id
163
+
164
+ # Tokenize the prompt
165
+ input_ids = tokenizer(prompt)["input_ids"]
166
+ input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
167
+
168
+ # For generation
169
+ prompt_length = input_ids.shape[1]
170
+
171
+ # Initialize the sequence with masks for the response part
172
+ x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(
173
+ device
174
+ )
175
+ x[:, :prompt_length] = input_ids.clone()
176
+
177
+ # Initialize visualization states for the response part
178
+ visualization_states = []
179
+
180
+ # Add initial state (all masked)
181
+ initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
182
+ visualization_states.append(initial_state)
183
+
184
+ # Apply constraints to the initial state
185
+ for pos, token_id in processed_constraints.items():
186
+ absolute_pos = prompt_length + pos
187
+ if absolute_pos < x.shape[1]:
188
+ x[:, absolute_pos] = token_id
189
+
190
+ # Mark prompt positions to exclude them from masking during classifier-free guidance
191
+ prompt_index = x != MASK_ID
192
+
193
+ # Ensure block_length is valid
194
+ if block_length > gen_length:
195
+ block_length = gen_length
196
+
197
+ # Calculate number of blocks
198
+ num_blocks = gen_length // block_length
199
+ if gen_length % block_length != 0:
200
+ num_blocks += 1
201
+
202
+ # Adjust steps per block
203
+ steps_per_block = steps // num_blocks
204
+ if steps_per_block < 1:
205
+ steps_per_block = 1
206
+
207
+ # Track the current state of x for visualization
208
+ current_x = x.clone()
209
+
210
+ # Process each block
211
+ for num_block in range(num_blocks):
212
+ # Calculate the start and end indices for the current block
213
+ block_start = prompt_length + num_block * block_length
214
+ block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1])
215
+
216
+ # Get mask indices for the current block
217
+ block_mask_index = x[:, block_start:block_end] == MASK_ID
218
+
219
+ # Skip if no masks in this block
220
+ if not block_mask_index.any():
221
+ continue
222
+
223
+ # Calculate number of tokens to unmask at each step
224
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
225
+
226
+ # Process each step
227
+ for i in range(steps_per_block):
228
+ # Get all mask positions in the current sequence
229
+ mask_index = x == MASK_ID
230
+
231
+ # Skip if no masks
232
+ if not mask_index.any():
233
+ break
234
+
235
+ # Apply classifier-free guidance if enabled
236
+ if cfg_scale > 0.0:
237
+ un_x = x.clone()
238
+ un_x[prompt_index] = MASK_ID
239
+ x_ = torch.cat([x, un_x], dim=0)
240
+ logits = model(x_).logits
241
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
242
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
243
+ else:
244
+ logits = model(x).logits
245
+
246
+ # Apply Gumbel noise for sampling
247
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
248
+ x0 = torch.argmax(logits_with_noise, dim=-1)
249
+
250
+ # Calculate confidence scores for remasking
251
+ if remasking == "low_confidence":
252
+ p = F.softmax(logits.to(torch.float32), dim=-1)
253
+ x0_p = torch.squeeze(
254
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
255
+ ) # b, l
256
+ elif remasking == "random":
257
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
258
+ else:
259
+ raise NotImplementedError(
260
+ f"Remasking strategy '{remasking}' not implemented"
261
+ )
262
+
263
+ # Don't consider positions beyond the current block
264
+ x0_p[:, block_end:] = -float("inf")
265
+
266
+ # Apply predictions where we have masks
267
+ old_x = x.clone()
268
+ x0 = torch.where(mask_index, x0, x)
269
+ confidence = torch.where(mask_index, x0_p, -float("inf"))
270
+
271
+ # Select tokens to unmask based on confidence
272
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
273
+ for j in range(confidence.shape[0]):
274
+ # Only consider positions within the current block for unmasking
275
+ block_confidence = confidence[j, block_start:block_end]
276
+ if i < steps_per_block - 1: # Not the last step
277
+ # Take top-k confidences
278
+ _, select_indices = torch.topk(
279
+ block_confidence,
280
+ k=min(
281
+ num_transfer_tokens[j, i].item(), block_confidence.numel()
282
+ ),
283
+ )
284
+ # Adjust indices to global positions
285
+ select_indices = select_indices + block_start
286
+ transfer_index[j, select_indices] = True
287
+ else: # Last step - unmask everything remaining
288
+ transfer_index[j, block_start:block_end] = mask_index[
289
+ j, block_start:block_end
290
+ ]
291
+
292
+ # Apply the selected tokens
293
+ x = torch.where(transfer_index, x0, x)
294
+
295
+ # Ensure constraints are maintained
296
+ for pos, token_id in processed_constraints.items():
297
+ absolute_pos = prompt_length + pos
298
+ if absolute_pos < x.shape[1]:
299
+ x[:, absolute_pos] = token_id
300
+
301
+ # Create visualization state only for the response part
302
+ current_state = []
303
+ for i in range(gen_length):
304
+ pos = prompt_length + i # Absolute position in the sequence
305
+
306
+ if x[0, pos] == MASK_ID:
307
+ # Still masked
308
+ current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
309
+
310
+ elif old_x[0, pos] == MASK_ID:
311
+ # Newly revealed in this step
312
+ token = tokenizer.decode(
313
+ [x[0, pos].item()], skip_special_tokens=True
314
+ )
315
+ # Color based on confidence
316
+ confidence = float(x0_p[0, pos].cpu())
317
+ if confidence < 0.3:
318
+ color = "#FF6666" # Light red
319
+ elif confidence < 0.7:
320
+ color = "#FFAA33" # Orange
321
+ else:
322
+ color = "#66CC66" # Light green
323
+
324
+ current_state.append((token, color))
325
+
326
+ else:
327
+ # Previously revealed
328
+ token = tokenizer.decode(
329
+ [x[0, pos].item()], skip_special_tokens=True
330
+ )
331
+ current_state.append((token, "#6699CC")) # Light blue
332
+
333
+ visualization_states.append(current_state)
334
+
335
+ # Extract final text (just the assistant's response)
336
+ response_tokens = x[0, prompt_length:]
337
+ final_text = tokenizer.decode(
338
+ response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
339
+ )
340
+
341
+ return visualization_states, final_text
342
+
343
+ def display_animation(prompt,
344
+ constraints,
345
+ gen_length,
346
+ steps,
347
+ temperature,
348
+ cfg_scale,
349
+ block_length,
350
+ remasking,
351
+ delay):
352
+
353
+ try:
354
+ vis_states, response_text = generate_response_with_visualization(
355
+ model=MODEL,
356
+ tokenizer=TOKENIZER,
357
+ device=DEVICE,
358
+ prompt=prompt,
359
+ gen_length=gen_length,
360
+ steps=steps,
361
+ constraints=constraints,
362
+ temperature=temperature,
363
+ cfg_scale=cfg_scale,
364
+ block_length=block_length,
365
+ remasking=remasking,
366
+ )
367
+ # Return the initial state immediately
368
+ yield vis_states[0]#, response_text
369
+
370
+ # Then animate through visualization states
371
+ for state in vis_states[1:]:
372
+ time.sleep(delay)
373
+ yield state#, response_text
374
+
375
+ except Exception as e:
376
+ error_msg = f"Error: {str(e)}"
377
+ print(error_msg)
378
+
379
+ # Show error in visualization
380
+ error_vis = [(error_msg, "red")]
381
+
382
+ # Produce the error
383
+ yield error_vis#, error_msg
384
+
385
+
386
+
387
+ with gr.Blocks() as demo:
388
+ gr.Markdown("# LLaDA - Large Language Diffusion Model")
389
+
390
+ num_random_words = gr.Number(minimum=1,
391
+ maximum=10,
392
+ value=3,
393
+ step=1,
394
+ label="Number of random words")
395
+
396
+ len_gen_text = gr.Slider(minimum=10,
397
+ maximum=64,
398
+ value=32,
399
+ step=1,
400
+ label="Length of generated text")
401
+
402
+ random_constraints = gr.Textbox(label="Random words and their positions")
403
+
404
+ generate_btn = gr.Button("Generate random words for insertion")
405
+ generate_btn.click(
406
+ fn=format_constraints,
407
+ inputs=[num_random_words,len_gen_text],
408
+ outputs=[random_constraints])
409
+
410
+ prompt = gr.Textbox(max_lines=10, label="Your prompt")
411
+
412
+ with gr.Accordion("Generation Settings", open=False):
413
+ with gr.Row():
414
+ steps = gr.Slider(
415
+ minimum=8, maximum=64, value=16, step=4, label="Denoising Steps"
416
+ )
417
+ temperature = gr.Slider(
418
+ minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature"
419
+ )
420
+ cfg_scale = gr.Slider(
421
+ minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale"
422
+ )
423
+ with gr.Row():
424
+ block_length = gr.Slider(
425
+ minimum=8, maximum=64, value=16, step=8, label="Block Length"
426
+ )
427
+ remasking_strategy = gr.Radio(
428
+ choices=["low_confidence", "random"],
429
+ value="low_confidence",
430
+ label="Remasking Strategy",
431
+ )
432
+ with gr.Row():
433
+ visualization_delay = gr.Slider(
434
+ minimum=0.0,
435
+ maximum=1.0,
436
+ value=0.8,
437
+ step=0.1,
438
+ label="Visualization Delay (seconds)",
439
+ )
440
+
441
+ continue_btn = gr.Button("Continue the prompt!")
442
+
443
+ vizbox = gr.HighlightedText(label="Output",
444
+ combine_adjacent=False,
445
+ show_legend=True)
446
+
447
+
448
+ continue_btn.click(fn=display_animation,
449
+ inputs=[prompt,
450
+ random_constraints,
451
+ len_gen_text,
452
+ steps,
453
+ temperature,
454
+ cfg_scale,
455
+ block_length,
456
+ remasking_strategy,
457
+ visualization_delay],
458
+ outputs=vizbox )
459
+
460
+
461
+
462
+ if __name__ == "__main__":
463
+ demo.launch(share=True)