""" inference_hf.py — Self-contained inference script for Erebus models on HuggingFace. This file has zero dependency on the rest of the erebus repo. Copy it anywhere and run it as long as you have: pip install torch tiktoken huggingface_hub safetensors Usage ----- # From HuggingFace Hub python inference_hf.py --hf_repo Rzoro/erebus-small --prompt "The future of AI" # Interactive python inference_hf.py --hf_repo Rzoro/erebus-small --interactive """ from __future__ import annotations import argparse import json import math from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F # ── Model definition (self-contained copy) ──────────────────────────────────── @dataclass class ErebusConfig: vocab_size: int = 50257 d_model: int = 768 n_heads: int = 12 n_layers: int = 12 d_ff: int = 3072 max_seq_len: int = 1024 dropout: float = 0.0 class RotaryPositionEmbedding(nn.Module): def __init__(self, head_dim: int, max_seq_len: int = 4096): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim)) positions = torch.arange(max_seq_len).float() freqs = torch.outer(positions, inv_freq) cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(0).unsqueeze(0) sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(0).unsqueeze(0) self.register_buffer("cos_cached", cos, persistent=False) self.register_buffer("sin_cached", sin, persistent=False) @staticmethod def _rotate_half(x): x1, x2 = x[..., 0::2], x[..., 1::2] return torch.stack([-x2, x1], dim=-1).flatten(-2) def forward(self, q, k): T = q.size(2) cos, sin = self.cos_cached[:, :, :T], self.sin_cached[:, :, :T] return q * cos + self._rotate_half(q) * sin, k * cos + self._rotate_half(k) * sin class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_heads, max_seq_len, dropout=0.0): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.rope = RotaryPositionEmbedding(self.head_dim, max_seq_len) self._flash = hasattr(F, "scaled_dot_product_attention") def forward(self, x): B, T, C = x.shape def split(t): return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) Q, K, V = split(self.q_proj(x)), split(self.k_proj(x)), split(self.v_proj(x)) Q, K = self.rope(Q, K) if self._flash: out = F.scaled_dot_product_attention(Q, K, V, is_causal=True) else: scale = math.sqrt(self.head_dim) scores = (Q @ K.transpose(-2, -1)) / scale causal = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool)) scores = scores.masked_fill(~causal, float("-inf")) out = torch.softmax(scores, dim=-1) @ V return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C)) class SwiGLU(nn.Module): def __init__(self, d_model, d_ff): super().__init__() d_ff = (d_ff // 64) * 64 self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w3 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_ff, d_model, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, cfg: ErebusConfig): super().__init__() self.norm1 = nn.RMSNorm(cfg.d_model) self.attn = MultiHeadAttention(cfg.d_model, cfg.n_heads, cfg.max_seq_len) self.norm2 = nn.RMSNorm(cfg.d_model) self.ffn = SwiGLU(cfg.d_model, cfg.d_ff) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x class Erebus(nn.Module): def __init__(self, cfg: ErebusConfig): super().__init__() self.cfg = cfg self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)]) self.norm = nn.RMSNorm(cfg.d_model) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.lm_head.weight = self.token_emb.weight @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 200, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.2, eos_token_id: Optional[int] = None, ) -> torch.Tensor: self.eval() for _ in range(max_new_tokens): ctx = input_ids[:, -self.cfg.max_seq_len:] x = self.token_emb(ctx) for block in self.blocks: x = block(x) logits = self.lm_head(self.norm(x))[:, -1, :] if repetition_penalty != 1.0: for tok in input_ids[0].unique(): logits[0, tok] /= repetition_penalty logits = logits / max(temperature, 1e-8) if top_k > 0: cutoff, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < cutoff[:, [-1]]] = float("-inf") if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_logits[cum - F.softmax(sorted_logits, dim=-1) > top_p] = float("-inf") logits.scatter_(1, sorted_idx, sorted_logits) next_tok = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) input_ids = torch.cat([input_ids, next_tok], dim=1) if eos_token_id is not None and next_tok.item() == eos_token_id: break return input_ids # ── Loading helpers ─────────────────────────────────────────────────────────── def load_from_hf(repo_id: str, device: torch.device) -> Erebus: from huggingface_hub import hf_hub_download from safetensors.torch import load_file print(f"Downloading {repo_id} from HuggingFace Hub …") cfg_path = hf_hub_download(repo_id, "config.json") weights_path = hf_hub_download(repo_id, "model.safetensors") with open(cfg_path) as f: cfg = ErebusConfig(**json.load(f)) model = Erebus(cfg) model.load_state_dict(load_file(weights_path), strict=False) model.eval().to(device) n = sum(p.numel() for p in model.parameters()) print(f"Loaded : {repo_id} ({n/1e6:.1f} M params)\n") return model def load_from_checkpoint(path: str, device: torch.device) -> Erebus: ckpt = torch.load(path, map_location="cpu", weights_only=False) model = Erebus(ckpt["config"]) model.load_state_dict(ckpt["model_state_dict"]) model.eval().to(device) n = sum(p.numel() for p in model.parameters()) print(f"Loaded : {path} ({n/1e6:.1f} M params, step={ckpt.get('step','?')})\n") return model # ── CLI ─────────────────────────────────────────────────────────────────────── def parse_args(): p = argparse.ArgumentParser(description="Erebus inference — works with local or HF weights.") src = p.add_mutually_exclusive_group(required=True) src.add_argument("--hf_repo", help="HuggingFace repo id e.g. Rzoro/erebus-small") src.add_argument("--checkpoint", help="Local .pt checkpoint path") inp = p.add_mutually_exclusive_group() inp.add_argument("--prompt", default=None) inp.add_argument("--interactive", action="store_true") p.add_argument("--max_new_tokens", type=int, default=200) p.add_argument("--temperature", type=float, default=0.8) p.add_argument("--top_k", type=int, default=50) p.add_argument("--top_p", type=float, default=0.95) p.add_argument("--repetition_penalty", type=float, default=1.2) p.add_argument("--device", default=None) return p.parse_args() def main(): import tiktoken args = parse_args() device = torch.device( args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu") ) print(f"Device : {device}") model = load_from_hf(args.hf_repo, device) if args.hf_repo \ else load_from_checkpoint(args.checkpoint, device) enc = tiktoken.get_encoding("gpt2") def run(prompt: str) -> str: ids = torch.tensor([enc.encode(prompt)], dtype=torch.long).to(device) out = model.generate( ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, eos_token_id=enc.eot_token, ) return enc.decode(out[0].tolist()) if args.interactive: print("═" * 60) print("Erebus — interactive mode (quit / Ctrl-C to exit)") print("═" * 60) while True: try: prompt = input("\nPrompt > ").strip() except (EOFError, KeyboardInterrupt): print("\nBye!"); break if not prompt or prompt.lower() in ("quit", "exit", "q"): print("Bye!"); break print("\n" + "─" * 60) print(run(prompt)) print("─" * 60) else: prompt = args.prompt or input("Prompt > ").strip() print("\n" + "─" * 60) print(run(prompt)) print("─" * 60) if __name__ == "__main__": main()