import random, time, ast import torch import torch.nn.functional as F import gradio as gr from wonderwords import RandomWord from transformers import AutoTokenizer, AutoModel if torch.cuda.is_available(): # Checks if you have an Nvidia GPU. # If so, it will use it for inference. device = "cuda" elif torch.backends.mps.is_available(): # Checks if you are using Apple Silicon. # If so, it will take advantage of the integrated GPU. DEVICE = "mps" else: # Else, it will just use your CPU. DEVICE = "cpu" print(f"Using device: {DEVICE}") # PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 try: # Load model and tokenizer TOKENIZER = AutoTokenizer.from_pretrained( "GSAI-ML/LLaDA-8B-Base", trust_remote_code=True ) MODEL = AutoModel.from_pretrained( "GSAI-ML/LLaDA-8B-Base", trust_remote_code=True, torch_dtype=torch.bfloat16 ).to(DEVICE) print("Model and Tokenizer loaded.") except Exception as e: error_msg = f"Error: {str(e)}" print(error_msg) # Constants MASK_TOKEN = "[MASK]" MASK_ID = 126336 # The token ID of [MASK] in LLaDA rw = RandomWord() def random_sample_without_replacement(sample_size: int, population_size: int) -> list: if not (1 <= sample_size <= population_size): raise ValueError("Sample size must be between 1 and population size.") selected_indices = set() while len(selected_indices) < sample_size: index = random.randrange(population_size) if index not in selected_indices: selected_indices.add(index) yield index def format_constraints(num_words: int, max_gen_length: int) -> dict: """Format constraints in format: 'position:word, position:word, ...'""" out = {} word_list = rw.random_words(num_words) positions = [i for i in random_sample_without_replacement(num_words, max_gen_length)] for j, position in enumerate(positions): out[position] = word_list[j] return out def add_gumbel_noise(logits, temperature): """ The Gumbel max is a method for sampling categorical distributions. According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. Thus, we use float64. """ if temperature <= 0: return logits logits = logits.to(torch.float64) noise = torch.rand_like(logits, dtype=torch.float64) gumbel_noise = (-torch.log(noise)) ** temperature return logits.exp() / gumbel_noise def get_num_transfer_tokens(mask_index, steps): """ In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), the expected number of tokens transitioned at each step should be consistent. This function is designed to precompute the number of tokens that need to be transitioned at each step. """ mask_num = mask_index.sum(dim=1, keepdim=True) base = mask_num // steps remainder = mask_num % steps num_transfer_tokens = ( torch.zeros( mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64 ) + base ) for i in range(mask_num.size(0)): num_transfer_tokens[i, : remainder[i]] += 1 return num_transfer_tokens def generate_response_with_visualization( model, tokenizer, device, prompt, gen_length=64, steps=32, constraints=None, temperature=0.0, cfg_scale=0.0, block_length=32, remasking="low_confidence", ): """ Generate text with LLaDA model with visualization using the same sampling as in generate.py Args: prompt: The prompt gen_length: Length of text to generate steps: Number of denoising steps constraints: Dictionary mapping positions to words temperature: Sampling temperature cfg_scale: Classifier-free guidance scale block_length: Block length for semi-autoregressive generation remasking: Remasking strategy ('low_confidence' or 'random') Returns: List of visualization states showing the progression and final text """ # Process constraints if constraints is None: constraints = {} else: constraints = ast.literal_eval(constraints) # Convert any string constraints to token IDs processed_constraints = {} for pos, word in constraints.items(): tokens = tokenizer.encode(" " + word, add_special_tokens=False) for i, token_id in enumerate(tokens): processed_constraints[pos + i] = token_id # Tokenize the prompt input_ids = tokenizer(prompt)["input_ids"] input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) # For generation prompt_length = input_ids.shape[1] # Initialize the sequence with masks for the response part x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to( device ) x[:, :prompt_length] = input_ids.clone() # Initialize visualization states for the response part visualization_states = [] # Add initial state (all masked) initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] visualization_states.append(initial_state) # Apply constraints to the initial state for pos, token_id in processed_constraints.items(): absolute_pos = prompt_length + pos if absolute_pos < x.shape[1]: x[:, absolute_pos] = token_id # Mark prompt positions to exclude them from masking during classifier-free guidance prompt_index = x != MASK_ID # Ensure block_length is valid if block_length > gen_length: block_length = gen_length # Calculate number of blocks num_blocks = gen_length // block_length if gen_length % block_length != 0: num_blocks += 1 # Adjust steps per block steps_per_block = steps // num_blocks if steps_per_block < 1: steps_per_block = 1 # Track the current state of x for visualization current_x = x.clone() # Process each block for num_block in range(num_blocks): # Calculate the start and end indices for the current block block_start = prompt_length + num_block * block_length block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1]) # Get mask indices for the current block block_mask_index = x[:, block_start:block_end] == MASK_ID # Skip if no masks in this block if not block_mask_index.any(): continue # Calculate number of tokens to unmask at each step num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block) # Process each step for i in range(steps_per_block): print(f"Processing step{i}") ## for logging and debugging... # Get all mask positions in the current sequence mask_index = x == MASK_ID # Skip if no masks if not mask_index.any(): break # Apply classifier-free guidance if enabled if cfg_scale > 0.0: un_x = x.clone() un_x[prompt_index] = MASK_ID x_ = torch.cat([x, un_x], dim=0) logits = model(x_).logits logits, un_logits = torch.chunk(logits, 2, dim=0) logits = un_logits + (cfg_scale + 1) * (logits - un_logits) else: logits = model(x).logits # Apply Gumbel noise for sampling logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1) # Calculate confidence scores for remasking if remasking == "low_confidence": p = F.softmax(logits.to(torch.float64), dim=-1) x0_p = torch.squeeze( torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1 ) # b, l elif remasking == "random": x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: raise NotImplementedError( f"Remasking strategy '{remasking}' not implemented" ) # Don't consider positions beyond the current block x0_p[:, block_end:] = -float("inf") # Apply predictions where we have masks old_x = x.clone() x0 = torch.where(mask_index, x0, x) confidence = torch.where(mask_index, x0_p, -float("inf")) # Select tokens to unmask based on confidence transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) for j in range(confidence.shape[0]): # Only consider positions within the current block for unmasking block_confidence = confidence[j, block_start:block_end] if i < steps_per_block - 1: # Not the last step # Take top-k confidences _, select_indices = torch.topk( block_confidence, k=min( num_transfer_tokens[j, i].item(), block_confidence.numel() ), ) # Adjust indices to global positions select_indices = select_indices + block_start transfer_index[j, select_indices] = True else: # Last step - unmask everything remaining transfer_index[j, block_start:block_end] = mask_index[ j, block_start:block_end ] # Apply the selected tokens x = torch.where(transfer_index, x0, x) # Ensure constraints are maintained for pos, token_id in processed_constraints.items(): absolute_pos = prompt_length + pos if absolute_pos < x.shape[1]: x[:, absolute_pos] = token_id # Create visualization state only for the response part current_state = [] for i in range(gen_length): pos = prompt_length + i # Absolute position in the sequence if x[0, pos] == MASK_ID: # Still masked current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks elif old_x[0, pos] == MASK_ID: # Newly revealed in this step token = tokenizer.decode( [x[0, pos].item()], skip_special_tokens=True ) # Color based on confidence confidence = float(x0_p[0, pos].cpu()) if confidence < 0.3: color = "#FF6666" # Light red elif confidence < 0.7: color = "#FFAA33" # Orange else: color = "#66CC66" # Light green current_state.append((token, color)) else: # Previously revealed token = tokenizer.decode( [x[0, pos].item()], skip_special_tokens=True ) current_state.append((token, "#6699CC")) # Light blue visualization_states.append(current_state) # Extract final text (just the assistant's response) response_tokens = x[0, prompt_length:] final_text = tokenizer.decode( response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True ) return visualization_states, final_text def display_animation(prompt, constraints, gen_length, steps, temperature, cfg_scale, block_length, remasking, delay): try: vis_states, response_text = generate_response_with_visualization( model=MODEL, tokenizer=TOKENIZER, device=DEVICE, prompt=prompt, gen_length=gen_length, steps=steps, constraints=constraints, temperature=temperature, cfg_scale=cfg_scale, block_length=block_length, remasking=remasking, ) # Return the initial state immediately yield vis_states[0]#, response_text # Then animate through visualization states for state in vis_states[1:]: time.sleep(delay) yield state#, response_text except Exception as e: error_msg = f"Error: {str(e)}" print(error_msg) # Show error in visualization error_vis = [(error_msg, "red")] # Produce the error yield error_vis#, error_msg with gr.Blocks() as demo: gr.Markdown("# LLaDA - Large Language Diffusion Model") num_random_words = gr.Number(minimum=1, maximum=10, value=3, step=1, label="Number of random words") len_gen_text = gr.Slider(minimum=10, maximum=128, value=64, step=1, label="Length of generated text") random_constraints = gr.Textbox(label="Random words and their positions") generate_btn = gr.Button("Generate random words for insertion") generate_btn.click( fn=format_constraints, inputs=[num_random_words,len_gen_text], outputs=[random_constraints]) prompt = gr.Textbox(max_lines=10, label="Your prompt") with gr.Accordion("Generation Settings", open=False): with gr.Row(): steps = gr.Slider( minimum=8, maximum=64, value=16, step=4, label="Denoising Steps" ) temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature" ) cfg_scale = gr.Slider( minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale" ) with gr.Row(): block_length = gr.Slider( minimum=8, maximum=64, value=32, step=8, label="Block Length" ) remasking_strategy = gr.Radio( choices=["low_confidence", "random"], value="low_confidence", label="Remasking Strategy", ) with gr.Row(): visualization_delay = gr.Slider( minimum=0.0, maximum=1.0, value=0.8, step=0.1, label="Visualization Delay (seconds)", ) continue_btn = gr.Button("Continue the prompt!") vizbox = gr.HighlightedText(label="Output", combine_adjacent=False, show_legend=True) continue_btn.click(fn=display_animation, inputs=[prompt, random_constraints, len_gen_text, steps, temperature, cfg_scale, block_length, remasking_strategy, visualization_delay], outputs=vizbox ) if __name__ == "__main__": demo.launch(share=True)