| """ |
| generate.py |
| =========== |
| Interactive text generation with the trained MiniLM model. |
| |
| Type a prompt and the model will complete it. |
| Type 'quit' or press Ctrl+C to exit. |
| |
| Author : André Costa |
| License : MIT |
| |
| Usage: |
| python3 generate.py |
| python3 generate.py --max-tokens 100 |
| python3 generate.py --temperature 0.9 --top-k 50 |
| """ |
|
|
| import argparse |
| import torch |
| from transformer import MiniLM, ModelConfig |
| from bpe_tokenizer import BPETokenizer |
|
|
|
|
| def load_model(checkpoint_path: str, tokenizer_path: str): |
| """Load the trained model and tokenizer.""" |
|
|
| print("Loading tokenizer...") |
| tokenizer = BPETokenizer.load(tokenizer_path) |
|
|
| print("Loading model...") |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
|
|
| cfg_dict = ckpt["model_config"] |
| cfg_dict.pop("d_head", None) |
| config = ModelConfig(**cfg_dict) |
|
|
| model = MiniLM(config) |
|
|
| state_dict = ckpt["model_state"] |
| if any(k.startswith("_orig_mod.") for k in state_dict): |
| state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device) |
|
|
| print(f"Model ready — {config.n_params / 1e6:.1f}M parameters | device: {device}") |
| print(f"Vocab: {config.vocab_size} tokens | Context: {config.seq_len} tokens\n") |
|
|
| return model, tokenizer, device |
|
|
|
|
| def generate( |
| model, |
| tokenizer, |
| device, |
| prompt: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_k: int, |
| top_p: float, |
| ) -> str: |
| """Generate text from a prompt.""" |
| input_ids = torch.tensor( |
| [tokenizer.encode(prompt)], |
| dtype=torch.long, |
| device=device, |
| ) |
|
|
| with torch.no_grad(): |
| output = model.generate( |
| input_ids, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| ) |
|
|
| return tokenizer.decode(output[0].tolist()) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MiniLM — Interactive text generation") |
| parser.add_argument("--checkpoint", type=str, default="./checkpoints/best_model.pt") |
| parser.add_argument("--tokenizer", type=str, default="./tokenizer") |
| parser.add_argument("--max-tokens", type=int, default=80) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-k", type=int, default=50) |
| parser.add_argument("--top-p", type=float, default=0.9) |
| args = parser.parse_args() |
|
|
| model, tokenizer, device = load_model(args.checkpoint, args.tokenizer) |
|
|
| print("=" * 55) |
| print(" MiniLM — Text Generation") |
| print(" Type a prompt and press Enter.") |
| print(" Type 'quit' to exit.") |
| print("=" * 55) |
| print() |
|
|
| while True: |
| try: |
| prompt = input("Prompt: ").strip() |
| except (KeyboardInterrupt, EOFError): |
| print("\nGoodbye!") |
| break |
|
|
| if not prompt: |
| continue |
|
|
| if prompt.lower() in ("quit", "exit", "q"): |
| print("Goodbye!") |
| break |
|
|
| result = generate( |
| model, tokenizer, device, |
| prompt=prompt, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| ) |
|
|
| print(f"\n{result}\n") |
| print("-" * 55) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|