#!/usr/bin/env python3 """ Goedel-mHC-1B — Self-contained inference script. No dependencies on the training codebase. Just torch and tiktoken. Works on CUDA, MPS, and CPU. Usage: pip install torch tiktoken python inference.py --checkpoint ckpt_best.pt "The capital of France is" python inference.py --checkpoint ckpt_best.pt --interactive """ from __future__ import annotations import argparse import math import time from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F # ============================================================================ # Model Configuration # ============================================================================ @dataclass class ModelConfig: dim: int = 2048 n_layers: int = 24 vocab_size: int = 50304 num_heads: int = 16 num_kv_heads: int = 4 head_dim: int = 128 intermediate_mult: float = 2.667 n_streams: int = 4 sinkhorn_iters: int = 5 rope_theta: float = 10000.0 max_seq_len: int = 4096 qk_norm: bool = True # ============================================================================ # RoPE # ============================================================================ class RotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0, max_seq_len: int = 4096): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) t = torch.arange(max_seq_len, dtype=inv_freq.dtype) freqs = torch.outer(t, inv_freq) self.register_buffer("_cos", freqs.cos().to(torch.bfloat16)) self.register_buffer("_sin", freqs.sin().to(torch.bfloat16)) def forward(self, seq_len: int, device: torch.device): return self._cos[:seq_len].to(device), self._sin[:seq_len].to(device) def apply_rotary(x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: """Apply rotary embeddings. x: (B, n_heads, S, head_dim).""" cos, sin = cos_sin cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) x1 = x[..., 0::2] x2 = x[..., 1::2] out1 = x1 * cos - x2 * sin out2 = x1 * sin + x2 * cos return torch.stack([out1, out2], dim=-1).flatten(-2) # ============================================================================ # Sinkhorn normalization # ============================================================================ def _sinkhorn(logits: torch.Tensor, iters: int = 5) -> torch.Tensor: """Project to doubly-stochastic matrix via Sinkhorn-Knopp.""" M = logits.exp() for _ in range(iters): M = M / M.sum(dim=-1, keepdim=True) M = M / M.sum(dim=-2, keepdim=True) return M # ============================================================================ # mHC (Manifold Hyper-Connections) # ============================================================================ class mHC(nn.Module): """Manifold-constrained hyper-connections (full-dim streams). Each of n_streams streams carries the full hidden dim D. Between blocks the hidden state is (B, S, n*D). The model must call expand() after embedding and contract() before the final norm. """ def __init__(self, dim: int, n_streams: int = 4, sinkhorn_iters: int = 5): super().__init__() self.dim = dim self.n = n_streams self.sinkhorn_iters = sinkhorn_iters self.norm = nn.RMSNorm(dim) self.logits_res = nn.Parameter(10.0 * torch.eye(n_streams)) w_pre_init = math.log(1.0 / (n_streams - 1)) self.w_pre = nn.Parameter(torch.full((n_streams,), w_pre_init)) self.w_post = nn.Parameter(torch.zeros(n_streams)) self._H_res_cache = None def expand(self, x: torch.Tensor) -> torch.Tensor: """(B, S, D) -> (B, S, n*D): replicate into n streams.""" B, S, D = x.shape return x.unsqueeze(2).expand(B, S, self.n, D).reshape(B, S, self.n * D) def contract(self, x: torch.Tensor) -> torch.Tensor: """(B, S, n*D) -> (B, S, D): average across streams.""" B, S, _ = x.shape return x.view(B, S, self.n, self.dim).mean(dim=2) def _get_H_res(self): if self._H_res_cache is None or self.training: self._H_res_cache = _sinkhorn(self.logits_res, self.sinkhorn_iters).clone() return self._H_res_cache def forward(self, x: torch.Tensor, sublayer) -> torch.Tensor: B, S, _ = x.shape n, D = self.n, self.dim H_res = self._get_H_res() h_pre = torch.sigmoid(self.w_pre) h_post = 2 * torch.sigmoid(self.w_post) streams = x.view(B, S, n, D) mixed = torch.einsum("ij,bsjd->bsid", H_res, streams) sublayer_in = torch.einsum("i,bsid->bsd", h_pre, mixed) y = sublayer(self.norm(sublayer_in)) y_dist = y.unsqueeze(2) * h_post.view(1, 1, n, 1) out = mixed + y_dist return out.reshape(B, S, n * D) # ============================================================================ # GatedGQA (Grouped Query Attention with sigmoid output gate) # ============================================================================ class GatedGQA(nn.Module): """Grouped Query Attention with RoPE, QK-norm, and sigmoid output gate.""" def __init__(self, cfg: ModelConfig): super().__init__() self.num_heads = cfg.num_heads self.num_kv_heads = cfg.num_kv_heads self.head_dim = cfg.head_dim self.kv_group_size = cfg.num_heads // cfg.num_kv_heads self.q_proj = nn.Linear(cfg.dim, cfg.num_heads * cfg.head_dim, bias=False) self.k_proj = nn.Linear(cfg.dim, cfg.num_kv_heads * cfg.head_dim, bias=False) self.v_proj = nn.Linear(cfg.dim, cfg.num_kv_heads * cfg.head_dim, bias=False) self.o_proj = nn.Linear(cfg.num_heads * cfg.head_dim, cfg.dim, bias=False) self.gate_proj = nn.Linear(cfg.dim, cfg.num_heads * cfg.head_dim, bias=False) self.qk_norm = nn.RMSNorm(cfg.head_dim) if cfg.qk_norm else None self.rope = RotaryEmbedding(cfg.head_dim, theta=cfg.rope_theta, max_seq_len=cfg.max_seq_len) def forward(self, x: torch.Tensor, cache: tuple | None = None, start_pos: int = 0, return_cache: bool = False) -> torch.Tensor | tuple: B, S, _ = x.shape q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) if self.qk_norm: q = self.qk_norm(q).to(v.dtype) k = self.qk_norm(k).to(v.dtype) # RoPE with position offset for cached generation rope_len = start_pos + S freqs_full = self.rope(rope_len, x.device) freqs = (freqs_full[0][start_pos:start_pos + S], freqs_full[1][start_pos:start_pos + S]) q = apply_rotary(q, freqs) k = apply_rotary(k, freqs) # KV cache: concat with previous cache if present # Cache stores unexpanded KV heads (num_kv_heads, not num_heads) if cache is not None: k_cache, v_cache = cache k = torch.cat([k_cache, k], dim=2) v = torch.cat([v_cache, v], dim=2) new_cache = (k, v) if return_cache else None # Prefill needs causal mask; single-token decode with cache does not use_causal = (cache is None) and (S > 1) # GQA: expand KV heads for SDPA compatibility on CPU / MPS / older torch k_exp = k.repeat_interleave(self.kv_group_size, dim=1) if self.kv_group_size > 1 else k v_exp = v.repeat_interleave(self.kv_group_size, dim=1) if self.kv_group_size > 1 else v out = F.scaled_dot_product_attention(q, k_exp, v_exp, is_causal=use_causal) # Sigmoid output gate if out.shape[1] == self.num_heads: # (B, H, S, D) layout gate = self.gate_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) else: gate = self.gate_proj(x).view(B, S, self.num_heads, self.head_dim) if gate.dtype != out.dtype: gate = gate.to(out.dtype) out = out * torch.sigmoid(gate) out = out.transpose(1, 2).contiguous().view(B, S, -1) result = self.o_proj(out) if cache is not None or return_cache: return result, (k, v) return result # ============================================================================ # ReLU-squared FFN # ============================================================================ class ReLU2(nn.Module): """ReLU-squared FFN: relu(x @ W_up)^2 @ W_down""" def __init__(self, cfg: ModelConfig): super().__init__() hidden = int(cfg.dim * cfg.intermediate_mult) hidden = ((hidden + 255) // 256) * 256 # round to multiple of 256 self.w_up = nn.Linear(cfg.dim, hidden, bias=False) self.w_down = nn.Linear(hidden, cfg.dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_down(F.relu(self.w_up(x)) ** 2) # ============================================================================ # TransformerBlock # ============================================================================ class TransformerBlock(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.attn = GatedGQA(cfg) self.ffn = ReLU2(cfg) self.residual_attn = mHC(cfg.dim, cfg.n_streams, cfg.sinkhorn_iters) self.residual_ffn = mHC(cfg.dim, cfg.n_streams, cfg.sinkhorn_iters) def forward(self, x: torch.Tensor, cache: tuple | None = None, start_pos: int = 0, return_cache: bool = False) -> torch.Tensor | tuple: _side = {} def _attn_fn(normed_x): out = self.attn(normed_x, cache=cache, start_pos=start_pos, return_cache=return_cache) if isinstance(out, tuple): out, _side['cache'] = out return out x = self.residual_attn(x, _attn_fn) x = self.residual_ffn(x, self.ffn) new_cache = _side.get('cache') if new_cache is not None: return x, new_cache return x # ============================================================================ # GPT Model # ============================================================================ class GPT(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.embed = nn.Embedding(cfg.vocab_size, cfg.dim) self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)]) self.norm = nn.RMSNorm(cfg.dim) self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) # Weight tying: share embedding and output projection self.embed.weight = self.lm_head.weight # mHC uses expanded hidden state self._needs_expand = True # always true for mHC architecture def forward(self, input_ids: torch.Tensor, kv_cache: list | None = None, start_pos: int = 0, return_cache: bool = True) -> tuple[torch.Tensor, list | None]: """Forward pass returning (logits, kv_cache). Args: input_ids: (B, S) token indices kv_cache: list of (K, V) tuples per layer, or None for prefill start_pos: position offset for RoPE in cached generation return_cache: if True, always return new KV caches (needed for prefill) """ x = self.embed(input_ids) if self._needs_expand: x = self.blocks[0].residual_attn.expand(x) new_kv_cache = [] if return_cache else None for i, block in enumerate(self.blocks): layer_cache = kv_cache[i] if kv_cache is not None else None result = block(x, cache=layer_cache, start_pos=start_pos, return_cache=return_cache) if isinstance(result, tuple): x, layer_new_cache = result if new_kv_cache is not None: new_kv_cache.append(layer_new_cache) else: x = result if new_kv_cache is not None: new_kv_cache.append(None) if self._needs_expand: x = self.blocks[0].residual_attn.contract(x) x = self.norm(x) logits = self.lm_head(x) return logits, new_kv_cache # ============================================================================ # Sampling utilities # ============================================================================ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: """Nucleus (top-p) sampling.""" sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Remove tokens with cumulative prob above the threshold (keep at least 1) mask = cumulative_probs - sorted_probs > p sorted_probs[mask] = 0.0 sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True) next_token = torch.multinomial(sorted_probs, num_samples=1) return torch.gather(sorted_indices, -1, next_token) def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: """Top-k sampling.""" top_k_probs, top_k_indices = torch.topk(probs, k, dim=-1) top_k_probs /= top_k_probs.sum(dim=-1, keepdim=True) idx = torch.multinomial(top_k_probs, num_samples=1) return torch.gather(top_k_indices, -1, idx) # ============================================================================ # Generation # ============================================================================ @torch.inference_mode() def generate( model: GPT, prompt_tokens: list[int], max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 0, stop_tokens: set[int] | None = None, ): """Generate tokens autoregressively with KV caching. Yields (token_id, tokens_per_second) tuples as tokens are generated. """ device = next(model.parameters()).device tokens = torch.tensor([prompt_tokens], dtype=torch.long, device=device) with torch.inference_mode(): # Prefill: process entire prompt at once logits, kv_cache = model(tokens, kv_cache=None, start_pos=0) logits = logits[:, -1, :] # last position start_time = time.perf_counter() seq_len = tokens.shape[1] for i in range(max_new_tokens): # Sample from logits if temperature <= 0: next_token = logits.argmax(dim=-1, keepdim=True) else: scaled_logits = logits / temperature probs = F.softmax(scaled_logits, dim=-1) if top_k > 0: next_token = sample_top_k(probs, top_k) elif top_p < 1.0: next_token = sample_top_p(probs, top_p) else: next_token = torch.multinomial(probs, num_samples=1) token_id = next_token.item() if stop_tokens and token_id in stop_tokens: break elapsed = time.perf_counter() - start_time tps = (i + 1) / elapsed if elapsed > 0 else 0.0 yield token_id, tps # Decode: single token with KV cache start_pos = seq_len + i next_input = next_token.view(1, 1) logits, kv_cache = model(next_input, kv_cache=kv_cache, start_pos=start_pos) logits = logits[:, -1, :] # ============================================================================ # Checkpoint loading # ============================================================================ def load_model(checkpoint_path: str, device: str = "cpu", dtype: torch.dtype = torch.float32) -> GPT: """Load a Goedel-mHC-1B checkpoint.""" cfg = ModelConfig() model = GPT(cfg) print(f"Loading checkpoint from {checkpoint_path} ...", flush=True) ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) # Support both {"model": state_dict, ...} and raw state_dict if isinstance(ckpt, dict) and "model" in ckpt: state_dict = ckpt["model"] step = ckpt.get("step", "?") val_loss = ckpt.get("val_loss", "?") print(f" Step: {step}, Val loss: {val_loss}") else: state_dict = ckpt # Handle weight tying: remove embed.weight if lm_head.weight is present # (they share the same tensor, so only one is needed) if "embed.weight" in state_dict and "lm_head.weight" in state_dict: del state_dict["embed.weight"] result = model.load_state_dict(state_dict, strict=False) if result.unexpected_keys: print(f" Warning: unexpected keys: {result.unexpected_keys}") if result.missing_keys: # embed.weight is expected to be missing due to weight tying missing = [k for k in result.missing_keys if k != "embed.weight"] if missing: print(f" Warning: missing keys: {missing}") model = model.to(dtype=dtype, device=device) model.eval() # Precompute Sinkhorn matrices for all mHC modules for block in model.blocks: block.residual_attn._get_H_res() block.residual_ffn._get_H_res() total_params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {total_params:,} ({total_params / 1e9:.2f}B)") print(f" Device: {device}, Dtype: {dtype}") return model # ============================================================================ # Tokenizer # ============================================================================ def get_tokenizer(): """Get GPT-2 tokenizer via tiktoken.""" import tiktoken return tiktoken.get_encoding("gpt2") # ============================================================================ # CLI # ============================================================================ def main(): parser = argparse.ArgumentParser( description="Goedel-mHC-1B inference", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python inference.py --checkpoint ckpt_best.pt "The capital of France is" python inference.py --checkpoint ckpt_best.pt --interactive python inference.py --checkpoint ckpt_best.pt --temperature 0 "def fibonacci(n):" python inference.py --checkpoint ckpt_best.pt --device cpu "Once upon a time" """, ) parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint file (.pt)") parser.add_argument("--interactive", action="store_true", help="Enter interactive REPL mode") parser.add_argument("--max-tokens", type=int, default=256, help="Maximum tokens to generate (default: 256)") parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature (0 = greedy, default: 0.7)") parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold (default: 0.9)") parser.add_argument("--top-k", type=int, default=0, help="Top-k sampling (0 = disabled, default: 0)") parser.add_argument("--device", type=str, default=None, help="Device (default: cuda if available, else cpu)") parser.add_argument("prompt", nargs="*", type=str, help="Prompt text (ignored in interactive mode)") args = parser.parse_args() # Device selection if args.device: device = args.device elif torch.cuda.is_available(): device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" else: device = "cpu" # Dtype: bf16 on CUDA, fp16 on MPS, fp32 on CPU if device == "cuda": dtype = torch.bfloat16 elif device == "mps": dtype = torch.float16 else: dtype = torch.float32 # Load model model = load_model(args.checkpoint, device=device, dtype=dtype) # Load tokenizer enc = get_tokenizer() eot_token = enc.eot_token # <|endoftext|> def run_generation(prompt_text: str): tokens = enc.encode(prompt_text) if not tokens: print("(empty prompt)") return print(prompt_text, end="", flush=True) generated = [] tps = 0.0 for token_id, tps in generate( model, tokens, max_new_tokens=args.max_tokens, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, stop_tokens={eot_token}, ): text = enc.decode([token_id]) print(text, end="", flush=True) generated.append(token_id) print() print(f"\n--- {len(generated)} tokens, {tps:.1f} tok/s ---") if args.interactive: print("=" * 60) print("Goedel-mHC-1B Interactive Mode") print("Type your prompt and press Enter. Ctrl+C or 'quit' to exit.") print("=" * 60) while True: try: prompt = input("\n>>> ") except (KeyboardInterrupt, EOFError): print("\nBye!") break prompt = prompt.strip() if not prompt or prompt.lower() in ("quit", "exit"): if prompt.lower() in ("quit", "exit"): print("Bye!") break print() run_generation(prompt) else: prompt_text = " ".join(args.prompt) if not prompt_text: parser.error("Provide a prompt or use --interactive") run_generation(prompt_text) if __name__ == "__main__": main()