File size: 3,531 Bytes
ca7da53 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | """
generate.py
===========
Interactive text generation with the trained MiniLM model.
Type a prompt and the model will complete it.
Type 'quit' or press Ctrl+C to exit.
Author : André Costa
License : MIT
Usage:
python3 generate.py
python3 generate.py --max-tokens 100
python3 generate.py --temperature 0.9 --top-k 50
"""
import argparse
import torch
from transformer import MiniLM, ModelConfig
from bpe_tokenizer import BPETokenizer
def load_model(checkpoint_path: str, tokenizer_path: str):
"""Load the trained model and tokenizer."""
print("Loading tokenizer...")
tokenizer = BPETokenizer.load(tokenizer_path)
print("Loading model...")
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
cfg_dict = ckpt["model_config"]
cfg_dict.pop("d_head", None)
config = ModelConfig(**cfg_dict)
model = MiniLM(config)
state_dict = ckpt["model_state"]
if any(k.startswith("_orig_mod.") for k in state_dict):
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model ready — {config.n_params / 1e6:.1f}M parameters | device: {device}")
print(f"Vocab: {config.vocab_size} tokens | Context: {config.seq_len} tokens\n")
return model, tokenizer, device
def generate(
model,
tokenizer,
device,
prompt: str,
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
) -> str:
"""Generate text from a prompt."""
input_ids = torch.tensor(
[tokenizer.encode(prompt)],
dtype=torch.long,
device=device,
)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
return tokenizer.decode(output[0].tolist())
def main():
parser = argparse.ArgumentParser(description="MiniLM — Interactive text generation")
parser.add_argument("--checkpoint", type=str, default="./checkpoints/best_model.pt")
parser.add_argument("--tokenizer", type=str, default="./tokenizer")
parser.add_argument("--max-tokens", type=int, default=80)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-k", type=int, default=50)
parser.add_argument("--top-p", type=float, default=0.9)
args = parser.parse_args()
model, tokenizer, device = load_model(args.checkpoint, args.tokenizer)
print("=" * 55)
print(" MiniLM — Text Generation")
print(" Type a prompt and press Enter.")
print(" Type 'quit' to exit.")
print("=" * 55)
print()
while True:
try:
prompt = input("Prompt: ").strip()
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
if not prompt:
continue
if prompt.lower() in ("quit", "exit", "q"):
print("Goodbye!")
break
result = generate(
model, tokenizer, device,
prompt=prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
print(f"\n{result}\n")
print("-" * 55)
if __name__ == "__main__":
main()
|