ritikraj2425's picture
pushed 17m param model
fb3221e
Raw
History Blame Contribute Delete
7.1 kB
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from train import MaskedDiffusionModel
from tokenizer import MAX_SEQ_LENGTH
import os
def load_tokenizer_full(vocab_file="subword_tokenizer.json"):
tokenizer = Tokenizer.from_file(vocab_file)
vocab = tokenizer.get_vocab()
id2word = {int(v): k for k, v in vocab.items()}
return tokenizer, vocab, id2word
def decode_response(token_ids, tokenizer):
"""
Use the tokenizer's built-in decode() for proper BPE subword handling.
Filter out special tokens before decoding.
"""
special_ids = {
tokenizer.token_to_id("[PAD]"),
tokenizer.token_to_id("[BOS]"),
tokenizer.token_to_id("[EOS]"),
tokenizer.token_to_id("[MASK]"),
tokenizer.token_to_id("[UNK]"),
}
filtered_ids = [tid for tid in token_ids if tid not in special_ids]
if not filtered_ids:
return ""
# Let the tokenizer handle subword reassembly
text = tokenizer.decode(filtered_ids)
# Light cleanup
text = text.strip()
# Fix spacing around punctuation
for p in [".", ",", "?", "!", "'", ":"]:
text = text.replace(f" {p}", p)
return text
def generate_response(
model, tokenizer, id2word,
prompt,
max_response_length=24,
sampling_steps=40,
temperature=0.5,
top_k=15
):
"""
Timestep-aware iterative denoising with running confidence remasking.
Key differences from old version:
1. Passes timestep t to model at each step (model now knows what stage it's at)
2. t = fraction of tokens still masked (starts ~1.0, ends ~0.0)
3. Running confidence remasking: tracks cumulative confidence per token
"""
model.eval()
device = next(model.parameters()).device
bos_id = tokenizer.token_to_id("[BOS]")
eos_id = tokenizer.token_to_id("[EOS]")
mask_id = tokenizer.token_to_id("[MASK]")
pad_id = tokenizer.token_to_id("[PAD]")
formatted = f"user: {prompt.lower().strip()} bot:"
input_ids = tokenizer.encode(formatted).ids
# Clamp response length to available space
max_resp = min(max_response_length, MAX_SEQ_LENGTH - len(input_ids) - 2)
if max_resp <= 0:
print("Prompt too long.")
return ""
sequence = [bos_id] + input_ids + [mask_id] * max_resp + [eos_id]
sequence += [pad_id] * (MAX_SEQ_LENGTH - len(sequence))
seq_tensor = torch.tensor([sequence], dtype=torch.long, device=device)
response_start = 1 + len(input_ids)
response_end = response_start + max_resp
mask_indices = list(range(response_start, response_end))
num_masks = len(mask_indices)
# Running confidence: tracks cumulative confidence per token position
# Tokens that are consistently predicted with high confidence get revealed first
running_confidence = torch.zeros(num_masks, device=device)
for step in range(1, sampling_steps + 1):
# ── Timestep: fraction of tokens still masked ──
# Step 1: t β‰ˆ 1.0 (most tokens masked, early stage)
# Step 40: t β‰ˆ 0.025 (few tokens masked, final refinement)
# Clamp to [0.05, 1.0] to stay within training range
t_val = max(1.0 - step / sampling_steps, 0.05)
t = torch.tensor([t_val], device=device)
# Padding mask so attention ignores PAD positions
pad_mask = (seq_tensor == pad_id)
with torch.no_grad():
logits = model(seq_tensor, t, src_key_padding_mask=pad_mask)
# Top-k filtering β€” zero out all but top-k logits
response_logits = logits[0, mask_indices] # [num_masks, vocab_size]
if top_k > 0:
top_k_vals, _ = torch.topk(response_logits, top_k, dim=-1)
min_top_k = top_k_vals[:, -1].unsqueeze(-1)
response_logits = response_logits.masked_fill(response_logits < min_top_k, float('-inf'))
# Temperature scaling and sampling
scaled = response_logits / max(temperature, 1e-6)
probs = F.softmax(scaled, dim=-1)
# At final step use greedy (argmax) for cleaner output
if step == sampling_steps:
predicted = torch.argmax(probs, dim=-1)
else:
predicted = torch.multinomial(probs, 1).squeeze(-1)
# Confidence scoring (on unscaled logits for reliable scores)
true_probs = F.softmax(response_logits, dim=-1)
confidences = true_probs[torch.arange(num_masks), predicted]
# Update running confidence (exponential moving average)
# This smooths out noisy single-step confidence and gives a better
# signal for which tokens should be revealed vs remasked
running_confidence = 0.7 * running_confidence + 0.3 * confidences
current = seq_tensor.squeeze(0).clone()
for i, idx in enumerate(mask_indices):
current[idx] = predicted[i]
# Progressive remasking based on running confidence
if step < sampling_steps:
# Reveal more tokens as we progress
target_revealed = int(num_masks * step / sampling_steps)
num_remask = num_masks - target_revealed
if num_remask > 0:
# Re-mask the tokens with LOWEST running confidence
_, low_idx = torch.topk(running_confidence, k=num_remask, largest=False)
for li in low_idx:
current[mask_indices[li]] = mask_id
seq_tensor = current.unsqueeze(0)
# Decode final response
response_ids = seq_tensor[0][response_start:response_end].tolist()
return decode_response(response_ids, tokenizer)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Inference on: {device}\n")
tokenizer, vocab, id2word = load_tokenizer_full()
# Must match train.py architecture exactly
model = MaskedDiffusionModel(
vocab_size=len(vocab),
d_model=256, nhead=8, num_layers=6,
max_seq_len=MAX_SEQ_LENGTH,
dropout=0.0, # No dropout at inference
).to(device)
# Prefer EMA checkpoint (saves the smoothed weights)
ckpt = "diffusion_model_best.pth" if os.path.exists("diffusion_model_best.pth") else "diffusion_model.pth"
try:
model.load_state_dict(torch.load(ckpt, map_location=device))
print(f"Loaded: {ckpt} (EMA weights)\n")
except Exception as e:
print(f"Error loading checkpoint: {e}")
print("Make sure to retrain with the new architecture first!")
exit()
test_prompts = [
"hi",
"how are you",
"what is your name",
"tell me a joke",
"what do you do for fun",
"i had a bad day",
" Can we go now?",
]
for prompt in test_prompts:
response = generate_response(
model, tokenizer, id2word,
prompt=prompt,
max_response_length=24,
sampling_steps=40,
temperature=0.5,
top_k=15
)
print(f"User : {prompt}")
print(f"Bot : {response}")
print("-" * 40)