nextShakespeare / inference.py
LiManshu's picture
Add files using upload-large-folder tool
bf6be45 verified
import argparse
from pathlib import Path
import torch
import yaml
from llm.data.tokenizer import CharTokenizer
from llm.inference.generate import greedy_decode, sample_decode
from llm.model.transformer import Transformer
from llm.utils.checkpoint import load_model_only
def load_yaml(path: Path):
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str, default="First Citizen:\\n")
parser.add_argument("--max_length", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--checkpoint", type=str, default="checkpoints/best_model.pt")
parser.add_argument("--config", type=str, default="configs/model.yaml")
parser.add_argument("--vocab", type=str, default="data/vocab/char_vocab.json")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config = load_yaml(Path(args.config))
tokenizer = CharTokenizer(vocab_path=args.vocab)
model_config["vocab_size"] = tokenizer.vocab_size
model = Transformer(model_config)
load_model_only(model, args.checkpoint)
model.to(device)
model.eval()
input_ids = tokenizer.encode(args.prompt)
if not input_ids:
input_ids = [0]
input_ids = torch.tensor([input_ids], dtype=torch.long)
with torch.no_grad():
if args.temperature == 0:
generated_ids = greedy_decode(
model, input_ids, max_length=args.max_length, device=device
)
else:
generated_ids = sample_decode(
model,
input_ids,
max_length=args.max_length,
temperature=args.temperature,
top_k=args.top_k if args.top_k > 0 else None,
top_p=args.top_p if args.top_p > 0 else None,
device=device,
)
text = tokenizer.decode(generated_ids[0])
print(text)
if __name__ == "__main__":
main()