File size: 2,190 Bytes
18be545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | from __future__ import annotations
import argparse
import torch
from src.model import GPTLanguageModel, config_from_dict
from src.tokenizer import VisdomTokenizer
from src.utils import get_device, load_json, resolve_path, set_seed
def main() -> None:
parser = argparse.ArgumentParser(description="Generate text from a VISDOM checkpoint.")
parser.add_argument("--checkpoint", default="checkpoints/latest.pt")
parser.add_argument("--prompt", default="The future of AI is")
parser.add_argument("--max_new_tokens", type=int, default=120)
parser.add_argument("--temperature", type=float, default=0.6)
parser.add_argument("--top_k", type=int, default=20)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--repetition_penalty", type=float, default=1.15)
parser.add_argument("--seed", type=int, default=1337)
args = parser.parse_args()
set_seed(args.seed)
ckpt_path = resolve_path(args.checkpoint)
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}. Train first with python train.py --config config.yaml")
checkpoint = torch.load(ckpt_path, map_location="cpu")
cfg = checkpoint["config"]
device = get_device(str(cfg.get("device", "cuda")))
meta = load_json(cfg["meta_file"])
cfg["vocab_size"] = int(meta["vocab_size"])
tokenizer = VisdomTokenizer(meta["tokenizer_model"])
model = GPTLanguageModel(config_from_dict(cfg))
model.load_state_dict(checkpoint["model_state_dict"])
model.eval().to(device)
ids = tokenizer.encode(args.prompt, add_bos=True)
x = torch.tensor(ids, dtype=torch.long, device=device)[None, ...]
with torch.no_grad():
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=device.type == "cuda"):
y = model.generate(
x,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
print(tokenizer.decode(y[0].tolist()))
if __name__ == "__main__":
main()
|