Spaces:
Runtime error
Runtime error
| import random, time, ast | |
| import torch | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from wonderwords import RandomWord | |
| from transformers import AutoTokenizer, AutoModel | |
| if torch.cuda.is_available(): | |
| # Checks if you have an Nvidia GPU. | |
| # If so, it will use it for inference. | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| # Checks if you are using Apple Silicon. | |
| # If so, it will take advantage of the integrated GPU. | |
| DEVICE = "mps" | |
| else: | |
| # Else, it will just use your CPU. | |
| DEVICE = "cpu" | |
| print(f"Using device: {DEVICE}") | |
| # PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 | |
| try: | |
| # Load model and tokenizer | |
| TOKENIZER = AutoTokenizer.from_pretrained( | |
| "GSAI-ML/LLaDA-8B-Base", trust_remote_code=True | |
| ) | |
| MODEL = AutoModel.from_pretrained( | |
| "GSAI-ML/LLaDA-8B-Base", | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to(DEVICE) | |
| print("Model and Tokenizer loaded.") | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| print(error_msg) | |
| # Constants | |
| MASK_TOKEN = "[MASK]" | |
| MASK_ID = 126336 # The token ID of [MASK] in LLaDA | |
| rw = RandomWord() | |
| def random_sample_without_replacement(sample_size: int, | |
| population_size: int) -> list: | |
| if not (1 <= sample_size <= population_size): | |
| raise ValueError("Sample size must be between 1 and population size.") | |
| selected_indices = set() | |
| while len(selected_indices) < sample_size: | |
| index = random.randrange(population_size) | |
| if index not in selected_indices: | |
| selected_indices.add(index) | |
| yield index | |
| def format_constraints(num_words: int, | |
| max_gen_length: int) -> dict: | |
| """Format constraints in format: 'position:word, position:word, ...'""" | |
| out = {} | |
| word_list = rw.random_words(num_words) | |
| positions = [i for i in random_sample_without_replacement(num_words, | |
| max_gen_length)] | |
| for j, position in enumerate(positions): | |
| out[position] = word_list[j] | |
| return out | |
| def add_gumbel_noise(logits, temperature): | |
| """ | |
| The Gumbel max is a method for sampling categorical distributions. | |
| According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. | |
| Thus, we use float64. | |
| """ | |
| if temperature <= 0: | |
| return logits | |
| logits = logits.to(torch.float64) | |
| noise = torch.rand_like(logits, dtype=torch.float64) | |
| gumbel_noise = (-torch.log(noise)) ** temperature | |
| return logits.exp() / gumbel_noise | |
| def get_num_transfer_tokens(mask_index, steps): | |
| """ | |
| In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. | |
| Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), | |
| the expected number of tokens transitioned at each step should be consistent. | |
| This function is designed to precompute the number of tokens that need to be transitioned at each step. | |
| """ | |
| mask_num = mask_index.sum(dim=1, keepdim=True) | |
| base = mask_num // steps | |
| remainder = mask_num % steps | |
| num_transfer_tokens = ( | |
| torch.zeros( | |
| mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64 | |
| ) | |
| + base | |
| ) | |
| for i in range(mask_num.size(0)): | |
| num_transfer_tokens[i, : remainder[i]] += 1 | |
| return num_transfer_tokens | |
| def generate_response_with_visualization( | |
| model, | |
| tokenizer, | |
| device, | |
| prompt, | |
| gen_length=64, | |
| steps=32, | |
| constraints=None, | |
| temperature=0.0, | |
| cfg_scale=0.0, | |
| block_length=32, | |
| remasking="low_confidence", | |
| ): | |
| """ | |
| Generate text with LLaDA model with visualization using the same sampling as in generate.py | |
| Args: | |
| prompt: The prompt | |
| gen_length: Length of text to generate | |
| steps: Number of denoising steps | |
| constraints: Dictionary mapping positions to words | |
| temperature: Sampling temperature | |
| cfg_scale: Classifier-free guidance scale | |
| block_length: Block length for semi-autoregressive generation | |
| remasking: Remasking strategy ('low_confidence' or 'random') | |
| Returns: | |
| List of visualization states showing the progression and final text | |
| """ | |
| # Process constraints | |
| if constraints is None: | |
| constraints = {} | |
| else: | |
| constraints = ast.literal_eval(constraints) | |
| # Convert any string constraints to token IDs | |
| processed_constraints = {} | |
| for pos, word in constraints.items(): | |
| tokens = tokenizer.encode(" " + word, add_special_tokens=False) | |
| for i, token_id in enumerate(tokens): | |
| processed_constraints[pos + i] = token_id | |
| # Tokenize the prompt | |
| input_ids = tokenizer(prompt)["input_ids"] | |
| input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) | |
| # For generation | |
| prompt_length = input_ids.shape[1] | |
| # Initialize the sequence with masks for the response part | |
| x = torch.full((1, prompt_length + gen_length), MASK_ID, dtype=torch.long).to( | |
| device | |
| ) | |
| x[:, :prompt_length] = input_ids.clone() | |
| # Initialize visualization states for the response part | |
| visualization_states = [] | |
| # Add initial state (all masked) | |
| initial_state = [(MASK_TOKEN, "#444444") for _ in range(gen_length)] | |
| visualization_states.append(initial_state) | |
| # Apply constraints to the initial state | |
| for pos, token_id in processed_constraints.items(): | |
| absolute_pos = prompt_length + pos | |
| if absolute_pos < x.shape[1]: | |
| x[:, absolute_pos] = token_id | |
| # Mark prompt positions to exclude them from masking during classifier-free guidance | |
| prompt_index = x != MASK_ID | |
| # Ensure block_length is valid | |
| if block_length > gen_length: | |
| block_length = gen_length | |
| # Calculate number of blocks | |
| num_blocks = gen_length // block_length | |
| if gen_length % block_length != 0: | |
| num_blocks += 1 | |
| # Adjust steps per block | |
| steps_per_block = steps // num_blocks | |
| if steps_per_block < 1: | |
| steps_per_block = 1 | |
| # Track the current state of x for visualization | |
| current_x = x.clone() | |
| # Process each block | |
| for num_block in range(num_blocks): | |
| # Calculate the start and end indices for the current block | |
| block_start = prompt_length + num_block * block_length | |
| block_end = min(prompt_length + (num_block + 1) * block_length, x.shape[1]) | |
| # Get mask indices for the current block | |
| block_mask_index = x[:, block_start:block_end] == MASK_ID | |
| # Skip if no masks in this block | |
| if not block_mask_index.any(): | |
| continue | |
| # Calculate number of tokens to unmask at each step | |
| num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block) | |
| # Process each step | |
| for i in range(steps_per_block): | |
| print(f"Processing step{i}") ## for logging and debugging... | |
| # Get all mask positions in the current sequence | |
| mask_index = x == MASK_ID | |
| # Skip if no masks | |
| if not mask_index.any(): | |
| break | |
| # Apply classifier-free guidance if enabled | |
| if cfg_scale > 0.0: | |
| un_x = x.clone() | |
| un_x[prompt_index] = MASK_ID | |
| x_ = torch.cat([x, un_x], dim=0) | |
| logits = model(x_).logits | |
| logits, un_logits = torch.chunk(logits, 2, dim=0) | |
| logits = un_logits + (cfg_scale + 1) * (logits - un_logits) | |
| else: | |
| logits = model(x).logits | |
| # Apply Gumbel noise for sampling | |
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) | |
| # Calculate confidence scores for remasking | |
| if remasking == "low_confidence": | |
| p = F.softmax(logits.to(torch.float64), dim=-1) | |
| x0_p = torch.squeeze( | |
| torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1 | |
| ) # b, l | |
| elif remasking == "random": | |
| x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
| else: | |
| raise NotImplementedError( | |
| f"Remasking strategy '{remasking}' not implemented" | |
| ) | |
| # Don't consider positions beyond the current block | |
| x0_p[:, block_end:] = -float("inf") | |
| # Apply predictions where we have masks | |
| old_x = x.clone() | |
| x0 = torch.where(mask_index, x0, x) | |
| confidence = torch.where(mask_index, x0_p, -float("inf")) | |
| # Select tokens to unmask based on confidence | |
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
| for j in range(confidence.shape[0]): | |
| # Only consider positions within the current block for unmasking | |
| block_confidence = confidence[j, block_start:block_end] | |
| if i < steps_per_block - 1: # Not the last step | |
| # Take top-k confidences | |
| _, select_indices = torch.topk( | |
| block_confidence, | |
| k=min( | |
| num_transfer_tokens[j, i].item(), block_confidence.numel() | |
| ), | |
| ) | |
| # Adjust indices to global positions | |
| select_indices = select_indices + block_start | |
| transfer_index[j, select_indices] = True | |
| else: # Last step - unmask everything remaining | |
| transfer_index[j, block_start:block_end] = mask_index[ | |
| j, block_start:block_end | |
| ] | |
| # Apply the selected tokens | |
| x = torch.where(transfer_index, x0, x) | |
| # Ensure constraints are maintained | |
| for pos, token_id in processed_constraints.items(): | |
| absolute_pos = prompt_length + pos | |
| if absolute_pos < x.shape[1]: | |
| x[:, absolute_pos] = token_id | |
| # Create visualization state only for the response part | |
| current_state = [] | |
| for i in range(gen_length): | |
| pos = prompt_length + i # Absolute position in the sequence | |
| if x[0, pos] == MASK_ID: | |
| # Still masked | |
| current_state.append((MASK_TOKEN, "#444444")) # Dark gray for masks | |
| elif old_x[0, pos] == MASK_ID: | |
| # Newly revealed in this step | |
| token = tokenizer.decode( | |
| [x[0, pos].item()], skip_special_tokens=True | |
| ) | |
| # Color based on confidence | |
| confidence = float(x0_p[0, pos].cpu()) | |
| if confidence < 0.3: | |
| color = "#FF6666" # Light red | |
| elif confidence < 0.7: | |
| color = "#FFAA33" # Orange | |
| else: | |
| color = "#66CC66" # Light green | |
| current_state.append((token, color)) | |
| else: | |
| # Previously revealed | |
| token = tokenizer.decode( | |
| [x[0, pos].item()], skip_special_tokens=True | |
| ) | |
| current_state.append((token, "#6699CC")) # Light blue | |
| visualization_states.append(current_state) | |
| # Extract final text (just the assistant's response) | |
| response_tokens = x[0, prompt_length:] | |
| final_text = tokenizer.decode( | |
| response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| return visualization_states, final_text | |
| def display_animation(prompt, | |
| constraints, | |
| gen_length, | |
| steps, | |
| temperature, | |
| cfg_scale, | |
| block_length, | |
| remasking, | |
| delay): | |
| try: | |
| vis_states, response_text = generate_response_with_visualization( | |
| model=MODEL, | |
| tokenizer=TOKENIZER, | |
| device=DEVICE, | |
| prompt=prompt, | |
| gen_length=gen_length, | |
| steps=steps, | |
| constraints=constraints, | |
| temperature=temperature, | |
| cfg_scale=cfg_scale, | |
| block_length=block_length, | |
| remasking=remasking, | |
| ) | |
| # Return the initial state immediately | |
| yield vis_states[0]#, response_text | |
| # Then animate through visualization states | |
| for state in vis_states[1:]: | |
| time.sleep(delay) | |
| yield state#, response_text | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| print(error_msg) | |
| # Show error in visualization | |
| error_vis = [(error_msg, "red")] | |
| # Produce the error | |
| yield error_vis#, error_msg | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# LLaDA - Large Language Diffusion Model") | |
| num_random_words = gr.Number(minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Number of random words") | |
| len_gen_text = gr.Slider(minimum=10, | |
| maximum=128, | |
| value=64, | |
| step=1, | |
| label="Length of generated text") | |
| random_constraints = gr.Textbox(label="Random words and their positions") | |
| generate_btn = gr.Button("Generate random words for insertion") | |
| generate_btn.click( | |
| fn=format_constraints, | |
| inputs=[num_random_words,len_gen_text], | |
| outputs=[random_constraints]) | |
| prompt = gr.Textbox(max_lines=10, label="Your prompt") | |
| with gr.Accordion("Generation Settings", open=False): | |
| with gr.Row(): | |
| steps = gr.Slider( | |
| minimum=8, maximum=64, value=16, step=4, label="Denoising Steps" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="Temperature" | |
| ) | |
| cfg_scale = gr.Slider( | |
| minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale" | |
| ) | |
| with gr.Row(): | |
| block_length = gr.Slider( | |
| minimum=8, maximum=64, value=32, step=8, label="Block Length" | |
| ) | |
| remasking_strategy = gr.Radio( | |
| choices=["low_confidence", "random"], | |
| value="low_confidence", | |
| label="Remasking Strategy", | |
| ) | |
| with gr.Row(): | |
| visualization_delay = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Visualization Delay (seconds)", | |
| ) | |
| continue_btn = gr.Button("Continue the prompt!") | |
| vizbox = gr.HighlightedText(label="Output", | |
| combine_adjacent=False, | |
| show_legend=True) | |
| continue_btn.click(fn=display_animation, | |
| inputs=[prompt, | |
| random_constraints, | |
| len_gen_text, | |
| steps, | |
| temperature, | |
| cfg_scale, | |
| block_length, | |
| remasking_strategy, | |
| visualization_delay], | |
| outputs=vizbox ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |