| |
| """Run inference from the Hugging Face model repo without cloning it. |
| |
| Usage: |
| pip install torch tokenizers huggingface-hub |
| python examples/inference_from_hf.py "Polska jest" 80 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import importlib.util |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import hf_hub_download |
| from tokenizers import Tokenizer |
|
|
|
|
| REPO_ID = "SlayerLab/slayer-gpt-tokenizer-model" |
| TEMP = 0.7 |
| TOP_K = 40 |
| TOP_P = 0.92 |
| REP_PEN = 1.15 |
| NGRAM = 3 |
| EOT = 0 |
|
|
|
|
| def load_model_module(path: str): |
| spec = importlib.util.spec_from_file_location("slayer_gpt_model", path) |
| if spec is None or spec.loader is None: |
| raise RuntimeError(f"Could not load model module from {path}") |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[spec.name] = module |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| def banned_next_tokens(seq: list[int], n: int) -> set[int]: |
| if len(seq) < n - 1: |
| return set() |
| prefix = tuple(seq[-(n - 1):]) |
| banned: set[int] = set() |
| for i in range(len(seq) - n + 1): |
| if tuple(seq[i:i + n - 1]) == prefix: |
| banned.add(seq[i + n - 1]) |
| return banned |
|
|
|
|
| @torch.no_grad() |
| def generate(model, tokenizer: Tokenizer, prompt: str, max_new_tokens: int, block_size: int, device: str) -> tuple[str, float]: |
| idx = torch.tensor(tokenizer.encode(prompt).ids, dtype=torch.long, device=device)[None] |
| start = time.time() |
| generated = 0 |
|
|
| for _ in range(max_new_tokens): |
| logits, _ = model(idx[:, -block_size:]) |
| logits = logits[:, -1, :].float() |
|
|
| for token_id in set(idx[0].tolist()): |
| logits[0, token_id] /= REP_PEN if logits[0, token_id] > 0 else 1 / REP_PEN |
|
|
| for token_id in banned_next_tokens(idx[0].tolist(), NGRAM): |
| logits[0, token_id] = -float("inf") |
|
|
| logits /= TEMP |
| kth = torch.topk(logits, TOP_K)[0][..., -1, None] |
| logits[logits < kth] = -float("inf") |
|
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumulative = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = cumulative > TOP_P |
| remove[..., 1:] = remove[..., :-1].clone() |
| remove[..., 0] = False |
| logits[0, sorted_indices[0][remove[0]]] = -float("inf") |
|
|
| next_id = torch.multinomial(F.softmax(logits, dim=-1), 1) |
| generated += 1 |
| if next_id.item() == EOT: |
| break |
| idx = torch.cat([idx, next_id], dim=1) |
|
|
| tokens_per_second = generated / max(time.time() - start, 1e-6) |
| return tokenizer.decode(idx[0].tolist()), tokens_per_second |
|
|
|
|
| def main() -> None: |
| prompt = sys.argv[1] if len(sys.argv) > 1 else "Polska jest" |
| max_new_tokens = int(sys.argv[2]) if len(sys.argv) > 2 else 80 |
| device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| model_py = hf_hub_download(REPO_ID, "scripts/model.py") |
| ckpt_path = hf_hub_download(REPO_ID, "model/ckpt.pt") |
| tokenizer_path = hf_hub_download(REPO_ID, "tokenizers/polish_bpe_32k.json") |
|
|
| model_module = load_model_module(model_py) |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
| model = model_module.GPT(model_module.GPTConfig(**ckpt["model_args"])) |
| state_dict = ckpt["model"] |
| for key in list(state_dict): |
| if key.startswith("_orig_mod."): |
| state_dict[key[len("_orig_mod."):]] = state_dict.pop(key) |
|
|
| model.load_state_dict(state_dict) |
| model.eval().to(device) |
| tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
| text, tps = generate( |
| model, |
| tokenizer, |
| prompt, |
| max_new_tokens, |
| ckpt["model_args"]["block_size"], |
| device, |
| ) |
| print(f"[repo={REPO_ID} device={device} {tps:.1f} tok/s]\n") |
| print(text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|