slayer-gpt-tokenizer-model / examples /inference_from_hf.py
kacperwikiel's picture
Upload Slayer GPT tokenizer model archive
4012ebc verified
Raw
History Blame Contribute Delete
3.9 kB
#!/usr/bin/env python3
"""Run inference from the Hugging Face model repo without cloning it.
Usage:
pip install torch tokenizers huggingface-hub
python examples/inference_from_hf.py "Polska jest" 80
"""
from __future__ import annotations
import importlib.util
import sys
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
REPO_ID = "SlayerLab/slayer-gpt-tokenizer-model"
TEMP = 0.7
TOP_K = 40
TOP_P = 0.92
REP_PEN = 1.15
NGRAM = 3
EOT = 0
def load_model_module(path: str):
spec = importlib.util.spec_from_file_location("slayer_gpt_model", path)
if spec is None or spec.loader is None:
raise RuntimeError(f"Could not load model module from {path}")
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
def banned_next_tokens(seq: list[int], n: int) -> set[int]:
if len(seq) < n - 1:
return set()
prefix = tuple(seq[-(n - 1):])
banned: set[int] = set()
for i in range(len(seq) - n + 1):
if tuple(seq[i:i + n - 1]) == prefix:
banned.add(seq[i + n - 1])
return banned
@torch.no_grad()
def generate(model, tokenizer: Tokenizer, prompt: str, max_new_tokens: int, block_size: int, device: str) -> tuple[str, float]:
idx = torch.tensor(tokenizer.encode(prompt).ids, dtype=torch.long, device=device)[None]
start = time.time()
generated = 0
for _ in range(max_new_tokens):
logits, _ = model(idx[:, -block_size:])
logits = logits[:, -1, :].float()
for token_id in set(idx[0].tolist()):
logits[0, token_id] /= REP_PEN if logits[0, token_id] > 0 else 1 / REP_PEN
for token_id in banned_next_tokens(idx[0].tolist(), NGRAM):
logits[0, token_id] = -float("inf")
logits /= TEMP
kth = torch.topk(logits, TOP_K)[0][..., -1, None]
logits[logits < kth] = -float("inf")
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cumulative > TOP_P
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
logits[0, sorted_indices[0][remove[0]]] = -float("inf")
next_id = torch.multinomial(F.softmax(logits, dim=-1), 1)
generated += 1
if next_id.item() == EOT:
break
idx = torch.cat([idx, next_id], dim=1)
tokens_per_second = generated / max(time.time() - start, 1e-6)
return tokenizer.decode(idx[0].tolist()), tokens_per_second
def main() -> None:
prompt = sys.argv[1] if len(sys.argv) > 1 else "Polska jest"
max_new_tokens = int(sys.argv[2]) if len(sys.argv) > 2 else 80
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
model_py = hf_hub_download(REPO_ID, "scripts/model.py")
ckpt_path = hf_hub_download(REPO_ID, "model/ckpt.pt")
tokenizer_path = hf_hub_download(REPO_ID, "tokenizers/polish_bpe_32k.json")
model_module = load_model_module(model_py)
ckpt = torch.load(ckpt_path, map_location="cpu")
model = model_module.GPT(model_module.GPTConfig(**ckpt["model_args"]))
state_dict = ckpt["model"]
for key in list(state_dict):
if key.startswith("_orig_mod."):
state_dict[key[len("_orig_mod."):]] = state_dict.pop(key)
model.load_state_dict(state_dict)
model.eval().to(device)
tokenizer = Tokenizer.from_file(tokenizer_path)
text, tps = generate(
model,
tokenizer,
prompt,
max_new_tokens,
ckpt["model_args"]["block_size"],
device,
)
print(f"[repo={REPO_ID} device={device} {tps:.1f} tok/s]\n")
print(text)
if __name__ == "__main__":
main()