| |
| """ |
| Goedel-mHC-1B — Self-contained inference script. |
| |
| No dependencies on the training codebase. Just torch and tiktoken. |
| Works on CUDA, MPS, and CPU. |
| |
| Usage: |
| pip install torch tiktoken |
| python inference.py --checkpoint ckpt_best.pt "The capital of France is" |
| python inference.py --checkpoint ckpt_best.pt --interactive |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import time |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ModelConfig: |
| dim: int = 2048 |
| n_layers: int = 24 |
| vocab_size: int = 50304 |
| num_heads: int = 16 |
| num_kv_heads: int = 4 |
| head_dim: int = 128 |
| intermediate_mult: float = 2.667 |
| n_streams: int = 4 |
| sinkhorn_iters: int = 5 |
| rope_theta: float = 10000.0 |
| max_seq_len: int = 4096 |
| qk_norm: bool = True |
|
|
|
|
| |
| |
| |
|
|
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim: int, theta: float = 10000.0, max_seq_len: int = 4096): |
| super().__init__() |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| t = torch.arange(max_seq_len, dtype=inv_freq.dtype) |
| freqs = torch.outer(t, inv_freq) |
| self.register_buffer("_cos", freqs.cos().to(torch.bfloat16)) |
| self.register_buffer("_sin", freqs.sin().to(torch.bfloat16)) |
|
|
| def forward(self, seq_len: int, device: torch.device): |
| return self._cos[:seq_len].to(device), self._sin[:seq_len].to(device) |
|
|
|
|
| def apply_rotary(x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
| """Apply rotary embeddings. x: (B, n_heads, S, head_dim).""" |
| cos, sin = cos_sin |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
| x1 = x[..., 0::2] |
| x2 = x[..., 1::2] |
| out1 = x1 * cos - x2 * sin |
| out2 = x1 * sin + x2 * cos |
| return torch.stack([out1, out2], dim=-1).flatten(-2) |
|
|
|
|
| |
| |
| |
|
|
| def _sinkhorn(logits: torch.Tensor, iters: int = 5) -> torch.Tensor: |
| """Project to doubly-stochastic matrix via Sinkhorn-Knopp.""" |
| M = logits.exp() |
| for _ in range(iters): |
| M = M / M.sum(dim=-1, keepdim=True) |
| M = M / M.sum(dim=-2, keepdim=True) |
| return M |
|
|
|
|
| |
| |
| |
|
|
| class mHC(nn.Module): |
| """Manifold-constrained hyper-connections (full-dim streams). |
| |
| Each of n_streams streams carries the full hidden dim D. Between |
| blocks the hidden state is (B, S, n*D). The model must call |
| expand() after embedding and contract() before the final norm. |
| """ |
|
|
| def __init__(self, dim: int, n_streams: int = 4, sinkhorn_iters: int = 5): |
| super().__init__() |
| self.dim = dim |
| self.n = n_streams |
| self.sinkhorn_iters = sinkhorn_iters |
| self.norm = nn.RMSNorm(dim) |
|
|
| self.logits_res = nn.Parameter(10.0 * torch.eye(n_streams)) |
|
|
| w_pre_init = math.log(1.0 / (n_streams - 1)) |
| self.w_pre = nn.Parameter(torch.full((n_streams,), w_pre_init)) |
| self.w_post = nn.Parameter(torch.zeros(n_streams)) |
|
|
| self._H_res_cache = None |
|
|
| def expand(self, x: torch.Tensor) -> torch.Tensor: |
| """(B, S, D) -> (B, S, n*D): replicate into n streams.""" |
| B, S, D = x.shape |
| return x.unsqueeze(2).expand(B, S, self.n, D).reshape(B, S, self.n * D) |
|
|
| def contract(self, x: torch.Tensor) -> torch.Tensor: |
| """(B, S, n*D) -> (B, S, D): average across streams.""" |
| B, S, _ = x.shape |
| return x.view(B, S, self.n, self.dim).mean(dim=2) |
|
|
| def _get_H_res(self): |
| if self._H_res_cache is None or self.training: |
| self._H_res_cache = _sinkhorn(self.logits_res, self.sinkhorn_iters).clone() |
| return self._H_res_cache |
|
|
| def forward(self, x: torch.Tensor, sublayer) -> torch.Tensor: |
| B, S, _ = x.shape |
| n, D = self.n, self.dim |
|
|
| H_res = self._get_H_res() |
| h_pre = torch.sigmoid(self.w_pre) |
| h_post = 2 * torch.sigmoid(self.w_post) |
|
|
| streams = x.view(B, S, n, D) |
| mixed = torch.einsum("ij,bsjd->bsid", H_res, streams) |
|
|
| sublayer_in = torch.einsum("i,bsid->bsd", h_pre, mixed) |
| y = sublayer(self.norm(sublayer_in)) |
|
|
| y_dist = y.unsqueeze(2) * h_post.view(1, 1, n, 1) |
| out = mixed + y_dist |
| return out.reshape(B, S, n * D) |
|
|
|
|
| |
| |
| |
|
|
| class GatedGQA(nn.Module): |
| """Grouped Query Attention with RoPE, QK-norm, and sigmoid output gate.""" |
|
|
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.num_heads = cfg.num_heads |
| self.num_kv_heads = cfg.num_kv_heads |
| self.head_dim = cfg.head_dim |
| self.kv_group_size = cfg.num_heads // cfg.num_kv_heads |
|
|
| self.q_proj = nn.Linear(cfg.dim, cfg.num_heads * cfg.head_dim, bias=False) |
| self.k_proj = nn.Linear(cfg.dim, cfg.num_kv_heads * cfg.head_dim, bias=False) |
| self.v_proj = nn.Linear(cfg.dim, cfg.num_kv_heads * cfg.head_dim, bias=False) |
| self.o_proj = nn.Linear(cfg.num_heads * cfg.head_dim, cfg.dim, bias=False) |
| self.gate_proj = nn.Linear(cfg.dim, cfg.num_heads * cfg.head_dim, bias=False) |
|
|
| self.qk_norm = nn.RMSNorm(cfg.head_dim) if cfg.qk_norm else None |
| self.rope = RotaryEmbedding(cfg.head_dim, theta=cfg.rope_theta, |
| max_seq_len=cfg.max_seq_len) |
|
|
| def forward(self, x: torch.Tensor, cache: tuple | None = None, |
| start_pos: int = 0, return_cache: bool = False) -> torch.Tensor | tuple: |
| B, S, _ = x.shape |
|
|
| q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| if self.qk_norm: |
| q = self.qk_norm(q).to(v.dtype) |
| k = self.qk_norm(k).to(v.dtype) |
|
|
| |
| rope_len = start_pos + S |
| freqs_full = self.rope(rope_len, x.device) |
| freqs = (freqs_full[0][start_pos:start_pos + S], |
| freqs_full[1][start_pos:start_pos + S]) |
| q = apply_rotary(q, freqs) |
| k = apply_rotary(k, freqs) |
|
|
| |
| |
| if cache is not None: |
| k_cache, v_cache = cache |
| k = torch.cat([k_cache, k], dim=2) |
| v = torch.cat([v_cache, v], dim=2) |
|
|
| new_cache = (k, v) if return_cache else None |
|
|
| |
| use_causal = (cache is None) and (S > 1) |
|
|
| |
| k_exp = k.repeat_interleave(self.kv_group_size, dim=1) if self.kv_group_size > 1 else k |
| v_exp = v.repeat_interleave(self.kv_group_size, dim=1) if self.kv_group_size > 1 else v |
|
|
| out = F.scaled_dot_product_attention(q, k_exp, v_exp, is_causal=use_causal) |
|
|
| |
| if out.shape[1] == self.num_heads: |
| |
| gate = self.gate_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) |
| else: |
| gate = self.gate_proj(x).view(B, S, self.num_heads, self.head_dim) |
| if gate.dtype != out.dtype: |
| gate = gate.to(out.dtype) |
| out = out * torch.sigmoid(gate) |
|
|
| out = out.transpose(1, 2).contiguous().view(B, S, -1) |
|
|
| result = self.o_proj(out) |
| if cache is not None or return_cache: |
| return result, (k, v) |
| return result |
|
|
|
|
| |
| |
| |
|
|
| class ReLU2(nn.Module): |
| """ReLU-squared FFN: relu(x @ W_up)^2 @ W_down""" |
|
|
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| hidden = int(cfg.dim * cfg.intermediate_mult) |
| hidden = ((hidden + 255) // 256) * 256 |
| self.w_up = nn.Linear(cfg.dim, hidden, bias=False) |
| self.w_down = nn.Linear(hidden, cfg.dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w_down(F.relu(self.w_up(x)) ** 2) |
|
|
|
|
| |
| |
| |
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.attn = GatedGQA(cfg) |
| self.ffn = ReLU2(cfg) |
| self.residual_attn = mHC(cfg.dim, cfg.n_streams, cfg.sinkhorn_iters) |
| self.residual_ffn = mHC(cfg.dim, cfg.n_streams, cfg.sinkhorn_iters) |
|
|
| def forward(self, x: torch.Tensor, cache: tuple | None = None, |
| start_pos: int = 0, |
| return_cache: bool = False) -> torch.Tensor | tuple: |
| _side = {} |
|
|
| def _attn_fn(normed_x): |
| out = self.attn(normed_x, cache=cache, start_pos=start_pos, |
| return_cache=return_cache) |
| if isinstance(out, tuple): |
| out, _side['cache'] = out |
| return out |
|
|
| x = self.residual_attn(x, _attn_fn) |
| x = self.residual_ffn(x, self.ffn) |
|
|
| new_cache = _side.get('cache') |
| if new_cache is not None: |
| return x, new_cache |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class GPT(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.embed = nn.Embedding(cfg.vocab_size, cfg.dim) |
| self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)]) |
| self.norm = nn.RMSNorm(cfg.dim) |
| self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) |
| |
| self.embed.weight = self.lm_head.weight |
|
|
| |
| self._needs_expand = True |
|
|
| def forward(self, input_ids: torch.Tensor, |
| kv_cache: list | None = None, |
| start_pos: int = 0, |
| return_cache: bool = True) -> tuple[torch.Tensor, list | None]: |
| """Forward pass returning (logits, kv_cache). |
| |
| Args: |
| input_ids: (B, S) token indices |
| kv_cache: list of (K, V) tuples per layer, or None for prefill |
| start_pos: position offset for RoPE in cached generation |
| return_cache: if True, always return new KV caches (needed for prefill) |
| """ |
| x = self.embed(input_ids) |
|
|
| if self._needs_expand: |
| x = self.blocks[0].residual_attn.expand(x) |
|
|
| new_kv_cache = [] if return_cache else None |
|
|
| for i, block in enumerate(self.blocks): |
| layer_cache = kv_cache[i] if kv_cache is not None else None |
| result = block(x, cache=layer_cache, start_pos=start_pos, |
| return_cache=return_cache) |
| if isinstance(result, tuple): |
| x, layer_new_cache = result |
| if new_kv_cache is not None: |
| new_kv_cache.append(layer_new_cache) |
| else: |
| x = result |
| if new_kv_cache is not None: |
| new_kv_cache.append(None) |
|
|
| if self._needs_expand: |
| x = self.blocks[0].residual_attn.contract(x) |
|
|
| x = self.norm(x) |
| logits = self.lm_head(x) |
| return logits, new_kv_cache |
|
|
|
|
| |
| |
| |
|
|
| def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: |
| """Nucleus (top-p) sampling.""" |
| sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| |
| mask = cumulative_probs - sorted_probs > p |
| sorted_probs[mask] = 0.0 |
| sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True) |
| next_token = torch.multinomial(sorted_probs, num_samples=1) |
| return torch.gather(sorted_indices, -1, next_token) |
|
|
|
|
| def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: |
| """Top-k sampling.""" |
| top_k_probs, top_k_indices = torch.topk(probs, k, dim=-1) |
| top_k_probs /= top_k_probs.sum(dim=-1, keepdim=True) |
| idx = torch.multinomial(top_k_probs, num_samples=1) |
| return torch.gather(top_k_indices, -1, idx) |
|
|
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def generate( |
| model: GPT, |
| prompt_tokens: list[int], |
| max_new_tokens: int = 256, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| top_k: int = 0, |
| stop_tokens: set[int] | None = None, |
| ): |
| """Generate tokens autoregressively with KV caching. |
| |
| Yields (token_id, tokens_per_second) tuples as tokens are generated. |
| """ |
| device = next(model.parameters()).device |
|
|
| tokens = torch.tensor([prompt_tokens], dtype=torch.long, device=device) |
|
|
| with torch.inference_mode(): |
| |
| logits, kv_cache = model(tokens, kv_cache=None, start_pos=0) |
| logits = logits[:, -1, :] |
|
|
| start_time = time.perf_counter() |
| seq_len = tokens.shape[1] |
|
|
| for i in range(max_new_tokens): |
| |
| if temperature <= 0: |
| next_token = logits.argmax(dim=-1, keepdim=True) |
| else: |
| scaled_logits = logits / temperature |
| probs = F.softmax(scaled_logits, dim=-1) |
| if top_k > 0: |
| next_token = sample_top_k(probs, top_k) |
| elif top_p < 1.0: |
| next_token = sample_top_p(probs, top_p) |
| else: |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| token_id = next_token.item() |
|
|
| if stop_tokens and token_id in stop_tokens: |
| break |
|
|
| elapsed = time.perf_counter() - start_time |
| tps = (i + 1) / elapsed if elapsed > 0 else 0.0 |
| yield token_id, tps |
|
|
| |
| start_pos = seq_len + i |
| next_input = next_token.view(1, 1) |
| logits, kv_cache = model(next_input, kv_cache=kv_cache, start_pos=start_pos) |
| logits = logits[:, -1, :] |
|
|
|
|
| |
| |
| |
|
|
| def load_model(checkpoint_path: str, device: str = "cpu", |
| dtype: torch.dtype = torch.float32) -> GPT: |
| """Load a Goedel-mHC-1B checkpoint.""" |
| cfg = ModelConfig() |
| model = GPT(cfg) |
|
|
| print(f"Loading checkpoint from {checkpoint_path} ...", flush=True) |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
|
|
| |
| if isinstance(ckpt, dict) and "model" in ckpt: |
| state_dict = ckpt["model"] |
| step = ckpt.get("step", "?") |
| val_loss = ckpt.get("val_loss", "?") |
| print(f" Step: {step}, Val loss: {val_loss}") |
| else: |
| state_dict = ckpt |
|
|
| |
| |
| if "embed.weight" in state_dict and "lm_head.weight" in state_dict: |
| del state_dict["embed.weight"] |
|
|
| result = model.load_state_dict(state_dict, strict=False) |
| if result.unexpected_keys: |
| print(f" Warning: unexpected keys: {result.unexpected_keys}") |
| if result.missing_keys: |
| |
| missing = [k for k in result.missing_keys if k != "embed.weight"] |
| if missing: |
| print(f" Warning: missing keys: {missing}") |
|
|
| model = model.to(dtype=dtype, device=device) |
| model.eval() |
|
|
| |
| for block in model.blocks: |
| block.residual_attn._get_H_res() |
| block.residual_ffn._get_H_res() |
|
|
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {total_params:,} ({total_params / 1e9:.2f}B)") |
| print(f" Device: {device}, Dtype: {dtype}") |
| return model |
|
|
|
|
| |
| |
| |
|
|
| def get_tokenizer(): |
| """Get GPT-2 tokenizer via tiktoken.""" |
| import tiktoken |
| return tiktoken.get_encoding("gpt2") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Goedel-mHC-1B inference", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| Examples: |
| python inference.py --checkpoint ckpt_best.pt "The capital of France is" |
| python inference.py --checkpoint ckpt_best.pt --interactive |
| python inference.py --checkpoint ckpt_best.pt --temperature 0 "def fibonacci(n):" |
| python inference.py --checkpoint ckpt_best.pt --device cpu "Once upon a time" |
| """, |
| ) |
| parser.add_argument("--checkpoint", type=str, required=True, |
| help="Path to checkpoint file (.pt)") |
| parser.add_argument("--interactive", action="store_true", |
| help="Enter interactive REPL mode") |
| parser.add_argument("--max-tokens", type=int, default=256, |
| help="Maximum tokens to generate (default: 256)") |
| parser.add_argument("--temperature", type=float, default=0.7, |
| help="Sampling temperature (0 = greedy, default: 0.7)") |
| parser.add_argument("--top-p", type=float, default=0.9, |
| help="Nucleus sampling threshold (default: 0.9)") |
| parser.add_argument("--top-k", type=int, default=0, |
| help="Top-k sampling (0 = disabled, default: 0)") |
| parser.add_argument("--device", type=str, default=None, |
| help="Device (default: cuda if available, else cpu)") |
| parser.add_argument("prompt", nargs="*", type=str, |
| help="Prompt text (ignored in interactive mode)") |
| args = parser.parse_args() |
|
|
| |
| if args.device: |
| device = args.device |
| elif torch.cuda.is_available(): |
| device = "cuda" |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
|
|
| |
| if device == "cuda": |
| dtype = torch.bfloat16 |
| elif device == "mps": |
| dtype = torch.float16 |
| else: |
| dtype = torch.float32 |
|
|
| |
| model = load_model(args.checkpoint, device=device, dtype=dtype) |
|
|
| |
| enc = get_tokenizer() |
| eot_token = enc.eot_token |
|
|
| def run_generation(prompt_text: str): |
| tokens = enc.encode(prompt_text) |
| if not tokens: |
| print("(empty prompt)") |
| return |
|
|
| print(prompt_text, end="", flush=True) |
|
|
| generated = [] |
| tps = 0.0 |
| for token_id, tps in generate( |
| model, tokens, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| top_k=args.top_k, |
| stop_tokens={eot_token}, |
| ): |
| text = enc.decode([token_id]) |
| print(text, end="", flush=True) |
| generated.append(token_id) |
|
|
| print() |
| print(f"\n--- {len(generated)} tokens, {tps:.1f} tok/s ---") |
|
|
| if args.interactive: |
| print("=" * 60) |
| print("Goedel-mHC-1B Interactive Mode") |
| print("Type your prompt and press Enter. Ctrl+C or 'quit' to exit.") |
| print("=" * 60) |
| while True: |
| try: |
| prompt = input("\n>>> ") |
| except (KeyboardInterrupt, EOFError): |
| print("\nBye!") |
| break |
| prompt = prompt.strip() |
| if not prompt or prompt.lower() in ("quit", "exit"): |
| if prompt.lower() in ("quit", "exit"): |
| print("Bye!") |
| break |
| print() |
| run_generation(prompt) |
| else: |
| prompt_text = " ".join(args.prompt) |
| if not prompt_text: |
| parser.error("Provide a prompt or use --interactive") |
| run_generation(prompt_text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|