| import torch |
| import torch.nn.functional as F |
| from tokenizers import Tokenizer |
| from train import MaskedDiffusionModel |
| from tokenizer import MAX_SEQ_LENGTH |
| import os |
|
|
|
|
| def load_tokenizer_full(vocab_file="subword_tokenizer.json"): |
| tokenizer = Tokenizer.from_file(vocab_file) |
| vocab = tokenizer.get_vocab() |
| id2word = {int(v): k for k, v in vocab.items()} |
| return tokenizer, vocab, id2word |
|
|
|
|
| def decode_response(token_ids, tokenizer): |
| """ |
| Use the tokenizer's built-in decode() for proper BPE subword handling. |
| Filter out special tokens before decoding. |
| """ |
| special_ids = { |
| tokenizer.token_to_id("[PAD]"), |
| tokenizer.token_to_id("[BOS]"), |
| tokenizer.token_to_id("[EOS]"), |
| tokenizer.token_to_id("[MASK]"), |
| tokenizer.token_to_id("[UNK]"), |
| } |
|
|
| filtered_ids = [tid for tid in token_ids if tid not in special_ids] |
|
|
| if not filtered_ids: |
| return "" |
|
|
| |
| text = tokenizer.decode(filtered_ids) |
|
|
| |
| text = text.strip() |
| |
| for p in [".", ",", "?", "!", "'", ":"]: |
| text = text.replace(f" {p}", p) |
|
|
| return text |
|
|
|
|
| def generate_response( |
| model, tokenizer, id2word, |
| prompt, |
| max_response_length=24, |
| sampling_steps=40, |
| temperature=0.5, |
| top_k=15 |
| ): |
| """ |
| Timestep-aware iterative denoising with running confidence remasking. |
| |
| Key differences from old version: |
| 1. Passes timestep t to model at each step (model now knows what stage it's at) |
| 2. t = fraction of tokens still masked (starts ~1.0, ends ~0.0) |
| 3. Running confidence remasking: tracks cumulative confidence per token |
| """ |
| model.eval() |
| device = next(model.parameters()).device |
|
|
| bos_id = tokenizer.token_to_id("[BOS]") |
| eos_id = tokenizer.token_to_id("[EOS]") |
| mask_id = tokenizer.token_to_id("[MASK]") |
| pad_id = tokenizer.token_to_id("[PAD]") |
|
|
| formatted = f"user: {prompt.lower().strip()} bot:" |
| input_ids = tokenizer.encode(formatted).ids |
|
|
| |
| max_resp = min(max_response_length, MAX_SEQ_LENGTH - len(input_ids) - 2) |
| if max_resp <= 0: |
| print("Prompt too long.") |
| return "" |
|
|
| sequence = [bos_id] + input_ids + [mask_id] * max_resp + [eos_id] |
| sequence += [pad_id] * (MAX_SEQ_LENGTH - len(sequence)) |
| seq_tensor = torch.tensor([sequence], dtype=torch.long, device=device) |
|
|
| response_start = 1 + len(input_ids) |
| response_end = response_start + max_resp |
| mask_indices = list(range(response_start, response_end)) |
| num_masks = len(mask_indices) |
|
|
| |
| |
| running_confidence = torch.zeros(num_masks, device=device) |
|
|
| for step in range(1, sampling_steps + 1): |
| |
| |
| |
| |
| t_val = max(1.0 - step / sampling_steps, 0.05) |
| t = torch.tensor([t_val], device=device) |
|
|
| |
| pad_mask = (seq_tensor == pad_id) |
|
|
| with torch.no_grad(): |
| logits = model(seq_tensor, t, src_key_padding_mask=pad_mask) |
|
|
| |
| response_logits = logits[0, mask_indices] |
| if top_k > 0: |
| top_k_vals, _ = torch.topk(response_logits, top_k, dim=-1) |
| min_top_k = top_k_vals[:, -1].unsqueeze(-1) |
| response_logits = response_logits.masked_fill(response_logits < min_top_k, float('-inf')) |
|
|
| |
| scaled = response_logits / max(temperature, 1e-6) |
| probs = F.softmax(scaled, dim=-1) |
|
|
| |
| if step == sampling_steps: |
| predicted = torch.argmax(probs, dim=-1) |
| else: |
| predicted = torch.multinomial(probs, 1).squeeze(-1) |
|
|
| |
| true_probs = F.softmax(response_logits, dim=-1) |
| confidences = true_probs[torch.arange(num_masks), predicted] |
|
|
| |
| |
| |
| running_confidence = 0.7 * running_confidence + 0.3 * confidences |
|
|
| current = seq_tensor.squeeze(0).clone() |
| for i, idx in enumerate(mask_indices): |
| current[idx] = predicted[i] |
|
|
| |
| if step < sampling_steps: |
| |
| target_revealed = int(num_masks * step / sampling_steps) |
| num_remask = num_masks - target_revealed |
| if num_remask > 0: |
| |
| _, low_idx = torch.topk(running_confidence, k=num_remask, largest=False) |
| for li in low_idx: |
| current[mask_indices[li]] = mask_id |
|
|
| seq_tensor = current.unsqueeze(0) |
|
|
| |
| response_ids = seq_tensor[0][response_start:response_end].tolist() |
| return decode_response(response_ids, tokenizer) |
|
|
|
|
| if __name__ == "__main__": |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Inference on: {device}\n") |
|
|
| tokenizer, vocab, id2word = load_tokenizer_full() |
|
|
| |
| model = MaskedDiffusionModel( |
| vocab_size=len(vocab), |
| d_model=256, nhead=8, num_layers=6, |
| max_seq_len=MAX_SEQ_LENGTH, |
| dropout=0.0, |
| ).to(device) |
|
|
| |
| ckpt = "diffusion_model_best.pth" if os.path.exists("diffusion_model_best.pth") else "diffusion_model.pth" |
| try: |
| model.load_state_dict(torch.load(ckpt, map_location=device)) |
| print(f"Loaded: {ckpt} (EMA weights)\n") |
| except Exception as e: |
| print(f"Error loading checkpoint: {e}") |
| print("Make sure to retrain with the new architecture first!") |
| exit() |
|
|
| test_prompts = [ |
| "hi", |
| "how are you", |
| "what is your name", |
| "tell me a joke", |
| "what do you do for fun", |
| "i had a bad day", |
| " Can we go now?", |
| ] |
|
|
| for prompt in test_prompts: |
| response = generate_response( |
| model, tokenizer, id2word, |
| prompt=prompt, |
| max_response_length=24, |
| sampling_steps=40, |
| temperature=0.5, |
| top_k=15 |
| ) |
| print(f"User : {prompt}") |
| print(f"Bot : {response}") |
| print("-" * 40) |