| """ |
| inference_hf.py β Self-contained inference script for Erebus models on HuggingFace. |
| |
| This file has zero dependency on the rest of the erebus repo. |
| Copy it anywhere and run it as long as you have: |
| pip install torch tiktoken huggingface_hub safetensors |
| |
| Usage |
| ----- |
| # From HuggingFace Hub |
| python inference_hf.py --hf_repo Rzoro/erebus-small --prompt "The future of AI" |
| |
| # Interactive |
| python inference_hf.py --hf_repo Rzoro/erebus-small --interactive |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
|
|
| @dataclass |
| class ErebusConfig: |
| vocab_size: int = 50257 |
| d_model: int = 768 |
| n_heads: int = 12 |
| n_layers: int = 12 |
| d_ff: int = 3072 |
| max_seq_len: int = 1024 |
| dropout: float = 0.0 |
|
|
|
|
| class RotaryPositionEmbedding(nn.Module): |
| def __init__(self, head_dim: int, max_seq_len: int = 4096): |
| super().__init__() |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| positions = torch.arange(max_seq_len).float() |
| freqs = torch.outer(positions, inv_freq) |
| cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(0).unsqueeze(0) |
| sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(0).unsqueeze(0) |
| self.register_buffer("cos_cached", cos, persistent=False) |
| self.register_buffer("sin_cached", sin, persistent=False) |
|
|
| @staticmethod |
| def _rotate_half(x): |
| x1, x2 = x[..., 0::2], x[..., 1::2] |
| return torch.stack([-x2, x1], dim=-1).flatten(-2) |
|
|
| def forward(self, q, k): |
| T = q.size(2) |
| cos, sin = self.cos_cached[:, :, :T], self.sin_cached[:, :, :T] |
| return q * cos + self._rotate_half(q) * sin, k * cos + self._rotate_half(k) * sin |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, d_model, n_heads, max_seq_len, dropout=0.0): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) |
| self.out_proj = nn.Linear(d_model, d_model, bias=False) |
| self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len) |
| self._flash = hasattr(F, "scaled_dot_product_attention") |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| Q, K, V = split(self.q_proj(x)), split(self.k_proj(x)), split(self.v_proj(x)) |
| Q, K = self.rope(Q, K) |
| if self._flash: |
| out = F.scaled_dot_product_attention(Q, K, V, is_causal=True) |
| else: |
| scale = math.sqrt(self.head_dim) |
| scores = (Q @ K.transpose(-2, -1)) / scale |
| causal = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)) |
| scores = scores.masked_fill(~causal, float("-inf")) |
| out = torch.softmax(scores, dim=-1) @ V |
| return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C)) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, d_model, d_ff): |
| super().__init__() |
| d_ff = (d_ff // 64) * 64 |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) |
| self.w3 = nn.Linear(d_model, d_ff, bias=False) |
| self.w2 = nn.Linear(d_ff, d_model, bias=False) |
|
|
| def forward(self, x): |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, cfg: ErebusConfig): |
| super().__init__() |
| self.norm1 = nn.RMSNorm(cfg.d_model) |
| self.attn = MultiHeadAttention(cfg.d_model, cfg.n_heads, cfg.max_seq_len) |
| self.norm2 = nn.RMSNorm(cfg.d_model) |
| self.ffn = SwiGLU(cfg.d_model, cfg.d_ff) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.norm1(x)) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class Erebus(nn.Module): |
| def __init__(self, cfg: ErebusConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)]) |
| self.norm = nn.RMSNorm(cfg.d_model) |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
| self.lm_head.weight = self.token_emb.weight |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_new_tokens: int = 200, |
| temperature: float = 0.8, |
| top_k: int = 50, |
| top_p: float = 0.95, |
| repetition_penalty: float = 1.2, |
| eos_token_id: Optional[int] = None, |
| ) -> torch.Tensor: |
| self.eval() |
| for _ in range(max_new_tokens): |
| ctx = input_ids[:, -self.cfg.max_seq_len:] |
| x = self.token_emb(ctx) |
| for block in self.blocks: |
| x = block(x) |
| logits = self.lm_head(self.norm(x))[:, -1, :] |
|
|
| if repetition_penalty != 1.0: |
| for tok in input_ids[0].unique(): |
| logits[0, tok] /= repetition_penalty |
|
|
| logits = logits / max(temperature, 1e-8) |
|
|
| if top_k > 0: |
| cutoff, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < cutoff[:, [-1]]] = float("-inf") |
|
|
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| sorted_logits[cum - F.softmax(sorted_logits, dim=-1) > top_p] = float("-inf") |
| logits.scatter_(1, sorted_idx, sorted_logits) |
|
|
| next_tok = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) |
| input_ids = torch.cat([input_ids, next_tok], dim=1) |
| if eos_token_id is not None and next_tok.item() == eos_token_id: |
| break |
| return input_ids |
|
|
|
|
| |
|
|
| def load_from_hf(repo_id: str, device: torch.device) -> Erebus: |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
|
|
| print(f"Downloading {repo_id} from HuggingFace Hub β¦") |
| cfg_path = hf_hub_download(repo_id, "config.json") |
| weights_path = hf_hub_download(repo_id, "model.safetensors") |
|
|
| with open(cfg_path) as f: |
| cfg = ErebusConfig(**json.load(f)) |
|
|
| model = Erebus(cfg) |
| model.load_state_dict(load_file(weights_path), strict=False) |
| model.eval().to(device) |
| n = sum(p.numel() for p in model.parameters()) |
| print(f"Loaded : {repo_id} ({n/1e6:.1f} M params)\n") |
| return model |
|
|
|
|
| def load_from_checkpoint(path: str, device: torch.device) -> Erebus: |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) |
| model = Erebus(ckpt["config"]) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval().to(device) |
| n = sum(p.numel() for p in model.parameters()) |
| print(f"Loaded : {path} ({n/1e6:.1f} M params, step={ckpt.get('step','?')})\n") |
| return model |
|
|
|
|
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Erebus inference β works with local or HF weights.") |
| src = p.add_mutually_exclusive_group(required=True) |
| src.add_argument("--hf_repo", help="HuggingFace repo id e.g. Rzoro/erebus-small") |
| src.add_argument("--checkpoint", help="Local .pt checkpoint path") |
|
|
| inp = p.add_mutually_exclusive_group() |
| inp.add_argument("--prompt", default=None) |
| inp.add_argument("--interactive", action="store_true") |
|
|
| p.add_argument("--max_new_tokens", type=int, default=200) |
| p.add_argument("--temperature", type=float, default=0.8) |
| p.add_argument("--top_k", type=int, default=50) |
| p.add_argument("--top_p", type=float, default=0.95) |
| p.add_argument("--repetition_penalty", type=float, default=1.2) |
| p.add_argument("--device", default=None) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| import tiktoken |
| args = parse_args() |
| device = torch.device( |
| args.device if args.device |
| else ("cuda" if torch.cuda.is_available() else "cpu") |
| ) |
| print(f"Device : {device}") |
|
|
| model = load_from_hf(args.hf_repo, device) if args.hf_repo \ |
| else load_from_checkpoint(args.checkpoint, device) |
|
|
| enc = tiktoken.get_encoding("gpt2") |
|
|
| def run(prompt: str) -> str: |
| ids = torch.tensor([enc.encode(prompt)], dtype=torch.long).to(device) |
| out = model.generate( |
| ids, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| eos_token_id=enc.eot_token, |
| ) |
| return enc.decode(out[0].tolist()) |
|
|
| if args.interactive: |
| print("β" * 60) |
| print("Erebus β interactive mode (quit / Ctrl-C to exit)") |
| print("β" * 60) |
| while True: |
| try: |
| prompt = input("\nPrompt > ").strip() |
| except (EOFError, KeyboardInterrupt): |
| print("\nBye!"); break |
| if not prompt or prompt.lower() in ("quit", "exit", "q"): |
| print("Bye!"); break |
| print("\n" + "β" * 60) |
| print(run(prompt)) |
| print("β" * 60) |
| else: |
| prompt = args.prompt or input("Prompt > ").strip() |
| print("\n" + "β" * 60) |
| print(run(prompt)) |
| print("β" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|