Goedel-mHC-1B / inference.py
Lazaurus's picture
Upload inference.py with huggingface_hub
8789bc4 verified
#!/usr/bin/env python3
"""
Goedel-mHC-1B — Self-contained inference script.
No dependencies on the training codebase. Just torch and tiktoken.
Works on CUDA, MPS, and CPU.
Usage:
pip install torch tiktoken
python inference.py --checkpoint ckpt_best.pt "The capital of France is"
python inference.py --checkpoint ckpt_best.pt --interactive
"""
from __future__ import annotations
import argparse
import math
import time
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================================
# Model Configuration
# ============================================================================
@dataclass
class ModelConfig:
dim: int = 2048
n_layers: int = 24
vocab_size: int = 50304
num_heads: int = 16
num_kv_heads: int = 4
head_dim: int = 128
intermediate_mult: float = 2.667
n_streams: int = 4
sinkhorn_iters: int = 5
rope_theta: float = 10000.0
max_seq_len: int = 4096
qk_norm: bool = True
# ============================================================================
# RoPE
# ============================================================================
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0, max_seq_len: int = 4096):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq)
self.register_buffer("_cos", freqs.cos().to(torch.bfloat16))
self.register_buffer("_sin", freqs.sin().to(torch.bfloat16))
def forward(self, seq_len: int, device: torch.device):
return self._cos[:seq_len].to(device), self._sin[:seq_len].to(device)
def apply_rotary(x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""Apply rotary embeddings. x: (B, n_heads, S, head_dim)."""
cos, sin = cos_sin
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos
return torch.stack([out1, out2], dim=-1).flatten(-2)
# ============================================================================
# Sinkhorn normalization
# ============================================================================
def _sinkhorn(logits: torch.Tensor, iters: int = 5) -> torch.Tensor:
"""Project to doubly-stochastic matrix via Sinkhorn-Knopp."""
M = logits.exp()
for _ in range(iters):
M = M / M.sum(dim=-1, keepdim=True)
M = M / M.sum(dim=-2, keepdim=True)
return M
# ============================================================================
# mHC (Manifold Hyper-Connections)
# ============================================================================
class mHC(nn.Module):
"""Manifold-constrained hyper-connections (full-dim streams).
Each of n_streams streams carries the full hidden dim D. Between
blocks the hidden state is (B, S, n*D). The model must call
expand() after embedding and contract() before the final norm.
"""
def __init__(self, dim: int, n_streams: int = 4, sinkhorn_iters: int = 5):
super().__init__()
self.dim = dim
self.n = n_streams
self.sinkhorn_iters = sinkhorn_iters
self.norm = nn.RMSNorm(dim)
self.logits_res = nn.Parameter(10.0 * torch.eye(n_streams))
w_pre_init = math.log(1.0 / (n_streams - 1))
self.w_pre = nn.Parameter(torch.full((n_streams,), w_pre_init))
self.w_post = nn.Parameter(torch.zeros(n_streams))
self._H_res_cache = None
def expand(self, x: torch.Tensor) -> torch.Tensor:
"""(B, S, D) -> (B, S, n*D): replicate into n streams."""
B, S, D = x.shape
return x.unsqueeze(2).expand(B, S, self.n, D).reshape(B, S, self.n * D)
def contract(self, x: torch.Tensor) -> torch.Tensor:
"""(B, S, n*D) -> (B, S, D): average across streams."""
B, S, _ = x.shape
return x.view(B, S, self.n, self.dim).mean(dim=2)
def _get_H_res(self):
if self._H_res_cache is None or self.training:
self._H_res_cache = _sinkhorn(self.logits_res, self.sinkhorn_iters).clone()
return self._H_res_cache
def forward(self, x: torch.Tensor, sublayer) -> torch.Tensor:
B, S, _ = x.shape
n, D = self.n, self.dim
H_res = self._get_H_res()
h_pre = torch.sigmoid(self.w_pre)
h_post = 2 * torch.sigmoid(self.w_post)
streams = x.view(B, S, n, D)
mixed = torch.einsum("ij,bsjd->bsid", H_res, streams)
sublayer_in = torch.einsum("i,bsid->bsd", h_pre, mixed)
y = sublayer(self.norm(sublayer_in))
y_dist = y.unsqueeze(2) * h_post.view(1, 1, n, 1)
out = mixed + y_dist
return out.reshape(B, S, n * D)
# ============================================================================
# GatedGQA (Grouped Query Attention with sigmoid output gate)
# ============================================================================
class GatedGQA(nn.Module):
"""Grouped Query Attention with RoPE, QK-norm, and sigmoid output gate."""
def __init__(self, cfg: ModelConfig):
super().__init__()
self.num_heads = cfg.num_heads
self.num_kv_heads = cfg.num_kv_heads
self.head_dim = cfg.head_dim
self.kv_group_size = cfg.num_heads // cfg.num_kv_heads
self.q_proj = nn.Linear(cfg.dim, cfg.num_heads * cfg.head_dim, bias=False)
self.k_proj = nn.Linear(cfg.dim, cfg.num_kv_heads * cfg.head_dim, bias=False)
self.v_proj = nn.Linear(cfg.dim, cfg.num_kv_heads * cfg.head_dim, bias=False)
self.o_proj = nn.Linear(cfg.num_heads * cfg.head_dim, cfg.dim, bias=False)
self.gate_proj = nn.Linear(cfg.dim, cfg.num_heads * cfg.head_dim, bias=False)
self.qk_norm = nn.RMSNorm(cfg.head_dim) if cfg.qk_norm else None
self.rope = RotaryEmbedding(cfg.head_dim, theta=cfg.rope_theta,
max_seq_len=cfg.max_seq_len)
def forward(self, x: torch.Tensor, cache: tuple | None = None,
start_pos: int = 0, return_cache: bool = False) -> torch.Tensor | tuple:
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.qk_norm:
q = self.qk_norm(q).to(v.dtype)
k = self.qk_norm(k).to(v.dtype)
# RoPE with position offset for cached generation
rope_len = start_pos + S
freqs_full = self.rope(rope_len, x.device)
freqs = (freqs_full[0][start_pos:start_pos + S],
freqs_full[1][start_pos:start_pos + S])
q = apply_rotary(q, freqs)
k = apply_rotary(k, freqs)
# KV cache: concat with previous cache if present
# Cache stores unexpanded KV heads (num_kv_heads, not num_heads)
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_cache = (k, v) if return_cache else None
# Prefill needs causal mask; single-token decode with cache does not
use_causal = (cache is None) and (S > 1)
# GQA: expand KV heads for SDPA compatibility on CPU / MPS / older torch
k_exp = k.repeat_interleave(self.kv_group_size, dim=1) if self.kv_group_size > 1 else k
v_exp = v.repeat_interleave(self.kv_group_size, dim=1) if self.kv_group_size > 1 else v
out = F.scaled_dot_product_attention(q, k_exp, v_exp, is_causal=use_causal)
# Sigmoid output gate
if out.shape[1] == self.num_heads:
# (B, H, S, D) layout
gate = self.gate_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
else:
gate = self.gate_proj(x).view(B, S, self.num_heads, self.head_dim)
if gate.dtype != out.dtype:
gate = gate.to(out.dtype)
out = out * torch.sigmoid(gate)
out = out.transpose(1, 2).contiguous().view(B, S, -1)
result = self.o_proj(out)
if cache is not None or return_cache:
return result, (k, v)
return result
# ============================================================================
# ReLU-squared FFN
# ============================================================================
class ReLU2(nn.Module):
"""ReLU-squared FFN: relu(x @ W_up)^2 @ W_down"""
def __init__(self, cfg: ModelConfig):
super().__init__()
hidden = int(cfg.dim * cfg.intermediate_mult)
hidden = ((hidden + 255) // 256) * 256 # round to multiple of 256
self.w_up = nn.Linear(cfg.dim, hidden, bias=False)
self.w_down = nn.Linear(hidden, cfg.dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_down(F.relu(self.w_up(x)) ** 2)
# ============================================================================
# TransformerBlock
# ============================================================================
class TransformerBlock(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.attn = GatedGQA(cfg)
self.ffn = ReLU2(cfg)
self.residual_attn = mHC(cfg.dim, cfg.n_streams, cfg.sinkhorn_iters)
self.residual_ffn = mHC(cfg.dim, cfg.n_streams, cfg.sinkhorn_iters)
def forward(self, x: torch.Tensor, cache: tuple | None = None,
start_pos: int = 0,
return_cache: bool = False) -> torch.Tensor | tuple:
_side = {}
def _attn_fn(normed_x):
out = self.attn(normed_x, cache=cache, start_pos=start_pos,
return_cache=return_cache)
if isinstance(out, tuple):
out, _side['cache'] = out
return out
x = self.residual_attn(x, _attn_fn)
x = self.residual_ffn(x, self.ffn)
new_cache = _side.get('cache')
if new_cache is not None:
return x, new_cache
return x
# ============================================================================
# GPT Model
# ============================================================================
class GPT(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim)
self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
self.norm = nn.RMSNorm(cfg.dim)
self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
# Weight tying: share embedding and output projection
self.embed.weight = self.lm_head.weight
# mHC uses expanded hidden state
self._needs_expand = True # always true for mHC architecture
def forward(self, input_ids: torch.Tensor,
kv_cache: list | None = None,
start_pos: int = 0,
return_cache: bool = True) -> tuple[torch.Tensor, list | None]:
"""Forward pass returning (logits, kv_cache).
Args:
input_ids: (B, S) token indices
kv_cache: list of (K, V) tuples per layer, or None for prefill
start_pos: position offset for RoPE in cached generation
return_cache: if True, always return new KV caches (needed for prefill)
"""
x = self.embed(input_ids)
if self._needs_expand:
x = self.blocks[0].residual_attn.expand(x)
new_kv_cache = [] if return_cache else None
for i, block in enumerate(self.blocks):
layer_cache = kv_cache[i] if kv_cache is not None else None
result = block(x, cache=layer_cache, start_pos=start_pos,
return_cache=return_cache)
if isinstance(result, tuple):
x, layer_new_cache = result
if new_kv_cache is not None:
new_kv_cache.append(layer_new_cache)
else:
x = result
if new_kv_cache is not None:
new_kv_cache.append(None)
if self._needs_expand:
x = self.blocks[0].residual_attn.contract(x)
x = self.norm(x)
logits = self.lm_head(x)
return logits, new_kv_cache
# ============================================================================
# Sampling utilities
# ============================================================================
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Nucleus (top-p) sampling."""
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens with cumulative prob above the threshold (keep at least 1)
mask = cumulative_probs - sorted_probs > p
sorted_probs[mask] = 0.0
sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
next_token = torch.multinomial(sorted_probs, num_samples=1)
return torch.gather(sorted_indices, -1, next_token)
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
"""Top-k sampling."""
top_k_probs, top_k_indices = torch.topk(probs, k, dim=-1)
top_k_probs /= top_k_probs.sum(dim=-1, keepdim=True)
idx = torch.multinomial(top_k_probs, num_samples=1)
return torch.gather(top_k_indices, -1, idx)
# ============================================================================
# Generation
# ============================================================================
@torch.inference_mode()
def generate(
model: GPT,
prompt_tokens: list[int],
max_new_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 0,
stop_tokens: set[int] | None = None,
):
"""Generate tokens autoregressively with KV caching.
Yields (token_id, tokens_per_second) tuples as tokens are generated.
"""
device = next(model.parameters()).device
tokens = torch.tensor([prompt_tokens], dtype=torch.long, device=device)
with torch.inference_mode():
# Prefill: process entire prompt at once
logits, kv_cache = model(tokens, kv_cache=None, start_pos=0)
logits = logits[:, -1, :] # last position
start_time = time.perf_counter()
seq_len = tokens.shape[1]
for i in range(max_new_tokens):
# Sample from logits
if temperature <= 0:
next_token = logits.argmax(dim=-1, keepdim=True)
else:
scaled_logits = logits / temperature
probs = F.softmax(scaled_logits, dim=-1)
if top_k > 0:
next_token = sample_top_k(probs, top_k)
elif top_p < 1.0:
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
if stop_tokens and token_id in stop_tokens:
break
elapsed = time.perf_counter() - start_time
tps = (i + 1) / elapsed if elapsed > 0 else 0.0
yield token_id, tps
# Decode: single token with KV cache
start_pos = seq_len + i
next_input = next_token.view(1, 1)
logits, kv_cache = model(next_input, kv_cache=kv_cache, start_pos=start_pos)
logits = logits[:, -1, :]
# ============================================================================
# Checkpoint loading
# ============================================================================
def load_model(checkpoint_path: str, device: str = "cpu",
dtype: torch.dtype = torch.float32) -> GPT:
"""Load a Goedel-mHC-1B checkpoint."""
cfg = ModelConfig()
model = GPT(cfg)
print(f"Loading checkpoint from {checkpoint_path} ...", flush=True)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
# Support both {"model": state_dict, ...} and raw state_dict
if isinstance(ckpt, dict) and "model" in ckpt:
state_dict = ckpt["model"]
step = ckpt.get("step", "?")
val_loss = ckpt.get("val_loss", "?")
print(f" Step: {step}, Val loss: {val_loss}")
else:
state_dict = ckpt
# Handle weight tying: remove embed.weight if lm_head.weight is present
# (they share the same tensor, so only one is needed)
if "embed.weight" in state_dict and "lm_head.weight" in state_dict:
del state_dict["embed.weight"]
result = model.load_state_dict(state_dict, strict=False)
if result.unexpected_keys:
print(f" Warning: unexpected keys: {result.unexpected_keys}")
if result.missing_keys:
# embed.weight is expected to be missing due to weight tying
missing = [k for k in result.missing_keys if k != "embed.weight"]
if missing:
print(f" Warning: missing keys: {missing}")
model = model.to(dtype=dtype, device=device)
model.eval()
# Precompute Sinkhorn matrices for all mHC modules
for block in model.blocks:
block.residual_attn._get_H_res()
block.residual_ffn._get_H_res()
total_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {total_params:,} ({total_params / 1e9:.2f}B)")
print(f" Device: {device}, Dtype: {dtype}")
return model
# ============================================================================
# Tokenizer
# ============================================================================
def get_tokenizer():
"""Get GPT-2 tokenizer via tiktoken."""
import tiktoken
return tiktoken.get_encoding("gpt2")
# ============================================================================
# CLI
# ============================================================================
def main():
parser = argparse.ArgumentParser(
description="Goedel-mHC-1B inference",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python inference.py --checkpoint ckpt_best.pt "The capital of France is"
python inference.py --checkpoint ckpt_best.pt --interactive
python inference.py --checkpoint ckpt_best.pt --temperature 0 "def fibonacci(n):"
python inference.py --checkpoint ckpt_best.pt --device cpu "Once upon a time"
""",
)
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to checkpoint file (.pt)")
parser.add_argument("--interactive", action="store_true",
help="Enter interactive REPL mode")
parser.add_argument("--max-tokens", type=int, default=256,
help="Maximum tokens to generate (default: 256)")
parser.add_argument("--temperature", type=float, default=0.7,
help="Sampling temperature (0 = greedy, default: 0.7)")
parser.add_argument("--top-p", type=float, default=0.9,
help="Nucleus sampling threshold (default: 0.9)")
parser.add_argument("--top-k", type=int, default=0,
help="Top-k sampling (0 = disabled, default: 0)")
parser.add_argument("--device", type=str, default=None,
help="Device (default: cuda if available, else cpu)")
parser.add_argument("prompt", nargs="*", type=str,
help="Prompt text (ignored in interactive mode)")
args = parser.parse_args()
# Device selection
if args.device:
device = args.device
elif torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Dtype: bf16 on CUDA, fp16 on MPS, fp32 on CPU
if device == "cuda":
dtype = torch.bfloat16
elif device == "mps":
dtype = torch.float16
else:
dtype = torch.float32
# Load model
model = load_model(args.checkpoint, device=device, dtype=dtype)
# Load tokenizer
enc = get_tokenizer()
eot_token = enc.eot_token # <|endoftext|>
def run_generation(prompt_text: str):
tokens = enc.encode(prompt_text)
if not tokens:
print("(empty prompt)")
return
print(prompt_text, end="", flush=True)
generated = []
tps = 0.0
for token_id, tps in generate(
model, tokens,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
stop_tokens={eot_token},
):
text = enc.decode([token_id])
print(text, end="", flush=True)
generated.append(token_id)
print()
print(f"\n--- {len(generated)} tokens, {tps:.1f} tok/s ---")
if args.interactive:
print("=" * 60)
print("Goedel-mHC-1B Interactive Mode")
print("Type your prompt and press Enter. Ctrl+C or 'quit' to exit.")
print("=" * 60)
while True:
try:
prompt = input("\n>>> ")
except (KeyboardInterrupt, EOFError):
print("\nBye!")
break
prompt = prompt.strip()
if not prompt or prompt.lower() in ("quit", "exit"):
if prompt.lower() in ("quit", "exit"):
print("Bye!")
break
print()
run_generation(prompt)
else:
prompt_text = " ".join(args.prompt)
if not prompt_text:
parser.error("Provide a prompt or use --interactive")
run_generation(prompt_text)
if __name__ == "__main__":
main()