#!/usr/bin/env python3 """Run inference from the Hugging Face model repo without cloning it. Usage: pip install torch tokenizers huggingface-hub python examples/inference_from_hf.py "Polska jest" 80 """ from __future__ import annotations import importlib.util import sys import time from pathlib import Path import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from tokenizers import Tokenizer REPO_ID = "SlayerLab/slayer-gpt-tokenizer-model" TEMP = 0.7 TOP_K = 40 TOP_P = 0.92 REP_PEN = 1.15 NGRAM = 3 EOT = 0 def load_model_module(path: str): spec = importlib.util.spec_from_file_location("slayer_gpt_model", path) if spec is None or spec.loader is None: raise RuntimeError(f"Could not load model module from {path}") module = importlib.util.module_from_spec(spec) sys.modules[spec.name] = module spec.loader.exec_module(module) return module def banned_next_tokens(seq: list[int], n: int) -> set[int]: if len(seq) < n - 1: return set() prefix = tuple(seq[-(n - 1):]) banned: set[int] = set() for i in range(len(seq) - n + 1): if tuple(seq[i:i + n - 1]) == prefix: banned.add(seq[i + n - 1]) return banned @torch.no_grad() def generate(model, tokenizer: Tokenizer, prompt: str, max_new_tokens: int, block_size: int, device: str) -> tuple[str, float]: idx = torch.tensor(tokenizer.encode(prompt).ids, dtype=torch.long, device=device)[None] start = time.time() generated = 0 for _ in range(max_new_tokens): logits, _ = model(idx[:, -block_size:]) logits = logits[:, -1, :].float() for token_id in set(idx[0].tolist()): logits[0, token_id] /= REP_PEN if logits[0, token_id] > 0 else 1 / REP_PEN for token_id in banned_next_tokens(idx[0].tolist(), NGRAM): logits[0, token_id] = -float("inf") logits /= TEMP kth = torch.topk(logits, TOP_K)[0][..., -1, None] logits[logits < kth] = -float("inf") sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) remove = cumulative > TOP_P remove[..., 1:] = remove[..., :-1].clone() remove[..., 0] = False logits[0, sorted_indices[0][remove[0]]] = -float("inf") next_id = torch.multinomial(F.softmax(logits, dim=-1), 1) generated += 1 if next_id.item() == EOT: break idx = torch.cat([idx, next_id], dim=1) tokens_per_second = generated / max(time.time() - start, 1e-6) return tokenizer.decode(idx[0].tolist()), tokens_per_second def main() -> None: prompt = sys.argv[1] if len(sys.argv) > 1 else "Polska jest" max_new_tokens = int(sys.argv[2]) if len(sys.argv) > 2 else 80 device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" model_py = hf_hub_download(REPO_ID, "scripts/model.py") ckpt_path = hf_hub_download(REPO_ID, "model/ckpt.pt") tokenizer_path = hf_hub_download(REPO_ID, "tokenizers/polish_bpe_32k.json") model_module = load_model_module(model_py) ckpt = torch.load(ckpt_path, map_location="cpu") model = model_module.GPT(model_module.GPTConfig(**ckpt["model_args"])) state_dict = ckpt["model"] for key in list(state_dict): if key.startswith("_orig_mod."): state_dict[key[len("_orig_mod."):]] = state_dict.pop(key) model.load_state_dict(state_dict) model.eval().to(device) tokenizer = Tokenizer.from_file(tokenizer_path) text, tps = generate( model, tokenizer, prompt, max_new_tokens, ckpt["model_args"]["block_size"], device, ) print(f"[repo={REPO_ID} device={device} {tps:.1f} tok/s]\n") print(text) if __name__ == "__main__": main()