| import gradio as gr |
| import torch |
| import torch.nn.functional as F |
| import time |
| from tokenizers import Tokenizer |
| from train2 import MaskedDiffusionModel |
|
|
| MAX_SEQ_LENGTH = 64 |
|
|
| def load_model_and_tokenizer(): |
| device = torch.device("cpu") |
| tokenizer = Tokenizer.from_file("subword_tokenizer2.json") |
| vocab = tokenizer.get_vocab() |
| |
| model = MaskedDiffusionModel( |
| vocab_size=len(vocab), |
| d_model=256, |
| nhead=8, |
| num_layers=6, |
| max_seq_len=MAX_SEQ_LENGTH, |
| dropout=0.2 |
| ).to(device) |
|
|
| try: |
| state_dict = torch.load("diffusion_model_between.pth", map_location=device) |
| model.load_state_dict(state_dict) |
| except Exception as e: |
| print(f"FAILED TO LOAD MODEL: {e}") |
| |
| model.eval() |
| return model, tokenizer, device |
|
|
| model, tokenizer, device = load_model_and_tokenizer() |
|
|
| def decode_with_masks(tensor, is_final=False): |
| eos_id = tokenizer.token_to_id("[EOS]") |
| if is_final: |
| eos_indices = (tensor == eos_id).nonzero(as_tuple=True)[0] |
| if len(eos_indices) > 0: |
| tensor = tensor[:eos_indices[0]] |
|
|
| special_ids = {tokenizer.token_to_id("[PAD]"), tokenizer.token_to_id("[BOS]"), |
| tokenizer.token_to_id("[EOS]"), tokenizer.token_to_id("[UNK]")} |
| |
| filtered_ids = [tid for tid in tensor.tolist() if tid not in special_ids] |
| if not filtered_ids: return "" |
| |
| text = tokenizer.decode(filtered_ids, skip_special_tokens=False).strip() |
| text = text.replace("[MASK]", "█") |
| |
| for p in [".", ",", "?", "!", "'", ":"]: |
| text = text.replace(f" {p}", p) |
| |
| return text.strip() |
|
|
| def predict(message, history): |
| try: |
| steps = 15 |
| temp = 0.3 |
| top_k = 10 |
|
|
| bos_id = tokenizer.token_to_id("[BOS]") |
| eos_id = tokenizer.token_to_id("[EOS]") |
| mask_id = tokenizer.token_to_id("[MASK]") |
| pad_id = tokenizer.token_to_id("[PAD]") |
| |
| formatted_prompt = f"user: {message.lower().strip()} bot:" |
| input_ids = tokenizer.encode(formatted_prompt).ids |
| max_resp = min(40, MAX_SEQ_LENGTH - len(input_ids) - 2) |
|
|
| sequence = [bos_id] + input_ids + [mask_id] * max_resp + [eos_id] |
| sequence += [pad_id] * (MAX_SEQ_LENGTH - len(sequence)) |
| seq_tensor = torch.tensor([sequence], dtype=torch.long, device=device) |
|
|
| response_start = 1 + len(input_ids) |
| response_end = response_start + max_resp |
| mask_indices = list(range(response_start, response_end)) |
| num_masks = len(mask_indices) |
|
|
| running_confidence = torch.zeros(num_masks, device=device) |
| current_seq = seq_tensor.squeeze(0).clone() |
|
|
| output_text = "" |
| for step in range(1, steps + 1): |
| t_val = max(1.0 - step / steps, 0.05) |
| t = torch.tensor([t_val], device=device) |
| |
| with torch.no_grad(): |
| logits = model(seq_tensor, t, src_key_padding_mask=(seq_tensor == pad_id)) |
| |
| response_logits = logits[0, mask_indices] |
| |
| if step > 1: |
| unique_tokens, counts = torch.unique(current_seq[mask_indices], return_counts=True) |
| for i, tok_id in enumerate(unique_tokens): |
| t_id = tok_id.item() |
| if t_id not in [bos_id, eos_id, mask_id, pad_id] and counts[i] > 1: |
| |
| response_logits[:, t_id] -= 10.0 * (counts[i].item() - 1) |
| |
| if top_k > 0: |
| v, _ = torch.topk(response_logits, top_k) |
| response_logits[response_logits < v[:, -1].unsqueeze(-1)] = -float('Inf') |
| |
| probs = F.softmax(response_logits / temp, dim=-1) |
| |
| if step == steps: |
| predicted = torch.argmax(probs, dim=-1) |
| else: |
| predicted = torch.multinomial(probs, 1).squeeze(-1) |
| |
| confidences = torch.gather(F.softmax(response_logits, dim=-1), 1, predicted.unsqueeze(-1)).squeeze(-1) |
| running_confidence = 0.7 * running_confidence + 0.3 * confidences |
| |
| for i, idx in enumerate(mask_indices): |
| current_seq[idx] = predicted[i] |
| |
| if step < steps: |
| target_reveal = int(num_masks * step / steps) |
| remask_count = num_masks - target_reveal |
| if remask_count > 0: |
| _, low_idx = torch.topk(running_confidence, k=remask_count, largest=False) |
| for li in low_idx: |
| current_seq[mask_indices[li]] = mask_id |
| |
| seq_tensor = current_seq.unsqueeze(0) |
| output_text = decode_with_masks(current_seq[response_start:response_end], is_final=(step == steps)) |
| yield output_text |
| |
| except Exception as e: |
| yield f"Error: {str(e)}" |
|
|
| custom_css = """ |
| @import url('https://fonts.googleapis.com/css2?family=Fira+Code:wght@400;600&display=swap'); |
| |
| body, .gradio-container { |
| background-color: #1e1e1e !important; |
| font-family: 'Fira Code', monospace !important; |
| color: #d4d4d4 !important; |
| } |
| |
| .hero-container { |
| padding: 2rem 5vw; |
| border-bottom: 2px solid #333; |
| background: #191919; |
| } |
| |
| .hero-brand { |
| color: #569cd6; |
| font-size: 1rem; |
| margin-bottom: 0.5rem; |
| } |
| |
| .hero-brand::before { content: "<"; color: #808080; } |
| .hero-brand::after { content: "/>"; color: #808080; } |
| |
| .hero-title { |
| font-size: 2.5rem; |
| color: #ce9178; |
| margin: 0 0 1rem 0; |
| font-weight: 600; |
| } |
| |
| .hero-description { |
| color: #6a9955; |
| line-height: 1.6; |
| font-size: 1rem; |
| background: transparent; |
| padding: 0; |
| border: none; |
| } |
| |
| .hero-description strong { |
| color: #c586c0; |
| } |
| |
| .hero-description code { |
| color: #dcdcaa; |
| background: #2d2d2d; |
| padding: 2px 6px; |
| border-radius: 3px; |
| } |
| |
| .app-container { |
| padding: 0 5vw 5vh 5vw; |
| } |
| |
| /* Customizing ChatInterface objects */ |
| .bubble-wrap { |
| font-family: 'Fira Code', monospace !important; |
| } |
| |
| .message-wrap .user { |
| background: #252526 !important; |
| border: 1px solid #3c3c3c !important; |
| color: #9cdcfe !important; |
| } |
| |
| .message-wrap .bot { |
| background: transparent !important; |
| border: none !important; |
| color: #d4d4d4 !important; |
| } |
| """ |
|
|
| with gr.Blocks(css=custom_css, fill_height=True) as demo: |
| with gr.Column(elem_classes="hero-container"): |
| gr.HTML(""" |
| <div class="hero-brand"> |
| persona-chat-mdlm |
| </div> |
| <h1 class="hero-title">PersonaChat MDLM</h1> |
| <div class="hero-description"> |
| /*<br/> |
| * <strong>Architecture:</strong> 17 Million Parameter Masked Discrete Diffusion Language Model<br/><br/> |
| * Unlike traditional autoregressive models that guess words strictly left-to-right, this model employs <strong>Parallel Denoising Generation</strong>.<br/> |
| * It maps out the structural sequence space instantly and iteratively normalizes masks into tokens.<br/><br/> |
| * <strong>Speed Paradigm:</strong> True <code>O(1)</code> scaling factor. Because generation relies on parallel iterations,<br/> |
| * computing a 10-token array demands the exact same temporal footprint as computing a 100-token array.<br/><br/> |
| * <strong>Dataset Pipeline:</strong> <code>bavard/personachat_truecased</code><br/> |
| */ |
| </div> |
| """) |
|
|
| with gr.Column(elem_classes="app-container"): |
| gr.ChatInterface( |
| predict, |
| examples=["Hi, how are you doing today?", "How are you doing?", "Do you have any pets?", "What kind of music do you like?"] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |