minilm / generate.py
AndreCosta's picture
Upload generate.py with huggingface_hub
ca7da53 verified
"""
generate.py
===========
Interactive text generation with the trained MiniLM model.
Type a prompt and the model will complete it.
Type 'quit' or press Ctrl+C to exit.
Author : André Costa
License : MIT
Usage:
python3 generate.py
python3 generate.py --max-tokens 100
python3 generate.py --temperature 0.9 --top-k 50
"""
import argparse
import torch
from transformer import MiniLM, ModelConfig
from bpe_tokenizer import BPETokenizer
def load_model(checkpoint_path: str, tokenizer_path: str):
"""Load the trained model and tokenizer."""
print("Loading tokenizer...")
tokenizer = BPETokenizer.load(tokenizer_path)
print("Loading model...")
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
cfg_dict = ckpt["model_config"]
cfg_dict.pop("d_head", None)
config = ModelConfig(**cfg_dict)
model = MiniLM(config)
state_dict = ckpt["model_state"]
if any(k.startswith("_orig_mod.") for k in state_dict):
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model ready — {config.n_params / 1e6:.1f}M parameters | device: {device}")
print(f"Vocab: {config.vocab_size} tokens | Context: {config.seq_len} tokens\n")
return model, tokenizer, device
def generate(
model,
tokenizer,
device,
prompt: str,
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
) -> str:
"""Generate text from a prompt."""
input_ids = torch.tensor(
[tokenizer.encode(prompt)],
dtype=torch.long,
device=device,
)
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
return tokenizer.decode(output[0].tolist())
def main():
parser = argparse.ArgumentParser(description="MiniLM — Interactive text generation")
parser.add_argument("--checkpoint", type=str, default="./checkpoints/best_model.pt")
parser.add_argument("--tokenizer", type=str, default="./tokenizer")
parser.add_argument("--max-tokens", type=int, default=80)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-k", type=int, default=50)
parser.add_argument("--top-p", type=float, default=0.9)
args = parser.parse_args()
model, tokenizer, device = load_model(args.checkpoint, args.tokenizer)
print("=" * 55)
print(" MiniLM — Text Generation")
print(" Type a prompt and press Enter.")
print(" Type 'quit' to exit.")
print("=" * 55)
print()
while True:
try:
prompt = input("Prompt: ").strip()
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
if not prompt:
continue
if prompt.lower() in ("quit", "exit", "q"):
print("Goodbye!")
break
result = generate(
model, tokenizer, device,
prompt=prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
print(f"\n{result}\n")
print("-" * 55)
if __name__ == "__main__":
main()