#!/usr/bin/env python3 """ K-Simplex Language Model - Inference Script Loads a trained k-simplex LLM checkpoint and generates text using geometrically-validated autoregressive sampling. Usage: python inference.py --checkpoint checkpoint_epoch_008.pt --prompt "ROMEO: " python inference.py --repo AbstractPhil/ksimplex-llm-prototype --prompt "To be or not" """ import argparse import json import math import torch import torch.nn as nn import torch.nn.functional as F import tiktoken from pathlib import Path from huggingface_hub import hf_hub_download # ============================================================================= # GEOMETRIC CORE # ============================================================================= def factorial(n: int) -> int: return math.factorial(n) def cayley_menger_volume_squared(vertices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Compute squared volume via Cayley-Menger determinant. Args: vertices: [*, nv, edim] vertex coordinates Returns: d2: [*, n_pairs] squared distances vol2: [*] squared volume """ nv = vertices.shape[-2] k = nv - 1 # simplex dimension # Pairwise squared distances diff = vertices.unsqueeze(-2) - vertices.unsqueeze(-3) # [*, nv, nv, edim] d2_matrix = (diff ** 2).sum(-1) # [*, nv, nv] # Extract upper triangle (pairs) idx = torch.triu_indices(nv, nv, offset=1) d2 = d2_matrix[..., idx[0], idx[1]] # [*, n_pairs] # Build Cayley-Menger matrix batch_shape = vertices.shape[:-2] size = nv + 1 cm = torch.zeros(*batch_shape, size, size, device=vertices.device, dtype=vertices.dtype) # First row/col: [0, 1, 1, ..., 1] cm[..., 0, 1:] = 1.0 cm[..., 1:, 0] = 1.0 # Fill distance submatrix cm[..., 1:, 1:] = d2_matrix # Diagonal of distance submatrix is 0 (already set) # Determinant det = torch.linalg.det(cm) # Volume formula: Vol² = (-1)^(k+1) * det(CM) / (2^k * (k!)²) sign = (-1) ** (k + 1) denom = (2 ** k) * (factorial(k) ** 2) vol2 = sign * det / denom return d2, vol2 # ============================================================================= # MODEL COMPONENTS # ============================================================================= class SimplexTemplate(nn.Module): """Generates regular simplex template vertices.""" def __init__(self, k: int, edim: int, scale: float = 1.0): super().__init__() self.k = k self.nv = k + 1 self.edim = edim # Regular simplex vertices (equilateral) vertices = torch.zeros(self.nv, edim) for i in range(self.nv): angle = 2 * math.pi * i / self.nv vertices[i, 0] = scale * math.cos(angle) if edim > 1: vertices[i, 1] = scale * math.sin(angle) if edim > 2: vertices[i, 2] = scale * 0.3 * math.cos(angle * 2) for d in range(3, edim): vertices[i, d] = scale * 0.1 * math.sin(angle * (d + 1)) self.register_buffer('template', vertices) def forward(self) -> torch.Tensor: return self.template class KSimplexChannel(nn.Module): """Single k-simplex channel with geometric validation.""" def __init__(self, k: int, edim: int, hidden: int, feat_dim: int, base_deform: float = 0.05): super().__init__() self.k = k self.nv = k + 1 self.edim = edim self.feat_dim = feat_dim self.base_deform = base_deform # Template self.template = SimplexTemplate(k, edim) # Projections self._to_coords = nn.Linear(hidden, self.nv * edim) self._to_feats = nn.Linear(hidden, self.nv * feat_dim) # Geometry dimension: n_pairs + 1 (vol²) n_pairs = (self.nv * (self.nv - 1)) // 2 self.geo_dim = n_pairs + 1 # Geometric gate self._geo_gate = nn.Sequential( nn.Linear(self.geo_dim, feat_dim), nn.Sigmoid() ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: [*, hidden] Returns: out: [*, feat_dim + geo_dim] gated features + geometry vol2: [*] squared volume for validity loss mean_d2: [*] mean squared distance """ # Vertex coordinates coords = self._to_coords(x).unflatten(-1, (self.nv, self.edim)) verts = self.template() + self.base_deform * coords # Vertex features vert_feats = self._to_feats(x).unflatten(-1, (self.nv, self.feat_dim)) # Cayley-Menger d2, vol2 = cayley_menger_volume_squared(verts) # Geometry vector geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) # Gate features by geometry gate = self._geo_gate(geo) validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1) # Aggregate vertex features feat_agg = vert_feats.mean(dim=-2) * gate * validity # Output out = torch.cat([feat_agg, geo], dim=-1) return out, vol2, d2.mean(dim=-1) class TokenToKChannels(nn.Module): """Project token embeddings to k-simplex channels.""" def __init__(self, embed_dim: int, hidden: int, depth: int, edim: int, feat_dim: int): super().__init__() self.depth = depth self._proj = nn.Linear(embed_dim, hidden) self._channels = nn.ModuleList([ KSimplexChannel(k=k+1, edim=edim, hidden=hidden, feat_dim=feat_dim) for k in range(depth) ]) # Compute output dimension (max across k-levels, then pad) self.out_dims = [ch.feat_dim + ch.geo_dim for ch in self._channels] self.max_dim = max(self.out_dims) # Padding projections to equalize dimensions self._pads = nn.ModuleList([ nn.Linear(d, self.max_dim) if d != self.max_dim else nn.Identity() for d in self.out_dims ]) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]: """ Args: x: [B, T, embed_dim] Returns: out: [B, T, K, max_dim] vol2_list: list of [B, T] per k d2_list: list of [B, T] per k """ h = self._proj(x) # [B, T, hidden] outputs = [] vol2_list = [] d2_list = [] for ch, pad in zip(self._channels, self._pads): out, vol2, d2 = ch(h) outputs.append(pad(out)) vol2_list.append(vol2) d2_list.append(d2) # Stack: [B, T, K, max_dim] out = torch.stack(outputs, dim=-2) return out, vol2_list, d2_list class KChannelCrossAttention(nn.Module): """Cross-attention between k-levels at each position.""" def __init__(self, dim: int, num_heads: int = 4, dropout: float = 0.1): super().__init__() self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) self.norm = nn.LayerNorm(dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [B, T, K, D] Returns: [B, T, K, D] """ B, T, K, D = x.shape # Reshape to [B*T, K, D] - attention across K dimension x_flat = x.view(B * T, K, D) # Self-attention across k-levels attn_out, _ = self.attn(x_flat, x_flat, x_flat) # Residual + norm out = self.norm(x_flat + attn_out) return out.view(B, T, K, D) class CausalSequenceAttention(nn.Module): """Causal attention across sequence positions.""" def __init__(self, dim: int, num_heads: int, max_seq_len: int, dropout: float = 0.1): super().__init__() self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True) self.norm = nn.LayerNorm(dim) # Causal mask mask = torch.tril(torch.ones(max_seq_len, max_seq_len)).bool() self.register_buffer('_causal_mask', mask) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [B, T, K, D] Returns: [B, T, K, D] """ B, T, K, D = x.shape # Flatten K into D: [B, T, K*D] x_flat = x.view(B, T, K * D) # Causal mask mask = self._causal_mask[:T, :T] attn_mask = ~mask # True = masked # Self-attention across sequence attn_out, _ = self.attn( x_flat, x_flat, x_flat, attn_mask=attn_mask.float().masked_fill(attn_mask, float('-inf')) ) # Residual + norm out = self.norm(x_flat + attn_out) return out.view(B, T, K, D) class GeoBlock(nn.Module): """Geometric block: k-channel attention + causal sequence attention + MLP.""" def __init__(self, dim: int, num_heads: int, max_seq_len: int, depth: int, dropout: float = 0.1): super().__init__() self.k_attn = KChannelCrossAttention(dim, num_heads=4, dropout=dropout) self.seq_attn = CausalSequenceAttention(dim, num_heads, max_seq_len, dropout) self.mlp = nn.Sequential( nn.Linear(dim * depth, dim * depth * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * depth * 4, dim * depth), nn.Dropout(dropout), ) self.mlp_norm = nn.LayerNorm(dim * depth) self.depth = depth def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [B, T, K, D] Returns: [B, T, K, D] """ # K-channel attention x = self.k_attn(x) # Sequence attention x = self.seq_attn(x) # MLP on flattened k-channels B, T, K, D = x.shape x_flat = x.view(B, T, K * D) x_flat = self.mlp_norm(x_flat + self.mlp(x_flat)) return x_flat.view(B, T, K, D) class KSimplexLM(nn.Module): """K-Simplex Language Model.""" def __init__( self, vocab_size: int = 50257, max_seq_len: int = 256, embed_dim: int = 384, depth: int = 4, edim: int = 16, feat_dim: int = 96, hidden: int = 384, num_heads: int = 8, num_blocks: int = 8, dropout: float = 0.1, ): super().__init__() self.vocab_size = vocab_size self.max_seq_len = max_seq_len self.depth = depth # Token embedding self.embed = nn.Embedding(vocab_size, embed_dim) self.pos_embed = nn.Embedding(max_seq_len, embed_dim) self.embed_drop = nn.Dropout(dropout) # Token to k-channels self.to_k_channels = TokenToKChannels(embed_dim, hidden, depth, edim, feat_dim) # Geometric blocks k_dim = self.to_k_channels.max_dim self.blocks = nn.ModuleList([ GeoBlock(k_dim, num_heads, max_seq_len, depth, dropout) for _ in range(num_blocks) ]) # LM head self.ln_f = nn.LayerNorm(k_dim * depth) self.lm_head = nn.Linear(k_dim * depth, vocab_size, bias=False) # Weight tying # self.lm_head.weight = self.embed.weight # Optional self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]: """ Args: x: [B, T] token indices Returns: logits: [B, T, vocab_size] geo_info: dict with vol2, d2 per k-level """ B, T = x.shape # Embeddings pos = torch.arange(T, device=x.device).unsqueeze(0) h = self.embed(x) + self.pos_embed(pos) h = self.embed_drop(h) # To k-channels h, vol2_list, d2_list = self.to_k_channels(h) # Geo blocks for block in self.blocks: h = block(h) # LM head h_flat = h.view(B, T, -1) h_flat = self.ln_f(h_flat) logits = self.lm_head(h_flat) geo_info = { 'vol2': vol2_list, 'd2': d2_list, } return logits, geo_info # ============================================================================= # INFERENCE UTILITIES # ============================================================================= def load_model( checkpoint_path: str = None, repo_id: str = None, device: str = None, ) -> tuple[KSimplexLM, tiktoken.Encoding]: """ Load model from checkpoint or HuggingFace Hub. Args: checkpoint_path: Local path to checkpoint repo_id: HuggingFace repo ID (e.g., "AbstractPhil/ksimplex-llm-prototype") device: Device to load to Returns: model: KSimplexLM tokenizer: tiktoken encoding """ if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load checkpoint if repo_id: checkpoint_path = hf_hub_download(repo_id, "checkpoint_latest.pt") config_path = hf_hub_download(repo_id, "config.json") with open(config_path) as f: config = json.load(f) elif checkpoint_path: checkpoint = torch.load(checkpoint_path, map_location=device) config = checkpoint.get('config', {}).get('model', {}) else: raise ValueError("Must provide checkpoint_path or repo_id") # Build model model = KSimplexLM( vocab_size=config.get('vocab_size', 50257), max_seq_len=config.get('max_seq_len', 256), embed_dim=config.get('embed_dim', 384), depth=config.get('depth', 4), edim=config.get('edim', 16), feat_dim=config.get('feat_dim', 96), hidden=config.get('hidden', 384), num_heads=config.get('num_heads', 8), num_blocks=config.get('num_blocks', 8), dropout=0.0, # No dropout at inference ) # Load weights if repo_id: checkpoint = torch.load(checkpoint_path, map_location=device) state_dict = checkpoint.get('model_state_dict', checkpoint) model.load_state_dict(state_dict) model.to(device) model.eval() # Tokenizer tokenizer = tiktoken.get_encoding("gpt2") return model, tokenizer @torch.no_grad() def generate( model: KSimplexLM, tokenizer: tiktoken.Encoding, prompt: str, max_tokens: int = 100, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9, device: str = None, ) -> str: """ Generate text from prompt. Args: model: KSimplexLM model tokenizer: tiktoken encoding prompt: Input text prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature top_k: Top-k sampling top_p: Nucleus sampling threshold device: Device Returns: Generated text including prompt """ if device is None: device = next(model.parameters()).device # Encode prompt tokens = tokenizer.encode(prompt) tokens = torch.tensor([tokens], dtype=torch.long, device=device) # Generate for _ in range(max_tokens): # Truncate to max_seq_len if tokens.shape[1] > model.max_seq_len: tokens = tokens[:, -model.max_seq_len:] # Forward logits, geo_info = model(tokens) logits = logits[:, -1, :] # Last position # Temperature logits = logits / temperature # Top-k if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') # Top-p (nucleus) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above threshold sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') # Sample probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append tokens = torch.cat([tokens, next_token], dim=1) # Stop on EOS (optional) if next_token.item() == tokenizer.eot_token: break # Decode return tokenizer.decode(tokens[0].tolist()) @torch.no_grad() def analyze_geometry( model: KSimplexLM, tokenizer: tiktoken.Encoding, text: str, device: str = None, ) -> dict: """ Analyze geometric properties of text encoding. Args: model: KSimplexLM model tokenizer: tiktoken encoding text: Input text device: Device Returns: Dictionary with geometric statistics """ if device is None: device = next(model.parameters()).device tokens = tokenizer.encode(text) tokens = torch.tensor([tokens], dtype=torch.long, device=device) _, geo_info = model(tokens) stats = {} for k, (vol2, d2) in enumerate(zip(geo_info['vol2'], geo_info['d2']), 1): vol2_np = vol2.cpu().numpy() d2_np = d2.cpu().numpy() stats[f'k{k}'] = { 'vol2_mean': float(vol2_np.mean()), 'vol2_std': float(vol2_np.std()), 'vol2_min': float(vol2_np.min()), 'vol2_max': float(vol2_np.max()), 'validity_rate': float((vol2_np > 0).mean()), 'd2_mean': float(d2_np.mean()), } return stats # ============================================================================= # CLI # ============================================================================= def main(): parser = argparse.ArgumentParser(description='K-Simplex LLM Inference') parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file') parser.add_argument('--repo', type=str, default='AbstractPhil/ksimplex-llm-prototype', help='HuggingFace repo ID') parser.add_argument('--prompt', type=str, default='ROMEO: ', help='Text prompt') parser.add_argument('--max_tokens', type=int, default=100, help='Maximum tokens to generate') parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature') parser.add_argument('--top_k', type=int, default=50, help='Top-k sampling') parser.add_argument('--top_p', type=float, default=0.9, help='Nucleus sampling threshold') parser.add_argument('--analyze', action='store_true', help='Analyze geometric properties instead of generating') args = parser.parse_args() print("Loading model...") model, tokenizer = load_model( checkpoint_path=args.checkpoint, repo_id=args.repo if not args.checkpoint else None, ) print(f"Model loaded on {next(model.parameters()).device}") if args.analyze: print(f"\nAnalyzing: {args.prompt}") stats = analyze_geometry(model, tokenizer, args.prompt) for k, kstats in stats.items(): print(f"\n{k}:") for name, value in kstats.items(): print(f" {name}: {value:.6f}") else: print(f"\nGenerating from: {args.prompt}") text = generate( model, tokenizer, args.prompt, max_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, ) print("\n" + "=" * 60) print(text) print("=" * 60) if __name__ == '__main__': main()