| import torch
|
| import torch.nn.functional as F
|
| from model import MiniGPT
|
| from dataset import MiniBPETokenizr,SimpleTokenizr
|
| import json
|
| import os
|
| from tokenizers import Tokenizer
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
| tokenizer = Tokenizer.from_file("./trained-mini-gpt/tokenizer.json")
|
|
|
|
|
| model = MiniGPT(vocab_size=tokenizer.get_vocab_size())
|
|
|
| checkpoint = torch.load("./trained-mini-gpt/mini-gpt.pth", map_location=device)
|
| model.load_state_dict(checkpoint)
|
| model.eval().to(device)
|
| totalparams = sum(p.numel() for p in model.parameters())
|
| print(f"Model total params: {totalparams:,}")
|
|
|
| def sample_token(logits, temperature=1.0):
|
| logits = logits / temperature
|
| logits = torch.nan_to_num(logits, nan=-1e9)
|
| probs = F.softmax(logits, dim=-1)
|
|
|
| if torch.any(torch.isnan(probs)) or torch.any(probs < 0):
|
| print("⚠️ Invalid probs detected. Using uniform fallback.")
|
| probs = torch.ones_like(probs) / probs.size(-1)
|
|
|
| return torch.multinomial(probs, num_samples=1).item()
|
|
|
| def generate_reply(prompt, max_tokens=100):
|
| tokens = tokenizer.encode(prompt).ids
|
| if not tokens:
|
| print("⚠️ Empty prompt after encoding.")
|
| return
|
| input_ids = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
|
| generated = []
|
|
|
| with torch.no_grad():
|
| for _ in range(max_tokens):
|
| logits = model(input_ids)
|
| logits = logits[:, -1, :]
|
| next_token = sample_token(logits)
|
| generated.append(next_token)
|
|
|
| next_str = tokenizer.id_to_token(next_token)
|
| encoded_text = tokenizer.encode(next_str).ids
|
| decoded_text = tokenizer.decode(encoded_text)
|
| print(decoded_text, end=" ", flush=True)
|
|
|
| if next_str == "<END>":
|
| break
|
|
|
| input_ids = torch.cat([input_ids, torch.tensor([[next_token]]).to(device)], dim=1)
|
| print()
|
|
|
|
|
| print("🧠 MiniGPT Chat (type 'exit' to quit')")
|
| while True:
|
| user_input = input("User: ")
|
| if user_input.lower() == "exit":
|
| break
|
| prompt = f"^User: {user_input}\nMiniGPT:"
|
| print("MiniGPT: ", end="", flush=True)
|
| generate_reply(prompt) |