josesho's picture
Update app.py
dce0883 verified
import random, time, ast
import torch
import torch.nn.functional as F
import gradio as gr
from wonderwords import RandomWord
from transformers import AutoTokenizer, AutoModel
if torch.cuda.is_available():
# Checks if you have an Nvidia GPU.
# If so, it will use it for inference.
device = "cuda"
elif torch.backends.mps.is_available():
# Checks if you are using Apple Silicon.
# If so, it will take advantage of the integrated GPU.
DEVICE = "mps"
else:
# Else, it will just use your CPU.
DEVICE = "cpu"
print(f"Using device: {DEVICE}")
# PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
try:
# Load model and tokenizer
TOKENIZER = AutoTokenizer.from_pretrained(
"GSAI-ML/LLaDA-8B-Base", trust_remote_code=True
)
MODEL = AutoModel.from_pretrained(
"GSAI-ML/LLaDA-8B-Base",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(DEVICE)
print("Model and Tokenizer loaded.")
except Exception as e:
error_msg = f"Error: {str(e)}"
print(error_msg)
# Constants
MASK_TOKEN = "[MASK]"
MASK_ID = 126336 # The token ID of [MASK] in LLaDA
rw = RandomWord()
def random_sample_without_replacement(sample_size: int,
population_size: int) -> list:
if not (1 <= sample_size <= population_size):
raise ValueError("Sample size must be between 1 and population size.")
selected_indices = set()
while len(selected_indices) < sample_size:
index = random.randrange(population_size)
if index not in selected_indices:
selected_indices.add(index)
yield index
def format_constraints(num_words: int,
max_gen_length: int) -> dict:
"""Format constraints in format: 'position:word, position:word, ...'"""
out = {}
word_list = rw.random_words(num_words)
positions = [i for i in random_sample_without_replacement(num_words,
max_gen_length)]
for j, position in enumerate(positions):
out[position] = word_list[j]
return out
def add_gumbel_noise(logits, temperature):
"""
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
"""
if temperature <= 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (-torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
def get_num_transfer_tokens(mask_index, steps):
"""
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
the expected number of tokens transitioned at each step should be consistent.
This function is designed to precompute the number of tokens that need to be transitioned at each step.
"""
mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = (
torch.zeros(
mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64
)
+ base
)
for i in range(mask_num.size(0)):
num_transfer_tokens[i, : remainder[i]] += 1
return num_transfer_tokens
def generate_response_with_visualization(
model,
tokenizer,
device,
prompt,
gen_length=64,
steps=32,
constraints=None,
temperature=0.0,
cfg_scale=0.0,
block_length=32,
remasking="low_confidence",
):
"""
Generate text with LLaDA model with visualization using the same sampling as in generate.py
Args:
prompt: The prompt
gen_length: Length of text to generate
steps: Number of denoising steps
constraints: Dictionary mapping positions to words
temperature: Sampling temperature
cfg_scale: Classifier-free guidance scale
block_length: Block length for semi-autoregressive generation
remasking: Remasking strategy ('low_confidence' or 'random')
Returns:
List of visualization states showing the progression and final text
"""
# Process constraints
if constraints is None:
constraints = {}
else:
constraints = ast.literal_eval(constraints)
# Convert any string constraints to token IDs
processed_constraints = {}
for pos, word in constraints.items():
tokens = tokenizer.encode(" " + word, add_special_tokens=False)
for i, token_id in enumerate(tokens):
processed_constraints[pos + i] = token_id
# Tokenize the prompt
input_ids = tokenizer(prompt)["input_ids"]
input_ids = torch.tensor(input_ids).to(device).unsqueeze(0)
# For generation
prompt_length = input_ids.shape[1]
# Initialize the sequence with masks for the response part
x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to(
device
)
x[:, :prompt_length] = input_ids.clone()
# Initialize visualization states for the response part
visualization_states = []
# Add initial state (all masked)
initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)]
visualization_states.append(initial_state)
# Apply constraints to the initial state
for pos, token_id in processed_constraints.items():
absolute_pos = prompt_length + pos
if absolute_pos < x.shape[1]:
x[:, absolute_pos] = token_id
# Mark prompt positions to exclude them from masking during classifier-free guidance
prompt_index = x != MASK_ID
# Ensure block_length is valid
if block_length > gen_length:
block_length = gen_length
# Calculate number of blocks
num_blocks = gen_length // block_length
if gen_length % block_length != 0:
num_blocks += 1
# Adjust steps per block
steps_per_block = steps // num_blocks
if steps_per_block < 1:
steps_per_block = 1
# Track the current state of x for visualization
current_x = x.clone()
# Process each block
for num_block in range(num_blocks):
# Calculate the start and end indices for the current block
block_start = prompt_length + num_block * block_length
block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1])
# Get mask indices for the current block
block_mask_index = x[:, block_start:block_end] == MASK_ID
# Skip if no masks in this block
if not block_mask_index.any():
continue
# Calculate number of tokens to unmask at each step
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)
# Process each step
for i in range(steps_per_block):
print(f"Processing step{i}") ## for logging and debugging...
# Get all mask positions in the current sequence
mask_index = x == MASK_ID
# Skip if no masks
if not mask_index.any():
break
# Apply classifier-free guidance if enabled
if cfg_scale > 0.0:
un_x = x.clone()
un_x[prompt_index] = MASK_ID
x_ = torch.cat([x, un_x], dim=0)
logits = model(x_).logits
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = model(x).logits
# Apply Gumbel noise for sampling
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1)
# Calculate confidence scores for remasking
if remasking == "low_confidence":
p = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
) # b, l
elif remasking == "random":
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(
f"Remasking strategy '{remasking}' not implemented"
)
# Don't consider positions beyond the current block
x0_p[:, block_end:] = -float("inf")
# Apply predictions where we have masks
old_x = x.clone()
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -float("inf"))
# Select tokens to unmask based on confidence
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
for j in range(confidence.shape[0]):
# Only consider positions within the current block for unmasking
block_confidence = confidence[j, block_start:block_end]
if i < steps_per_block - 1: # Not the last step
# Take top-k confidences
_, select_indices = torch.topk(
block_confidence,
k=min(
num_transfer_tokens[j, i].item(), block_confidence.numel()
),
)
# Adjust indices to global positions
select_indices = select_indices + block_start
transfer_index[j, select_indices] = True
else: # Last step - unmask everything remaining
transfer_index[j, block_start:block_end] = mask_index[
j, block_start:block_end
]
# Apply the selected tokens
x = torch.where(transfer_index, x0, x)
# Ensure constraints are maintained
for pos, token_id in processed_constraints.items():
absolute_pos = prompt_length + pos
if absolute_pos < x.shape[1]:
x[:, absolute_pos] = token_id
# Create visualization state only for the response part
current_state = []
for i in range(gen_length):
pos = prompt_length + i # Absolute position in the sequence
if x[0, pos] == MASK_ID:
# Still masked
current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks
elif old_x[0, pos] == MASK_ID:
# Newly revealed in this step
token = tokenizer.decode(
[x[0, pos].item()], skip_special_tokens=True
)
# Color based on confidence
confidence = float(x0_p[0, pos].cpu())
if confidence < 0.3:
color = "#FF6666" # Light red
elif confidence < 0.7:
color = "#FFAA33" # Orange
else:
color = "#66CC66" # Light green
current_state.append((token, color))
else:
# Previously revealed
token = tokenizer.decode(
[x[0, pos].item()], skip_special_tokens=True
)
current_state.append((token, "#6699CC")) # Light blue
visualization_states.append(current_state)
# Extract final text (just the assistant's response)
response_tokens = x[0, prompt_length:]
final_text = tokenizer.decode(
response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return visualization_states, final_text
def display_animation(prompt,
constraints,
gen_length,
steps,
temperature,
cfg_scale,
block_length,
remasking,
delay):
try:
vis_states, response_text = generate_response_with_visualization(
model=MODEL,
tokenizer=TOKENIZER,
device=DEVICE,
prompt=prompt,
gen_length=gen_length,
steps=steps,
constraints=constraints,
temperature=temperature,
cfg_scale=cfg_scale,
block_length=block_length,
remasking=remasking,
)
# Return the initial state immediately
yield vis_states[0]#, response_text
# Then animate through visualization states
for state in vis_states[1:]:
time.sleep(delay)
yield state#, response_text
except Exception as e:
error_msg = f"Error: {str(e)}"
print(error_msg)
# Show error in visualization
error_vis = [(error_msg, "red")]
# Produce the error
yield error_vis#, error_msg
with gr.Blocks() as demo:
gr.Markdown("# LLaDA - Large Language Diffusion Model")
num_random_words = gr.Number(minimum=1,
maximum=10,
value=3,
step=1,
label="Number of random words")
len_gen_text = gr.Slider(minimum=10,
maximum=128,
value=64,
step=1,
label="Length of generated text")
random_constraints = gr.Textbox(label="Random words and their positions")
generate_btn = gr.Button("Generate random words for insertion")
generate_btn.click(
fn=format_constraints,
inputs=[num_random_words,len_gen_text],
outputs=[random_constraints])
prompt = gr.Textbox(max_lines=10, label="Your prompt")
with gr.Accordion("Generation Settings", open=False):
with gr.Row():
steps = gr.Slider(
minimum=8, maximum=64, value=16, step=4, label="Denoising Steps"
)
temperature = gr.Slider(
minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature"
)
cfg_scale = gr.Slider(
minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale"
)
with gr.Row():
block_length = gr.Slider(
minimum=8, maximum=64, value=32, step=8, label="Block Length"
)
remasking_strategy = gr.Radio(
choices=["low_confidence", "random"],
value="low_confidence",
label="Remasking Strategy",
)
with gr.Row():
visualization_delay = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.8,
step=0.1,
label="Visualization Delay (seconds)",
)
continue_btn = gr.Button("Continue the prompt!")
vizbox = gr.HighlightedText(label="Output",
combine_adjacent=False,
show_legend=True)
continue_btn.click(fn=display_animation,
inputs=[prompt,
random_constraints,
len_gen_text,
steps,
temperature,
cfg_scale,
block_length,
remasking_strategy,
visualization_delay],
outputs=vizbox )
if __name__ == "__main__":
demo.launch(share=True)