import argparse from pathlib import Path import torch import yaml from llm.data.tokenizer import CharTokenizer from llm.inference.generate import greedy_decode, sample_decode from llm.model.transformer import Transformer from llm.utils.checkpoint import load_model_only def load_yaml(path: Path): with open(path, "r", encoding="utf-8") as f: return yaml.safe_load(f) def main(): parser = argparse.ArgumentParser() parser.add_argument("--prompt", type=str, default="First Citizen:\\n") parser.add_argument("--max_length", type=int, default=200) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt") parser.add_argument("--config", type=str, default="configs/model.yaml") parser.add_argument("--vocab", type=str, default="data/vocab/char_vocab.json") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_config = load_yaml(Path(args.config)) tokenizer = CharTokenizer(vocab_path=args.vocab) model_config["vocab_size"] = tokenizer.vocab_size model = Transformer(model_config) load_model_only(model, args.checkpoint) model.to(device) model.eval() input_ids = tokenizer.encode(args.prompt) if not input_ids: input_ids = [0] input_ids = torch.tensor([input_ids], dtype=torch.long) with torch.no_grad(): if args.temperature == 0: generated_ids = greedy_decode( model, input_ids, max_length=args.max_length, device=device ) else: generated_ids = sample_decode( model, input_ids, max_length=args.max_length, temperature=args.temperature, top_k=args.top_k if args.top_k > 0 else None, top_p=args.top_p if args.top_p > 0 else None, device=device, ) text = tokenizer.decode(generated_ids[0]) print(text) if __name__ == "__main__": main()