| import argparse | |
| from pathlib import Path | |
| import torch | |
| import yaml | |
| from llm.data.tokenizer import CharTokenizer | |
| from llm.inference.generate import greedy_decode, sample_decode | |
| from llm.model.transformer import Transformer | |
| from llm.utils.checkpoint import load_model_only | |
| def load_yaml(path: Path): | |
| with open(path, "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--prompt", type=str, default="First Citizen:\\n") | |
| parser.add_argument("--max_length", type=int, default=200) | |
| parser.add_argument("--temperature", type=float, default=0.8) | |
| parser.add_argument("--top_k", type=int, default=50) | |
| parser.add_argument("--top_p", type=float, default=0.9) | |
| parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt") | |
| parser.add_argument("--config", type=str, default="configs/model.yaml") | |
| parser.add_argument("--vocab", type=str, default="data/vocab/char_vocab.json") | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_config = load_yaml(Path(args.config)) | |
| tokenizer = CharTokenizer(vocab_path=args.vocab) | |
| model_config["vocab_size"] = tokenizer.vocab_size | |
| model = Transformer(model_config) | |
| load_model_only(model, args.checkpoint) | |
| model.to(device) | |
| model.eval() | |
| input_ids = tokenizer.encode(args.prompt) | |
| if not input_ids: | |
| input_ids = [0] | |
| input_ids = torch.tensor([input_ids], dtype=torch.long) | |
| with torch.no_grad(): | |
| if args.temperature == 0: | |
| generated_ids = greedy_decode( | |
| model, input_ids, max_length=args.max_length, device=device | |
| ) | |
| else: | |
| generated_ids = sample_decode( | |
| model, | |
| input_ids, | |
| max_length=args.max_length, | |
| temperature=args.temperature, | |
| top_k=args.top_k if args.top_k > 0 else None, | |
| top_p=args.top_p if args.top_p > 0 else None, | |
| device=device, | |
| ) | |
| text = tokenizer.decode(generated_ids[0]) | |
| print(text) | |
| if __name__ == "__main__": | |
| main() | |