| | import gradio as gr |
| | import torch |
| | import numpy as np |
| | import json |
| | import time |
| | from transformers import AutoTokenizer |
| | import os |
| | import importlib |
| | import os |
| | from huggingface_hub import hf_hub_download |
| |
|
| | import spaces |
| | from dotenv import load_dotenv |
| | from infer import ( |
| | load_trained_model, |
| | find_answer_start, |
| | get_noising_schedule, |
| | noisify_answer, |
| | filter_logits, |
| | confidence_guided_noising, |
| | noisify_answer_without_remasking |
| | ) |
| | from models import CustomTransformerModel |
| | from model_config import CustomTransformerConfig |
| |
|
| | |
| | if os.getenv("HF_TOKEN") is None: |
| | load_dotenv() |
| |
|
| | hf_token = os.getenv("HF_TOKEN") |
| |
|
| | if hf_token is None: |
| | raise ValueError("HF_TOKEN is not set") |
| |
|
| | rng = np.random.default_rng() |
| |
|
| | @spaces.GPU |
| | def generate_diffusion_text(input_ids, top_p, top_k): |
| | with torch.no_grad(): |
| | input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) |
| |
|
| | with torch.cuda.amp.autocast(dtype=torch.float16): |
| | logits = model(input_ids=input_tensor)["logits"] |
| | |
| | logits = filter_logits(logits, top_k=top_k, top_p=top_p) |
| | logits = logits.clamp(min=-1e8, max=1e4) |
| | probs = torch.nn.functional.softmax(logits, dim=-1)[0] |
| | probs = torch.clamp(probs, min=1e-8, max=1.0) |
| | |
| | |
| | sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist() |
| | conf = probs[range(len(sampled)), sampled].cpu().numpy() |
| | return sampled, conf |
| |
|
| | def format_chat_prompt(question): |
| | return ( |
| | "<|begin_of_text|>\n" |
| | "<|start_header_id|>system<|end_header_id|>\n" |
| | "You are a helpful assistant.\n" |
| | "<|start_header_id|>user<|end_header_id|>\n" |
| | f"{question}\n" |
| | "<|start_header_id|>assistant<|end_header_id|>\n" |
| | ) |
| |
|
| | def render_html(label, text): |
| | return f"<b>{label}</b><br><div style='white-space: pre-wrap; line-height:1.8'>{text}</div>" |
| |
|
| | def highlight_tokens(token_ids, answer_start, changed_indices, color): |
| | tokens = tokenizer.convert_ids_to_tokens(token_ids) |
| | highlighted = [] |
| | for j, tok in enumerate(tokens): |
| | if tokenizer.convert_tokens_to_ids(tok) == eos_token_id: |
| | continue |
| | tok_str = tokenizer.convert_tokens_to_string([tok]) |
| | if (answer_start + j) in changed_indices: |
| | highlighted.append(f'<span style="color:{color}">{tok_str}</span>') |
| | else: |
| | highlighted.append(tok_str) |
| | return "".join(highlighted) |
| |
|
| | def diffusion_chat(question, noising, enable_pause, max_it): |
| | |
| | sharpness = 3.0 |
| | noise_start = 0.5 |
| | top_p = 1.0 |
| | top_k = 10 |
| | clustering = False |
| | pause_length = 1.0 if enable_pause else 0.0 |
| |
|
| | if question.strip() == "": |
| | question = "What do you know about Amsterdam?" |
| |
|
| | prompt = format_chat_prompt(question) |
| | input_ids = tokenizer.encode(prompt, add_special_tokens=False) |
| | answer_start = find_answer_start(input_ids, assistant_marker_ids) |
| | if answer_start is None: |
| | yield render_html("Error", "Could not find Assistant marker in input.") |
| | return |
| |
|
| | input_ids = (input_ids + [mask_token_id] * (256 - len(input_ids)))[:256] |
| | ori_input_tokens = input_ids |
| |
|
| | |
| | current_tokens, just_noised_indices = noisify_answer( |
| | input_ids, answer_start, tokenizer, threshold=1.0, clustering=clustering, noise_start=1.0 |
| | ) |
| | yield render_html("Iteration 0 (initial noise)", |
| | highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red")) |
| | |
| | start = time.perf_counter() |
| |
|
| | last_tokens = [] |
| | prev_decoded = [] |
| |
|
| | unmasked_mask = [False] * len(current_tokens) |
| |
|
| | for i in range(max_it): |
| | |
| | generated_tokens, confidences = generate_diffusion_text(current_tokens, top_p, top_k) |
| | current_tokens = ori_input_tokens[:answer_start] + generated_tokens[answer_start:] |
| | |
| |
|
| | |
| | new_decoded = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) |
| | diff_indices = { |
| | answer_start + j for j, tok in enumerate(new_decoded) |
| | if j >= len(prev_decoded) or tok != prev_decoded[j] |
| | } |
| | prev_decoded = new_decoded |
| |
|
| | time.sleep(max(pause_length - (time.perf_counter() - start), 0)) |
| |
|
| | yield render_html(f"Iteration {i+1}/{max_it} (after generation)", |
| | highlight_tokens(current_tokens[answer_start:], answer_start, diff_indices, color="green")) |
| | time.sleep(pause_length) |
| |
|
| | |
| | last_tokens.append(current_tokens) |
| | if len(last_tokens) > 3: |
| | last_tokens.pop(0) |
| | if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]: |
| | yield render_html("Stopped early", f"After {i+1} iterations.") |
| | break |
| | |
| | |
| | if i < max_it-1 and noising: |
| | threshold = get_noising_schedule(i, max_it, sharpness=sharpness) |
| |
|
| | noised_answer, just_noised_indices = noisify_answer( |
| | current_tokens, answer_start, tokenizer, |
| | threshold=threshold, clustering=clustering, noise_start=noise_start |
| | ) |
| | |
| | for idx in range(answer_start, len(current_tokens)): |
| | if noised_answer[idx] != mask_token_id: |
| | unmasked_mask[idx] = True |
| | |
| | |
| | |
| | yield render_html(f"Iteration {i+1}/{max_it} (before noising)", |
| | highlight_tokens(current_tokens[answer_start:], answer_start, just_noised_indices, color="red")) |
| | start = time.perf_counter() |
| | |
| | current_tokens = ori_input_tokens[:answer_start] + noised_answer[answer_start:] |
| |
|
| | |
| | answer_ids = current_tokens[answer_start:] |
| | try: |
| | final_ids = answer_ids[:answer_ids.index(eos_token_id)] |
| | except ValueError: |
| | final_ids = answer_ids |
| |
|
| | final_output = tokenizer.decode(final_ids, skip_special_tokens=True) |
| | yield render_html(f"Final Output ({len(final_ids)} tokens after {i+1} iterations)", final_output) |
| |
|
| |
|
| | def is_running_on_spaces(): |
| | return os.getenv("SPACE_ID") is not None |
| |
|
| | print("Loading model...") |
| |
|
| | if is_running_on_spaces(): |
| | |
| | ckpt_path = hf_hub_download( |
| | repo_id="ruurd/tini_model", |
| | filename="diffusion-model-8B.pth", |
| | token=os.getenv("HF_TOKEN") |
| | ) |
| | else: |
| | |
| | ckpt_path = "diffusion-model-8B.pth" |
| |
|
| | model, tokenizer = load_trained_model(checkpoint_path=ckpt_path) |
| | print("✅ Model loaded.") |
| |
|
| | vocab_size = len(tokenizer) |
| | eos_token_id = tokenizer.eos_token_id |
| | mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0] |
| | assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>\n", add_special_tokens=False) |
| |
|
| | demo = gr.Interface( |
| | fn=diffusion_chat, |
| | inputs=[ |
| | gr.Textbox( |
| | label="User Question", |
| | lines=2, |
| | placeholder="What do you know about Amsterdam?", |
| | ), |
| | gr.Checkbox(label="Enable intermediate noising", value=True), |
| | gr.Checkbox(label="Pause between iterations", value=False), |
| | gr.Slider(1, 512, value=64, step=1, label="Increase the maximum number of iterations."), |
| | ], |
| | outputs=gr.HTML(label="Diffusion Output"), |
| | title="LAD Chat", |
| | allow_flagging="never", |
| | live=False |
| | ) |
| |
|
| | demo.launch(share=True, allowed_paths=["."], ssr_mode=False) |
| |
|
| |
|
| |
|
| |
|