import gradio as gr import torch import torch.nn.functional as F import time from tokenizers import Tokenizer from train2 import MaskedDiffusionModel MAX_SEQ_LENGTH = 64 def load_model_and_tokenizer(): device = torch.device("cpu") tokenizer = Tokenizer.from_file("subword_tokenizer2.json") vocab = tokenizer.get_vocab() model = MaskedDiffusionModel( vocab_size=len(vocab), d_model=256, nhead=8, num_layers=6, max_seq_len=MAX_SEQ_LENGTH, dropout=0.2 ).to(device) try: state_dict = torch.load("diffusion_model_between.pth", map_location=device) model.load_state_dict(state_dict) except Exception as e: print(f"FAILED TO LOAD MODEL: {e}") model.eval() return model, tokenizer, device model, tokenizer, device = load_model_and_tokenizer() def decode_with_masks(tensor, is_final=False): eos_id = tokenizer.token_to_id("[EOS]") if is_final: eos_indices = (tensor == eos_id).nonzero(as_tuple=True)[0] if len(eos_indices) > 0: tensor = tensor[:eos_indices[0]] special_ids = {tokenizer.token_to_id("[PAD]"), tokenizer.token_to_id("[BOS]"), tokenizer.token_to_id("[EOS]"), tokenizer.token_to_id("[UNK]")} filtered_ids = [tid for tid in tensor.tolist() if tid not in special_ids] if not filtered_ids: return "" text = tokenizer.decode(filtered_ids, skip_special_tokens=False).strip() text = text.replace("[MASK]", "█") for p in [".", ",", "?", "!", "'", ":"]: text = text.replace(f" {p}", p) return text.strip() def predict(message, history): try: steps = 15 temp = 0.3 top_k = 10 bos_id = tokenizer.token_to_id("[BOS]") eos_id = tokenizer.token_to_id("[EOS]") mask_id = tokenizer.token_to_id("[MASK]") pad_id = tokenizer.token_to_id("[PAD]") formatted_prompt = f"user: {message.lower().strip()} bot:" input_ids = tokenizer.encode(formatted_prompt).ids max_resp = min(40, MAX_SEQ_LENGTH - len(input_ids) - 2) sequence = [bos_id] + input_ids + [mask_id] * max_resp + [eos_id] sequence += [pad_id] * (MAX_SEQ_LENGTH - len(sequence)) seq_tensor = torch.tensor([sequence], dtype=torch.long, device=device) response_start = 1 + len(input_ids) response_end = response_start + max_resp mask_indices = list(range(response_start, response_end)) num_masks = len(mask_indices) running_confidence = torch.zeros(num_masks, device=device) current_seq = seq_tensor.squeeze(0).clone() output_text = "" for step in range(1, steps + 1): t_val = max(1.0 - step / steps, 0.05) t = torch.tensor([t_val], device=device) with torch.no_grad(): logits = model(seq_tensor, t, src_key_padding_mask=(seq_tensor == pad_id)) response_logits = logits[0, mask_indices] if step > 1: unique_tokens, counts = torch.unique(current_seq[mask_indices], return_counts=True) for i, tok_id in enumerate(unique_tokens): t_id = tok_id.item() if t_id not in [bos_id, eos_id, mask_id, pad_id] and counts[i] > 1: # Logit subtraction entirely prevents structural duplication loops natively response_logits[:, t_id] -= 10.0 * (counts[i].item() - 1) if top_k > 0: v, _ = torch.topk(response_logits, top_k) response_logits[response_logits < v[:, -1].unsqueeze(-1)] = -float('Inf') probs = F.softmax(response_logits / temp, dim=-1) if step == steps: predicted = torch.argmax(probs, dim=-1) else: predicted = torch.multinomial(probs, 1).squeeze(-1) confidences = torch.gather(F.softmax(response_logits, dim=-1), 1, predicted.unsqueeze(-1)).squeeze(-1) running_confidence = 0.7 * running_confidence + 0.3 * confidences for i, idx in enumerate(mask_indices): current_seq[idx] = predicted[i] if step < steps: target_reveal = int(num_masks * step / steps) remask_count = num_masks - target_reveal if remask_count > 0: _, low_idx = torch.topk(running_confidence, k=remask_count, largest=False) for li in low_idx: current_seq[mask_indices[li]] = mask_id seq_tensor = current_seq.unsqueeze(0) output_text = decode_with_masks(current_seq[response_start:response_end], is_final=(step == steps)) yield output_text except Exception as e: yield f"Error: {str(e)}" custom_css = """ @import url('https://fonts.googleapis.com/css2?family=Fira+Code:wght@400;600&display=swap'); body, .gradio-container { background-color: #1e1e1e !important; font-family: 'Fira Code', monospace !important; color: #d4d4d4 !important; } .hero-container { padding: 2rem 5vw; border-bottom: 2px solid #333; background: #191919; } .hero-brand { color: #569cd6; font-size: 1rem; margin-bottom: 0.5rem; } .hero-brand::before { content: "<"; color: #808080; } .hero-brand::after { content: "/>"; color: #808080; } .hero-title { font-size: 2.5rem; color: #ce9178; margin: 0 0 1rem 0; font-weight: 600; } .hero-description { color: #6a9955; line-height: 1.6; font-size: 1rem; background: transparent; padding: 0; border: none; } .hero-description strong { color: #c586c0; } .hero-description code { color: #dcdcaa; background: #2d2d2d; padding: 2px 6px; border-radius: 3px; } .app-container { padding: 0 5vw 5vh 5vw; } /* Customizing ChatInterface objects */ .bubble-wrap { font-family: 'Fira Code', monospace !important; } .message-wrap .user { background: #252526 !important; border: 1px solid #3c3c3c !important; color: #9cdcfe !important; } .message-wrap .bot { background: transparent !important; border: none !important; color: #d4d4d4 !important; } """ with gr.Blocks(css=custom_css, fill_height=True) as demo: with gr.Column(elem_classes="hero-container"): gr.HTML("""
O(1) scaling factor. Because generation relies on parallel iterations,bavard/personachat_truecased