ritikraj2425's picture
Restored app-container frame around ChatInterface
f33f0ec
Raw
History Blame Contribute Delete
7.87 kB
import gradio as gr
import torch
import torch.nn.functional as F
import time
from tokenizers import Tokenizer
from train2 import MaskedDiffusionModel
MAX_SEQ_LENGTH = 64
def load_model_and_tokenizer():
device = torch.device("cpu")
tokenizer = Tokenizer.from_file("subword_tokenizer2.json")
vocab = tokenizer.get_vocab()
model = MaskedDiffusionModel(
vocab_size=len(vocab),
d_model=256,
nhead=8,
num_layers=6,
max_seq_len=MAX_SEQ_LENGTH,
dropout=0.2
).to(device)
try:
state_dict = torch.load("diffusion_model_between.pth", map_location=device)
model.load_state_dict(state_dict)
except Exception as e:
print(f"FAILED TO LOAD MODEL: {e}")
model.eval()
return model, tokenizer, device
model, tokenizer, device = load_model_and_tokenizer()
def decode_with_masks(tensor, is_final=False):
eos_id = tokenizer.token_to_id("[EOS]")
if is_final:
eos_indices = (tensor == eos_id).nonzero(as_tuple=True)[0]
if len(eos_indices) > 0:
tensor = tensor[:eos_indices[0]]
special_ids = {tokenizer.token_to_id("[PAD]"), tokenizer.token_to_id("[BOS]"),
tokenizer.token_to_id("[EOS]"), tokenizer.token_to_id("[UNK]")}
filtered_ids = [tid for tid in tensor.tolist() if tid not in special_ids]
if not filtered_ids: return ""
text = tokenizer.decode(filtered_ids, skip_special_tokens=False).strip()
text = text.replace("[MASK]", "█")
for p in [".", ",", "?", "!", "'", ":"]:
text = text.replace(f" {p}", p)
return text.strip()
def predict(message, history):
try:
steps = 15
temp = 0.3
top_k = 10
bos_id = tokenizer.token_to_id("[BOS]")
eos_id = tokenizer.token_to_id("[EOS]")
mask_id = tokenizer.token_to_id("[MASK]")
pad_id = tokenizer.token_to_id("[PAD]")
formatted_prompt = f"user: {message.lower().strip()} bot:"
input_ids = tokenizer.encode(formatted_prompt).ids
max_resp = min(40, MAX_SEQ_LENGTH - len(input_ids) - 2)
sequence = [bos_id] + input_ids + [mask_id] * max_resp + [eos_id]
sequence += [pad_id] * (MAX_SEQ_LENGTH - len(sequence))
seq_tensor = torch.tensor([sequence], dtype=torch.long, device=device)
response_start = 1 + len(input_ids)
response_end = response_start + max_resp
mask_indices = list(range(response_start, response_end))
num_masks = len(mask_indices)
running_confidence = torch.zeros(num_masks, device=device)
current_seq = seq_tensor.squeeze(0).clone()
output_text = ""
for step in range(1, steps + 1):
t_val = max(1.0 - step / steps, 0.05)
t = torch.tensor([t_val], device=device)
with torch.no_grad():
logits = model(seq_tensor, t, src_key_padding_mask=(seq_tensor == pad_id))
response_logits = logits[0, mask_indices]
if step > 1:
unique_tokens, counts = torch.unique(current_seq[mask_indices], return_counts=True)
for i, tok_id in enumerate(unique_tokens):
t_id = tok_id.item()
if t_id not in [bos_id, eos_id, mask_id, pad_id] and counts[i] > 1:
# Logit subtraction entirely prevents structural duplication loops natively
response_logits[:, t_id] -= 10.0 * (counts[i].item() - 1)
if top_k > 0:
v, _ = torch.topk(response_logits, top_k)
response_logits[response_logits < v[:, -1].unsqueeze(-1)] = -float('Inf')
probs = F.softmax(response_logits / temp, dim=-1)
if step == steps:
predicted = torch.argmax(probs, dim=-1)
else:
predicted = torch.multinomial(probs, 1).squeeze(-1)
confidences = torch.gather(F.softmax(response_logits, dim=-1), 1, predicted.unsqueeze(-1)).squeeze(-1)
running_confidence = 0.7 * running_confidence + 0.3 * confidences
for i, idx in enumerate(mask_indices):
current_seq[idx] = predicted[i]
if step < steps:
target_reveal = int(num_masks * step / steps)
remask_count = num_masks - target_reveal
if remask_count > 0:
_, low_idx = torch.topk(running_confidence, k=remask_count, largest=False)
for li in low_idx:
current_seq[mask_indices[li]] = mask_id
seq_tensor = current_seq.unsqueeze(0)
output_text = decode_with_masks(current_seq[response_start:response_end], is_final=(step == steps))
yield output_text
except Exception as e:
yield f"Error: {str(e)}"
custom_css = """
@import url('https://fonts.googleapis.com/css2?family=Fira+Code:wght@400;600&display=swap');
body, .gradio-container {
background-color: #1e1e1e !important;
font-family: 'Fira Code', monospace !important;
color: #d4d4d4 !important;
}
.hero-container {
padding: 2rem 5vw;
border-bottom: 2px solid #333;
background: #191919;
}
.hero-brand {
color: #569cd6;
font-size: 1rem;
margin-bottom: 0.5rem;
}
.hero-brand::before { content: "<"; color: #808080; }
.hero-brand::after { content: "/>"; color: #808080; }
.hero-title {
font-size: 2.5rem;
color: #ce9178;
margin: 0 0 1rem 0;
font-weight: 600;
}
.hero-description {
color: #6a9955;
line-height: 1.6;
font-size: 1rem;
background: transparent;
padding: 0;
border: none;
}
.hero-description strong {
color: #c586c0;
}
.hero-description code {
color: #dcdcaa;
background: #2d2d2d;
padding: 2px 6px;
border-radius: 3px;
}
.app-container {
padding: 0 5vw 5vh 5vw;
}
/* Customizing ChatInterface objects */
.bubble-wrap {
font-family: 'Fira Code', monospace !important;
}
.message-wrap .user {
background: #252526 !important;
border: 1px solid #3c3c3c !important;
color: #9cdcfe !important;
}
.message-wrap .bot {
background: transparent !important;
border: none !important;
color: #d4d4d4 !important;
}
"""
with gr.Blocks(css=custom_css, fill_height=True) as demo:
with gr.Column(elem_classes="hero-container"):
gr.HTML("""
<div class="hero-brand">
persona-chat-mdlm
</div>
<h1 class="hero-title">PersonaChat MDLM</h1>
<div class="hero-description">
/*<br/>
&nbsp;* <strong>Architecture:</strong> 17 Million Parameter Masked Discrete Diffusion Language Model<br/><br/>
&nbsp;* Unlike traditional autoregressive models that guess words strictly left-to-right, this model employs <strong>Parallel Denoising Generation</strong>.<br/>
&nbsp;* It maps out the structural sequence space instantly and iteratively normalizes masks into tokens.<br/><br/>
&nbsp;* <strong>Speed Paradigm:</strong> True <code>O(1)</code> scaling factor. Because generation relies on parallel iterations,<br/>
&nbsp;* computing a 10-token array demands the exact same temporal footprint as computing a 100-token array.<br/><br/>
&nbsp;* <strong>Dataset Pipeline:</strong> <code>bavard/personachat_truecased</code><br/>
&nbsp;*/
</div>
""")
with gr.Column(elem_classes="app-container"):
gr.ChatInterface(
predict,
examples=["Hi, how are you doing today?", "How are you doing?", "Do you have any pets?", "What kind of music do you like?"]
)
if __name__ == "__main__":
demo.launch()