#!/usr/bin/env python3 from __future__ import annotations import json import math import os import re import shutil import socket import string import sys import threading import webbrowser from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple from urllib.parse import quote, unquote, urlparse from urllib.request import Request, urlopen import hashlib import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint # --------------------------------------------------------------------------- # Config (from ailay.config) # --------------------------------------------------------------------------- @dataclass class ModelConfig: dim: int = 128 n_unique_layers: int = 8 n_logical_layers: int = 16 n_heads: int = 4 n_kv_heads: int = 2 ffn_dim: int = 224 dropout: float = 0.0 seq_len: int = 2048 sliding_window_size: int = 512 mtp_horizons: Tuple[int, ...] = (2, 3, 4) rope_fraction: float = 0.5 embed_scale: bool = True logit_soft_cap: float = -1.0 quantization: str = "nvfp4" # Engram (conditional memory) config engram_dim: int = 0 engram_heads: int = 4 engram_table_size: int = 8192 engram_max_ngram: int = 3 # mHC (Manifold-Constrained Hyper-Connections) config mhc_expansion: int = 1 @property def head_dim(self) -> int: return self.dim // self.n_heads model_config = ModelConfig() MODEL_SERIES = { "haiku": { "dim": 64, "n_unique_layers": 12, "n_logical_layers": 24, "n_heads": 4, "n_kv_heads": 2, "ffn_dim": 384, "dropout": 0.0, "seq_len": 2048, "mtp_horizons": (2, 3, 4), "rope_fraction": 0.5, "batch_size": 80, "grad_accum": 1, "lr": 8e-4, "min_lr": 1e-5, "sft_lr": 2e-4, "sft_min_lr": 1e-5, "warmup_steps": 300, "weight_decay": 0.02, "pretrain_passes": 2, "sft_passes": 3, "max_sft_target_chars": 0, "use_grad_checkpoint": True, "num_workers": 24, "prefetch_factor": 64, "shuffle_buffer": 8192, "max_pretrain_tokens": 0, "min_pretrain_tokens": 100_000_000, "quantization": "nvfp4", "engram_dim": 8, "engram_heads": 2, "engram_table_size": 64, "engram_max_ngram": 2, "mhc_expansion": 2, }, "sonnet": { "dim": 1024, "n_unique_layers": 20, "n_logical_layers": 40, "n_heads": 16, "n_kv_heads": 4, "ffn_dim": 4096, "dropout": 0.0, "seq_len": 2048, "mtp_horizons": (2,), "rope_fraction": 0.5, "batch_size": 24, "grad_accum": 1, "lr": 1e-4, "min_lr": 2e-5, "sft_lr": 5e-5, "sft_min_lr": 5e-6, "warmup_steps": 250, "weight_decay": 0.1, "pretrain_passes": 1, "sft_passes": 1, "max_sft_target_chars": 0, "use_grad_checkpoint": True, "num_workers": 32, "prefetch_factor": 64, "shuffle_buffer": 16384, "max_pretrain_tokens": 0, "min_pretrain_tokens": 100_000_000, "quantization": "nvfp4", "engram_dim": 32, "engram_heads": 8, "engram_table_size": 4096, "engram_max_ngram": 2, "mhc_expansion": 2, }, "opus": { "dim": 1536, "n_unique_layers": 18, "n_logical_layers": 36, "n_heads": 16, "n_kv_heads": 4, "ffn_dim": 5888, "dropout": 0.0, "seq_len": 2048, "mtp_horizons": (2,), "rope_fraction": 0.5, "batch_size": 24, "grad_accum": 1, "lr": 1.6e-4, "min_lr": 1.6e-5, "sft_lr": 3e-5, "sft_min_lr": 3e-6, "warmup_steps": 200, "weight_decay": 0.1, "pretrain_passes": 1, "sft_passes": 1, "max_sft_target_chars": 0, "use_grad_checkpoint": True, "num_workers": 48, "prefetch_factor": 64, "shuffle_buffer": 16384, "max_pretrain_tokens": 0, "min_pretrain_tokens": 100_000_000, "quantization": "nvfp4", "engram_dim": 64, "engram_heads": 8, "engram_table_size": 8192, "engram_max_ngram": 2, "mhc_expansion": 4, }, } # --------------------------------------------------------------------------- # Tokenizer (from ailay.tokenizer) # --------------------------------------------------------------------------- FORMAT_TOKENS = [ "<|user|>", "<|assistant|>", "<|system|>", "<|start_header_id|>", "<|end_header_id|>", "<|begin_of_thought|>", "<|end_of_thought|>", "<|begin_of_solution|>", "<|end_of_solution|>", ] class WordTokenizer: WORD_RE = re.compile( r"\s+|[^\W\d_]+(?:['\u2019][^\W\d_]+)?|\d+|[^\w\s]+", re.UNICODE ) def __init__( self, extra_chars: str = "", format_tokens: Optional[List[str]] = None ) -> None: base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r" fallback_chars = sorted(set(base + extra_chars)) self.core_special = ["", "", "", ""] self.format_tokens = ( list(format_tokens) if format_tokens else list(FORMAT_TOKENS) ) self.special = list(self.core_special) + list(self.format_tokens) self.id_to_token: List[str] = ( list(self.core_special) + self.format_tokens + fallback_chars ) self.token_to_id: Dict[str, int] = { t: i for i, t in enumerate(self.id_to_token) } self.special_multi_tokens = sorted( [t for t in self.special if len(t) > 1], key=len, reverse=True ) self.multi_char_tokens = self.special_multi_tokens self.dynamic_additions = 0 @property def pad_id(self) -> int: return self.token_to_id[""] @property def bos_id(self) -> int: return self.token_to_id[""] @property def eos_id(self) -> int: return self.token_to_id[""] @property def unk_id(self) -> int: return self.token_to_id[""] @property def vocab_size(self) -> int: return len(self.id_to_token) def maybe_add_char(self, ch: str) -> bool: if ch in self.token_to_id: return False self.token_to_id[ch] = len(self.id_to_token) self.id_to_token.append(ch) self.dynamic_additions += 1 return True def maybe_add_token(self, token: str) -> bool: if token in self.token_to_id: return False self.token_to_id[token] = len(self.id_to_token) self.id_to_token.append(token) self.dynamic_additions += 1 return True def iter_lexical_tokens(self, text: str) -> Iterator[str]: i = 0 n = len(text) while i < n: matched_special = False for token in self.special_multi_tokens: if text.startswith(token, i): yield token i += len(token) matched_special = True break if matched_special: continue m = self.WORD_RE.match(text, i) if m is None: yield text[i] i += 1 continue tok = m.group(0) yield tok i = m.end() def encode( self, text: str, add_bos: bool = False, add_eos: bool = False ) -> List[int]: out: List[int] = [] if add_bos: out.append(self.bos_id) unk = self.unk_id t2i = self.token_to_id for tok in self.iter_lexical_tokens(text): tid = t2i.get(tok) if tid is not None: out.append(tid) continue for ch in tok: out.append(t2i.get(ch, unk)) if add_eos: out.append(self.eos_id) return out def decode(self, ids: Sequence[int], skip_special: bool = True) -> str: pieces: List[str] = [] for idx in ids: if int(idx) < 0 or int(idx) >= len(self.id_to_token): continue tok = self.id_to_token[int(idx)] if skip_special and tok in self.special: continue pieces.append(tok) return "".join(pieces) def save(self, path: Path) -> None: with path.open("w", encoding="utf-8") as f: json.dump( { "id_to_token": self.id_to_token, "format_tokens": self.format_tokens, "core_special": self.core_special, "tokenizer_type": "word_level_v1", }, f, ensure_ascii=False, indent=2, ) @classmethod def load(cls, path: Path) -> WordTokenizer: with path.open("r", encoding="utf-8") as f: data = json.load(f) format_tokens = data.get("format_tokens", FORMAT_TOKENS) tokenizer = cls(extra_chars="", format_tokens=format_tokens) tokenizer.id_to_token = data["id_to_token"] tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)} tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens) tokenizer.special_multi_tokens = sorted( [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True ) tokenizer.multi_char_tokens = tokenizer.special_multi_tokens return tokenizer LetterTokenizer = WordTokenizer # --------------------------------------------------------------------------- # Model (from ailay.model) # --------------------------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: if hasattr(torch.nn.functional, "rms_norm"): return torch.nn.functional.rms_norm( x, self.weight.shape, self.weight, self.eps ) x_fp = x.float() rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps) return (x_fp * rms).to(dtype=x.dtype) * self.weight class RotaryEmbedding(nn.Module): def __init__(self, dim: int, base: float = 10000.0) -> None: super().__init__() inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv, persistent=False) def cos_sin( self, seq_len: int, device: torch.device, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) cos = emb.cos()[None, None, :, :].to(dtype=dtype) sin = emb.sin()[None, None, :, :].to(dtype=dtype) return cos, sin def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) class CausalSelfAttention(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, dropout: float, sliding_window: int, rope_fraction: float, ) -> None: super().__init__() self.dim = dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim self.n_rep = n_heads // n_kv_heads self.dropout = dropout self.sliding_window = sliding_window self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) for lin in (self.wq, self.wk, self.wv): nn.init.normal_(lin.weight, std=dim ** -0.5) nn.init.normal_(self.wo.weight, std=(n_heads * head_dim) ** -0.5) self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2) self.rope = RotaryEmbedding(self.rope_dim) self.q_norm = RMSNorm(head_dim) self.k_norm = RMSNorm(head_dim) self.output_gate = nn.Parameter(torch.zeros(n_heads)) def forward( self, x: torch.Tensor, is_global: bool, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, T, _ = x.shape q = self.wq(x).view(B, T, self.n_heads, self.head_dim) k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) q = self.q_norm(q) k = self.k_norm(k) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) past_len = past_kv[0].shape[2] if past_kv is not None else 0 cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype) cos_slice = cos[:, :, past_len : past_len + T, :] sin_slice = sin[:, :, past_len : past_len + T, :] q_rope = q[..., : self.rope_dim] q_pass = q[..., self.rope_dim :] k_rope = k[..., : self.rope_dim] k_pass = k[..., self.rope_dim :] q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice) k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice) q = torch.cat([q_rope, q_pass], dim=-1) k = torch.cat([k_rope, k_pass], dim=-1) if past_kv is not None: k = torch.cat([past_kv[0], k], dim=2) v = torch.cat([past_kv[1], v], dim=2) new_kv = (k, v) if use_cache else None S = k.shape[2] if self.n_rep > 1: k = ( k[:, :, None, :, :] .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) .reshape(B, self.n_heads, S, self.head_dim) ) v = ( v[:, :, None, :, :] .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) .reshape(B, self.n_heads, S, self.head_dim) ) drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0 if is_global: if past_kv is None and T > 1: out = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=drop_p ) else: out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p) else: T_q = q.shape[2] q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1) k_pos = torch.arange(S, device=q.device).unsqueeze(0) mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window) out = F.scaled_dot_product_attention( q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p ) gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1) out = out * gate out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) out = self.wo(out) return out, new_kv class SwiGLU(nn.Module): def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None: super().__init__() self.gate = nn.Linear(dim, hidden_dim, bias=False) self.up = nn.Linear(dim, hidden_dim, bias=False) self.down = nn.Linear(hidden_dim, dim, bias=False) self.drop = nn.Dropout(dropout) nn.init.normal_(self.gate.weight, std=dim ** -0.5) nn.init.normal_(self.up.weight, std=dim ** -0.5) nn.init.normal_(self.down.weight, std=hidden_dim ** -0.5) def forward(self, x: torch.Tensor) -> torch.Tensor: h = F.silu(self.gate(x)) * self.up(x) out = self.down(h) if self.training and torch.is_grad_enabled(): out = self.drop(out) return out class EngramBlock(nn.Module): """Conditional memory via O(1) hashed N-gram lookup (DeepSeek Engram).""" def __init__( self, dim: int, engram_dim: int, n_heads: int = 4, table_size: int = 8192, max_ngram: int = 3, ) -> None: super().__init__() self.dim = dim self.engram_dim = engram_dim self.n_heads = n_heads self.table_size = table_size self.max_ngram = max_ngram self.embeddings = nn.ParameterDict() for n in range(2, max_ngram + 1): for k in range(n_heads): self.embeddings[f"{n}_{k}"] = nn.Parameter( torch.randn(table_size, engram_dim) * (engram_dim**-0.5) ) for n in range(2, max_ngram + 1): for k in range(n_heads): seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16) rng = torch.Generator().manual_seed(seed) a = torch.randint(1, 2**31, (1,), generator=rng).item() b = torch.randint(0, 2**31, (1,), generator=rng).item() self.register_buffer( f"hash_a_{n}_{k}", torch.tensor(a), persistent=False ) self.register_buffer( f"hash_b_{n}_{k}", torch.tensor(b), persistent=False ) total_branch_dim = engram_dim * n_heads * (max_ngram - 1) self.branch_conv = nn.Conv1d( total_branch_dim, total_branch_dim, kernel_size=4, dilation=max_ngram, padding=0, groups=total_branch_dim, bias=True, ) nn.init.zeros_(self.branch_conv.weight) nn.init.zeros_(self.branch_conv.bias) self.gate_query = nn.Linear(dim, engram_dim, bias=False) self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False) self.gate_value = nn.Linear(total_branch_dim, dim, bias=False) self.gate_scale = engram_dim**-0.5 def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor: a = getattr(self, f"hash_a_{n}_{k}") b = getattr(self, f"hash_b_{n}_{k}") B, T = token_ids.shape padded = F.pad(token_ids, (n - 1, 0), value=0) combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device) for i in range(n): combined = (combined * 31 + padded[:, i : i + T].long()) % self.table_size return ((a * combined) ^ b) % self.table_size def forward( self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: B, T, _ = hidden.shape if token_ids is None: token_ids = hidden.mean(dim=-1).long() % self.table_size all_indices = [] all_tables = [] for n in range(2, self.max_ngram + 1): for k in range(self.n_heads): all_indices.append(self._hash_ngram(token_ids, n, k)) all_tables.append(self.embeddings[f"{n}_{k}"]) branch_outputs = [tbl[idx] for idx, tbl in zip(all_indices, all_tables)] memory = torch.cat(branch_outputs, dim=-1) conv_in = memory.transpose(1, 2) conv_in = F.pad( conv_in, (self.branch_conv.dilation[0] * (self.branch_conv.kernel_size[0] - 1), 0), ) conv_out = self.branch_conv(conv_in) memory = conv_out.transpose(1, 2) query = self.gate_query(hidden) key = self.gate_key(memory) gate = torch.sigmoid( (query * key).sum(dim=-1, keepdim=True) * self.gate_scale ) value = self.gate_value(memory) return gate * value def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor: M = torch.exp(logits.clamp(-10, 10)) for _ in range(n_iters): M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10) M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10) return M class ManifoldHyperConnection(nn.Module): """Manifold-Constrained Hyper-Connections (mHC) residual wrapper.""" def __init__(self, dim: int, expansion: int = 2) -> None: super().__init__() self.dim = dim self.expansion = expansion n = expansion self.bias_pre = nn.Parameter(torch.zeros(1, n)) self.bias_post = nn.Parameter(torch.zeros(1, n)) self.bias_res = nn.Parameter(torch.zeros(n, n)) self.theta_pre = nn.Linear(n * dim, n, bias=False) self.theta_post = nn.Linear(n * dim, n, bias=False) self.theta_res = nn.Linear(n * dim, n * n, bias=False) self.alpha_pre = nn.Parameter(torch.tensor(0.0)) self.alpha_post = nn.Parameter(torch.tensor(0.0)) self.alpha_res = nn.Parameter(torch.tensor(0.0)) def _compute_mappings( self, x_expanded: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, T, _ = x_expanded.shape n = self.expansion x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]]) d_pre = torch.tanh(self.theta_pre(x_norm)) d_post = torch.tanh(self.theta_post(x_norm)) d_res = self.theta_res(x_norm) H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre) H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post) H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape( B, T, n, n ) H_res = _sinkhorn_knopp(H_res_raw) return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res def expand_stream(self, x: torch.Tensor) -> torch.Tensor: return x.repeat(1, 1, self.expansion) def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor: B, T, _ = x_expanded.shape return x_expanded.view(B, T, self.expansion, self.dim).mean(dim=-2) def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor: B, T, _ = x_expanded.shape x_streams = x_expanded.view(B, T, self.expansion, self.dim) return (H_pre @ x_streams).squeeze(-2) def post_res_mix( self, layer_output: torch.Tensor, x_expanded: torch.Tensor, H_post: torch.Tensor, H_res: torch.Tensor, ) -> torch.Tensor: B, T, _ = x_expanded.shape x_streams = x_expanded.view(B, T, self.expansion, self.dim) mixed = torch.matmul(H_res, x_streams) post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2)) return (mixed + post_out).reshape(B, T, self.expansion * self.dim) class TransformerBlock(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: int, head_dim: int, ffn_dim: int, dropout: float, sliding_window: int, rope_fraction: float, engram_dim: int = 0, engram_heads: int = 4, engram_table_size: int = 8192, engram_max_ngram: int = 3, mhc_expansion: int = 1, ) -> None: super().__init__() self.dim = dim self.norm1 = RMSNorm(dim) self.attn = CausalSelfAttention( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, ) self.norm2 = RMSNorm(dim) self.ffn = SwiGLU(dim, ffn_dim, dropout) self.use_engram = engram_dim > 0 if self.use_engram: self.engram = EngramBlock( dim=dim, engram_dim=engram_dim, n_heads=engram_heads, table_size=engram_table_size, max_ngram=engram_max_ngram, ) self.engram_norm = RMSNorm(dim) self.use_mhc = mhc_expansion > 1 if self.use_mhc: self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion) self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion) def forward( self, x: torch.Tensor, is_global: bool, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, token_ids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: if self.use_mhc: x_exp = self.mhc_attn.expand_stream(x) H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp) attn_in = self.mhc_attn.pre_mix(x_exp, H_pre) attn_out, new_kv = self.attn( self.norm1(attn_in), is_global, past_kv, use_cache ) x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res) if self.use_engram: collapsed = self.mhc_attn.collapse_stream(x_exp) collapsed = collapsed + self.engram( self.engram_norm(collapsed), token_ids=token_ids ) x_exp = self.mhc_attn.expand_stream(collapsed) H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp) ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2) ffn_out = self.ffn(self.norm2(ffn_in)) x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2) x = self.mhc_attn.collapse_stream(x_exp) else: attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache) x = x + attn_out if self.use_engram: x = x + self.engram(self.engram_norm(x), token_ids=token_ids) x = x + self.ffn(self.norm2(x)) return x, new_kv def _detect_engram_dim(state_dict: dict) -> int: for key in state_dict: if ".engram." in key and ".embeddings." in key: return state_dict[key].shape[-1] return 0 def _detect_mhc_expansion(state_dict: dict) -> int: for key, val in state_dict.items(): if ".mhc_attn.bias_pre" in key and val.dim() == 2: return val.shape[-1] return 1 class TinyMemoryLM(nn.Module): def __init__( self, vocab_size: int, dim: int, n_unique_layers: int, n_logical_layers: int, n_heads: int, n_kv_heads: int, ffn_dim: int, dropout: float, mtp_horizons: Sequence[int], grad_checkpoint: bool, sliding_window: int = 512, rope_fraction: float = 0.5, embed_scale: bool = True, engram_dim: int = 0, engram_heads: int = 4, engram_table_size: int = 8192, engram_max_ngram: int = 3, mhc_expansion: int = 1, ) -> None: super().__init__() self.dim = dim self.n_unique_layers = n_unique_layers self.n_logical_layers = n_logical_layers self.grad_checkpoint = grad_checkpoint self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0 head_dim = dim // n_heads self.embed_tokens = nn.Embedding(vocab_size, dim) self.head = nn.Linear(dim, vocab_size, bias=False) self.head.weight = self.embed_tokens.weight self.output_bias = nn.Parameter(torch.zeros(vocab_size)) self.blocks = nn.ModuleList( [ TransformerBlock( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, rope_fraction=rope_fraction, engram_dim=engram_dim, engram_heads=engram_heads, engram_table_size=engram_table_size, engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion, ) for _ in range(n_unique_layers) ] ) self.norm = RMSNorm(dim) self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1}) self.mtp_adapters = nn.ModuleDict( {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons} ) self.mtp_norms = nn.ModuleDict( {str(h): RMSNorm(dim) for h in self.mtp_horizons} ) res_scale = (2 * n_logical_layers) ** -0.5 for block in self.blocks: block.attn.wo.weight.data.mul_(res_scale) block.ffn.down.weight.data.mul_(res_scale) def resize_token_embeddings(self, new_vocab_size: int) -> None: old_vocab_size = self.embed_tokens.num_embeddings if new_vocab_size == old_vocab_size: return device = self.embed_tokens.weight.device old_embed_weight = self.embed_tokens.weight.data.clone() self.embed_tokens = nn.Embedding( new_vocab_size, self.embed_tokens.embedding_dim ).to(device) self.head = nn.Linear( self.embed_tokens.embedding_dim, new_vocab_size, bias=False ).to(device) self.head.weight = self.embed_tokens.weight old_bias = self.output_bias.data.clone() self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device)) copy_size = min(old_vocab_size, new_vocab_size) self.output_bias.data[:copy_size] = old_bias[:copy_size] self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size] def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]: logical = [] blocks_list = list(self.blocks) full_sequence = blocks_list + blocks_list for logical_idx, block in enumerate(full_sequence[: self.n_logical_layers]): logical.append((block, logical_idx)) return logical def forward( self, ids: torch.Tensor, use_cache: bool = False, past_key_values: Optional[ List[Optional[Tuple[torch.Tensor, torch.Tensor]]] ] = None, return_hidden: bool = False, ) -> Tuple[ torch.Tensor, Dict[int, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]], ]: B, T = ids.shape x = self.embed_tokens(ids) * self.embed_scale_factor logical_layers = self._build_logical_layers() new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = ( [] if use_cache else None ) for layer_idx, (block, logical_idx) in enumerate(logical_layers): is_global = logical_idx % 2 == 0 past_kv = ( past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None ) if self.grad_checkpoint and self.training and not use_cache: x, layer_kv = checkpoint( block, x, is_global, past_kv, use_cache, ids, use_reentrant=True ) else: x, layer_kv = block(x, is_global, past_kv, use_cache, ids) if new_past_key_values is not None: new_past_key_values.append(layer_kv) x = self.norm(x) h_out = x if return_hidden else None logits = self.head(x) if self.embed_scale_factor != 1.0: logits = logits / self.embed_scale_factor logits = logits + self.output_bias mtp: Dict[int, torch.Tensor] = {} if self.mtp_horizons and self.training: for horizon in self.mtp_horizons: if horizon > 1 and horizon <= T - 1: shifted_h = x[:, :-horizon, :] adapted_h = self.mtp_adapters[str(horizon)](shifted_h) adapted_h = self.mtp_norms[str(horizon)](adapted_h) mtp_logits = self.head(adapted_h) if self.embed_scale_factor != 1.0: mtp_logits = mtp_logits / self.embed_scale_factor mtp_logits = mtp_logits + self.output_bias mtp[horizon] = mtp_logits return logits, mtp, h_out, new_past_key_values # --------------------------------------------------------------------------- # Generation (from ailay.generation) # --------------------------------------------------------------------------- def sample_text( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int, temperature: float, top_k: int, branches: int, branch_len: int, device: torch.device, seq_len: int, ) -> str: def _sample_id(logits: torch.Tensor) -> torch.Tensor: if not torch.isfinite(logits).any(): logits = torch.zeros_like(logits) logits = torch.where( torch.isfinite(logits), logits, torch.full_like(logits, -1e9) ) if top_k > 0: v, idx = torch.topk(logits, k=min(top_k, logits.shape[-1])) p = torch.softmax(v, dim=-1) return idx.gather(-1, torch.multinomial(p, 1)) p = torch.softmax(logits, dim=-1) return torch.multinomial(p, 1) model.eval() ids = tokenizer.encode(prompt, add_bos=True, add_eos=False) prompt_len = len(ids) x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) with torch.no_grad(): generated = 0 while generated < max_new_tokens: if branches <= 1: ctx = x[:, -seq_len:] logits, _, _, _ = model(ctx) nlogits = logits[:, -1, :] / max(temperature, 1e-6) nid = _sample_id(nlogits) x = torch.cat([x, nid], dim=1) generated += 1 continue rollout = min(branch_len, max_new_tokens - generated) best_nll: Optional[float] = None best_tokens: Optional[List[torch.Tensor]] = None for _ in range(branches): cand = x cand_tokens: List[torch.Tensor] = [] nll = 0.0 for _ in range(rollout): ctx = cand[:, -seq_len:] logits, _, _, _ = model(ctx) nlogits = logits[:, -1, :] / max(temperature, 1e-6) nid = _sample_id(nlogits) lp = F.log_softmax(nlogits.float(), dim=-1) nll += float(-lp.gather(-1, nid).item()) cand = torch.cat([cand, nid], dim=1) cand_tokens.append(nid) if best_nll is None or nll < best_nll: best_nll = nll best_tokens = cand_tokens assert best_tokens is not None for t in best_tokens: x = torch.cat([x, t], dim=1) generated += 1 generated_ids = x[0, prompt_len:].tolist() return tokenizer.decode(generated_ids, skip_special=True) def sample_text_cached( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int, temperature: float, top_k: int, device: torch.device, seq_len: int, ) -> str: model.eval() ids = tokenizer.encode(prompt, add_bos=True, add_eos=False) prompt_len = len(ids) x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) with torch.no_grad(): logits, _, _, past_kv = model(x, use_cache=True) nlogits = logits[:, -1, :] / max(temperature, 1e-6) if top_k > 0: v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1])) p = torch.softmax(v, dim=-1) nid = idx.gather(-1, torch.multinomial(p, 1)) else: p = torch.softmax(nlogits, dim=-1) nid = torch.multinomial(p, 1) all_ids = [int(nid.item())] for _ in range(max_new_tokens - 1): logits, _, _, past_kv = model(nid, use_cache=True, past_key_values=past_kv) nlogits = logits[:, -1, :] / max(temperature, 1e-6) if top_k > 0: v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1])) p = torch.softmax(v, dim=-1) nid = idx.gather(-1, torch.multinomial(p, 1)) else: p = torch.softmax(nlogits, dim=-1) nid = torch.multinomial(p, 1) tid = int(nid.item()) all_ids.append(tid) if tid == tokenizer.eos_id: break return tokenizer.decode(all_ids, skip_special=True) def speculative_decode( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int, temperature: float, top_k: int, device: torch.device, seq_len: int, ) -> str: model.eval() ids = tokenizer.encode(prompt, add_bos=True, add_eos=False) x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) all_generated: List[int] = [] with torch.no_grad(): logits, _, h_out, past_kv = model(x, use_cache=True, return_hidden=True) def _sample_from(lg: torch.Tensor) -> int: lg = lg / max(temperature, 1e-6) if top_k > 0: v, idx = torch.topk(lg, k=min(top_k, lg.shape[-1])) p = torch.softmax(v, dim=-1) return int(idx[torch.multinomial(p, 1)].item()) p = torch.softmax(lg, dim=-1) return int(torch.multinomial(p, 1).item()) main_token = _sample_from(logits[0, -1, :]) all_generated.append(main_token) while len(all_generated) < max_new_tokens: if main_token == tokenizer.eos_id: break draft_tokens = [] if h_out is not None and model.mtp_horizons: last_hidden = h_out[:, -1:, :] for h in model.mtp_horizons: adapter = model.mtp_adapters[str(h)] norm = model.mtp_norms[str(h)] adapted = norm(adapter(last_hidden)) draft_logits = model.head(adapted) + model.output_bias draft_tok = _sample_from(draft_logits[0, 0, :]) draft_tokens.append(draft_tok) if not draft_tokens: nid = torch.tensor([[main_token]], dtype=torch.long, device=device) logits, _, h_out, past_kv = model( nid, use_cache=True, past_key_values=past_kv, return_hidden=True ) main_token = _sample_from(logits[0, -1, :]) all_generated.append(main_token) continue verify_input = torch.tensor( [[main_token] + draft_tokens], dtype=torch.long, device=device ) verify_logits, _, h_out, past_kv = model( verify_input, use_cache=True, past_key_values=past_kv, return_hidden=True, ) accepted = 0 all_generated.append(main_token) if main_token not in all_generated[ -1: ] else None for i, draft_tok in enumerate(draft_tokens): verified_tok = _sample_from(verify_logits[0, i, :]) if verified_tok == draft_tok: all_generated.append(draft_tok) accepted += 1 if draft_tok == tokenizer.eos_id: break else: all_generated.append(verified_tok) break if accepted < len(draft_tokens): trim_len = len(draft_tokens) - accepted - 1 if trim_len > 0 and past_kv is not None: past_kv = [ (k[:, :, :-trim_len, :], v[:, :, :-trim_len, :]) if k is not None else None for k, v in past_kv ] main_token = all_generated[-1] return tokenizer.decode(all_generated, skip_special=True) def build_stop_token_ids(tokenizer: WordTokenizer) -> set: stop_tokens = {tokenizer.eos_id} for tok in ("<|user|>", "<|system|>", "<|assistant|>"): tid = tokenizer.token_to_id.get(tok) if tid is not None: stop_tokens.add(int(tid)) return stop_tokens def apply_no_repeat_ngram( logits: torch.Tensor, token_history: Sequence[int], ngram_size: int, ) -> torch.Tensor: if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1): return logits prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple() banned: set = set() for i in range(len(token_history) - ngram_size + 1): if tuple(token_history[i : i + ngram_size - 1]) == prefix: banned.add(int(token_history[i + ngram_size - 1])) if not banned: return logits out = logits.clone() banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long) out[banned_ids] = float("-inf") return out def score_candidate( prompt: str, raw_text: str, visible_text: str, avg_logprob: float, ) -> float: clean = visible_text.strip() if not clean: return -1e9 score = avg_logprob words = clean.lower().split() prompt_words = re.findall(r"[A-Za-z][A-Za-z'-]{2,}", prompt.lower()) prompt_stop = { "what", "which", "when", "where", "why", "how", "are", "is", "the", "and", "for", "with", "that", "this", "from", "into", "about", "explain", "tell", "give", "list", "show", "write", "their", "there", "your", } prompt_keywords = {w for w in prompt_words if w not in prompt_stop} candidate_keywords = set(re.findall(r"[A-Za-z][A-Za-z'-]{2,}", clean.lower())) if len(words) < 6: score -= 2.0 else: score += min(2.0, len(words) * 0.03) if clean[-1:] in ".!?": score += 0.5 if "<|user|>" in raw_text or "<|system|>" in raw_text: score -= 4.0 if raw_text.count("<|assistant|>") > 1: score -= 2.0 if prompt_keywords: overlap = len(prompt_keywords & candidate_keywords) / len(prompt_keywords) if overlap == 0.0: score -= 2.5 else: score += min(3.5, overlap * 4.0) for open_tok, close_tok in [ ("<|begin_of_thought|>", "<|end_of_thought|>"), ("<|begin_of_solution|>", "<|end_of_solution|>"), ]: if (open_tok in raw_text) != (close_tok in raw_text): score -= 1.0 if len(words) >= 3: trigrams = [tuple(words[i : i + 3]) for i in range(len(words) - 2)] if trigrams: unique_ratio = len(set(trigrams)) / len(trigrams) if unique_ratio < 0.35: score -= 4.0 elif unique_ratio < 0.55: score -= 2.0 else: score += min(1.0, (unique_ratio - 0.55) * 2.0) alpha_words = [ w for w in words if len(w) <= 18 and (sum(ch.isalpha() for ch in w) / max(len(w), 1)) > 0.7 ] alpha_ratio = len(alpha_words) / max(len(words), 1) if alpha_ratio < 0.45: score -= 3.0 elif alpha_ratio < 0.65: score -= 1.0 return score def generate_candidate( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int, temperature: float, top_k: int, repetition_penalty: float, no_repeat_ngram_size: int, device: str, sft_mode: bool, force_thought: bool, stream: bool, context_window: int, ) -> Tuple[str, str, float, int]: if sft_mode: full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" else: full_prompt = prompt if force_thought: full_prompt = f"{full_prompt}<|begin_of_thought|> " input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) visible_tokens: List[str] = [] raw_tokens: List[str] = [] stop_token_ids = build_stop_token_ids(tokenizer) total_logprob = 0.0 sampled_tokens = 0 with torch.no_grad(): for _ in range(max_new_tokens): ctx_ids = ( input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t ) logits, _, _, _ = model(ctx_ids) next_logits = logits[0, -1, :].clone() raw_next_logits = next_logits.clone() if repetition_penalty != 1.0: seen = set(input_ids_t[0].tolist()) for token_id in seen: if next_logits[token_id] > 0: next_logits[token_id] /= repetition_penalty else: next_logits[token_id] *= repetition_penalty if temperature != 1.0: next_logits = next_logits / max(temperature, 1e-6) if no_repeat_ngram_size > 1: next_logits = apply_no_repeat_ngram( next_logits, input_ids_t[0].tolist(), no_repeat_ngram_size, ) if top_k > 0: v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) next_logits[next_logits < v[-1]] = float("-inf") top_p = 0.9 if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) remove_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p sorted_logits[remove_mask] = float("-inf") next_logits = sorted_logits.scatter(0, sorted_indices, sorted_logits) if not torch.isfinite(next_logits).any(): next_logits = raw_next_logits if temperature != 1.0: next_logits = next_logits / max(temperature, 1e-6) probs = torch.softmax(next_logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1).item() total_logprob += float(torch.log(probs[next_id] + 1e-12).item()) sampled_tokens += 1 if next_id in stop_token_ids: break token_str = ( tokenizer.id_to_token[next_id] if next_id < len(tokenizer.id_to_token) else "" ) raw_tokens.append(token_str) if token_str not in tokenizer.special: visible_tokens.append(token_str) if stream: print(token_str, end="", flush=True) input_ids_t = torch.cat( [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 ) if stream: print() avg_logprob = total_logprob / max(1, sampled_tokens) return "".join(visible_tokens), "".join(raw_tokens), avg_logprob, 0 def generate_beam_search( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int = 60, beam_width: int = 8, length_penalty: float = 0.7, no_repeat_ngram_size: int = 3, device: str = "cuda", sft_mode: bool = False, context_window: int = 2048, ) -> str: if sft_mode: full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" else: full_prompt = prompt prompt_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) prompt_len = len(prompt_ids) stop_ids = build_stop_token_ids(tokenizer) beams: List[Tuple[float, List[int]]] = [(0.0, list(prompt_ids))] completed: List[Tuple[float, List[int]]] = [] for _step in range(max_new_tokens): if not beams: break candidates: List[Tuple[float, List[int]]] = [] for beam_score, beam_ids in beams: x = torch.tensor( [beam_ids[-context_window:]], dtype=torch.long, device=device ) with torch.no_grad(): logits, _, _, _ = model(x) nl = logits[0, -1, :] log_probs = F.log_softmax(nl, dim=-1) gen_ids = beam_ids[prompt_len:] if no_repeat_ngram_size > 1 and len(gen_ids) >= no_repeat_ngram_size - 1: prefix = tuple(gen_ids[-(no_repeat_ngram_size - 1) :]) for i in range(len(gen_ids) - no_repeat_ngram_size + 1): if tuple(gen_ids[i : i + no_repeat_ngram_size - 1]) == prefix: log_probs[gen_ids[i + no_repeat_ngram_size - 1]] = float("-inf") topk_lp, topk_ids = torch.topk(log_probs, beam_width) for i in range(beam_width): tid = topk_ids[i].item() new_score = beam_score + topk_lp[i].item() new_ids = beam_ids + [tid] if tid in stop_ids: completed.append((new_score, new_ids)) else: candidates.append((new_score, new_ids)) def _norm_score(pair): gen_len = max(1, len(pair[1]) - prompt_len) return pair[0] / (gen_len**length_penalty) candidates.sort(key=_norm_score, reverse=True) beams = candidates[:beam_width] pool = completed + beams if not pool: return "" def _norm_score_final(pair): gen_len = max(1, len(pair[1]) - prompt_len) return pair[0] / (gen_len**length_penalty) pool.sort(key=_norm_score_final, reverse=True) best_ids = pool[0][1][prompt_len:] text = tokenizer.decode(best_ids, skip_special=True) nl_pos = text.find("\n") if nl_pos > 5: text = text[:nl_pos] return text.strip() def generate( model: TinyMemoryLM, tokenizer: WordTokenizer, prompt: str, max_new_tokens: int = 256, temperature: float = 0.8, top_k: int = 40, repetition_penalty: float = 1.0, device: str = "cuda", sft_mode: bool = False, force_thought: bool = False, stream: bool = True, decode_mode: str = "legacy", best_of: int = 3, no_repeat_ngram_size: int = 3, context_window: int = 2048, beam_width: int = 8, length_penalty: float = 0.7, ) -> str: if decode_mode == "beam": text = generate_beam_search( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=max_new_tokens, beam_width=beam_width, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, device=device, sft_mode=sft_mode, context_window=context_window, ) if stream: print(text) return text if decode_mode == "legacy": text, _, _, _ = generate_candidate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, device=device, sft_mode=sft_mode, force_thought=force_thought, stream=stream, context_window=context_window, ) return text candidates: List[Tuple[float, str, str, float]] = [] for _ in range(max(1, best_of)): candidate_text, raw_text, avg_logprob, _ = generate_candidate( model=model, tokenizer=tokenizer, prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, device=device, sft_mode=sft_mode, force_thought=force_thought, stream=False, context_window=context_window, ) score = score_candidate(prompt, raw_text, candidate_text, avg_logprob) candidates.append((score, candidate_text, raw_text, avg_logprob)) best_score, best_text, _, _ = max(candidates, key=lambda item: item[0]) if stream: print(best_text, end="", flush=True) print() return best_text # --------------------------------------------------------------------------- # Web server (from interactive.py) # --------------------------------------------------------------------------- ROOT = Path(__file__).resolve().parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) HF_ORG = "CompactAI" HF_API = "https://huggingface.co/api" CACHE_ROOT = Path.home() / ".cache" / "compactai_web" USER_AGENT = "Mozilla/5.0 CompactAI-Web" MODEL_CACHE: dict[tuple[str, str], dict[str, object]] = {} MODEL_CACHE_LOCK = threading.RLock() GENERATION_LOCK = threading.Lock() def request_json(url: str): req = Request(url, headers={"User-Agent": USER_AGENT}) with urlopen(req, timeout=60) as response: return json.loads(response.read().decode("utf-8")) def request_text(url: str) -> str: req = Request(url, headers={"User-Agent": USER_AGENT}) with urlopen(req, timeout=60) as response: return response.read().decode("utf-8", errors="replace") def download_file(url: str, destination: Path) -> None: destination.parent.mkdir(parents=True, exist_ok=True) temp_path = destination.with_suffix(destination.suffix + ".tmp") req = Request(url, headers={"User-Agent": USER_AGENT}) with urlopen(req, timeout=120) as response, temp_path.open("wb") as handle: shutil.copyfileobj(response, handle) temp_path.replace(destination) def normalize_repo_id(raw_repo_id: str) -> str: if not isinstance(raw_repo_id, str): return "" repo_id = raw_repo_id.strip() if not repo_id: return "" try: repo_id = unquote(repo_id) except Exception: pass return ( repo_id.replace("https://huggingface.co/", "") .replace("http://huggingface.co/", "") .replace("api/models/", "") .replace("models/", "") .split("?", 1)[0] .split("#", 1)[0] .strip("/") ) def series_from_name(name: str) -> str | None: lower = (name or "").lower() if "haiku" in lower: return "Haiku" if "sonnet" in lower: return "Sonnet" if "opus" in lower: return "Opus" return None def encoded_repo_id(repo_id: str) -> str: return "/".join( quote(part, safe="") for part in normalize_repo_id(repo_id).split("/") if part ) def hf_file_url(repo_id: str, filename: str) -> str: encoded_name = "/".join( quote(part, safe="") for part in filename.split("/") if part ) return ( f"https://huggingface.co/{encoded_repo_id(repo_id)}/resolve/main/{encoded_name}" ) def model_list() -> list[dict[str, object]]: data = request_json(f"{HF_API}/models?author={quote(HF_ORG)}&full=true&limit=200") models: list[dict[str, object]] = [] for item in data: siblings = item.get("siblings") or [] filenames = [s.get("rfilename", "") for s in siblings if isinstance(s, dict)] has_model = "model.pt" in filenames or "model/model.pt" in filenames has_pretrain = "pretrain.pt" in filenames or "model/pretrain.pt" in filenames has_tokenizer = ( "tokenizer.json" in filenames or "model/tokenizer.json" in filenames ) if not has_model and not has_pretrain: continue name = (item.get("id") or "").split("/")[-1] series = series_from_name(name) if not series: continue models.append( { "id": item.get("id", ""), "name": name, "series": series, "downloads": item.get("downloads", 0) or 0, "likes": item.get("likes", 0) or 0, "has_model": has_model, "has_pretrain": has_pretrain, "has_tokenizer": has_tokenizer, } ) return sorted(models, key=lambda entry: entry["downloads"], reverse=True) def model_details(repo_id: str) -> dict[str, object] | None: normalized = normalize_repo_id(repo_id) if not normalized: return None data = request_json(f"{HF_API}/models/{encoded_repo_id(normalized)}") siblings = data.get("siblings") or [] files: dict[str, dict[str, float]] = {} has_model = False has_pretrain = False for sibling in siblings: if not isinstance(sibling, dict): continue filename = sibling.get("rfilename") or "" if not filename: continue size_mb = round((sibling.get("size") or 0) / (1024 * 1024), 2) files[filename] = {"size_mb": size_mb} if filename.startswith("model/"): files[filename.removeprefix("model/")] = {"size_mb": size_mb} if filename in {"model.pt", "model/model.pt"}: has_model = True if filename in {"pretrain.pt", "model/pretrain.pt"}: has_pretrain = True readme_raw = "" try: readme_raw = request_text( f"https://huggingface.co/{encoded_repo_id(normalized)}/raw/main/README.md" ) except Exception: readme_raw = "" name = (data.get("id") or normalized).split("/")[-1] return { "id": normalized, "name": name, "series": series_from_name(name) or "Sonnet", "downloads": data.get("downloads", 0) or 0, "files": files, "readme_raw": readme_raw, "hf_model_id": normalized, "has_model": has_model, "has_pretrain": has_pretrain, } def cache_dir(repo_id: str, model_type: str) -> Path: return CACHE_ROOT / normalize_repo_id(repo_id).replace("/", "__") / model_type def artifact_candidates(model_type: str) -> list[str]: return ( ["model/pretrain.pt", "pretrain.pt"] if model_type == "pretrain" else ["model/model.pt", "model.pt"] ) def ensure_artifact(repo_id: str, model_type: str, destination_name: str) -> Path: normalized = normalize_repo_id(repo_id) target = cache_dir(normalized, model_type) / destination_name if target.exists(): return target last_error: Exception | None = None for candidate in ( artifact_candidates(model_type) if destination_name.endswith(".pt") else ["model/tokenizer.json", "tokenizer.json"] ): try: download_file(hf_file_url(normalized, candidate), target) return target except Exception as exc: last_error = exc raise RuntimeError( f"Unable to download {destination_name} for {normalized}: {last_error}" ) def series_config(series: str) -> dict[str, object]: return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"]) def load_bundle(repo_id: str, model_type: str) -> dict[str, object]: normalized = normalize_repo_id(repo_id) details = model_details(normalized) if not details: raise RuntimeError("Model details are unavailable.") series = str(details["series"]) key = (normalized, model_type) with MODEL_CACHE_LOCK: cached = MODEL_CACHE.get(key) if cached: return cached bundle_dir = cache_dir(normalized, model_type) bundle_dir.mkdir(parents=True, exist_ok=True) model_path = bundle_dir / ( "pretrain.pt" if model_type == "pretrain" else "model.pt" ) tokenizer_path = bundle_dir / "tokenizer.json" if not model_path.exists(): ensure_artifact(normalized, model_type, model_path.name) if not tokenizer_path.exists(): ensure_artifact(normalized, model_type, tokenizer_path.name) tokenizer = WordTokenizer.load(tokenizer_path) ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) cfg = series_config(series) vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size)) state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt # Auto-detect new arch features from checkpoint weights engram_dim = _detect_engram_dim(state_dict) or int( cfg.get("engram_dim", model_config.engram_dim) ) mhc_expansion = _detect_mhc_expansion(state_dict) or int( cfg.get("mhc_expansion", model_config.mhc_expansion) ) model = TinyMemoryLM( vocab_size=vocab_size, dim=int(cfg.get("dim", model_config.dim)), n_unique_layers=int( cfg.get("n_unique_layers", model_config.n_unique_layers) ), n_logical_layers=int( cfg.get("n_logical_layers", model_config.n_logical_layers) ), n_heads=int(cfg.get("n_heads", model_config.n_heads)), n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)), ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)), dropout=float(cfg.get("dropout", model_config.dropout)), mtp_horizons=tuple( int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons) ), grad_checkpoint=False, sliding_window=int( cfg.get("sliding_window_size", model_config.sliding_window_size) ), rope_fraction=float( cfg.get("rope_fraction", model_config.rope_fraction) ), embed_scale=bool( cfg.get("embed_scale", model_config.embed_scale) ), engram_dim=engram_dim, engram_heads=int(cfg.get("engram_heads", model_config.engram_heads)), engram_table_size=int( cfg.get("engram_table_size", model_config.engram_table_size) ), engram_max_ngram=int( cfg.get("engram_max_ngram", model_config.engram_max_ngram) ), mhc_expansion=mhc_expansion, ) model.load_state_dict(state_dict, strict=False) model.eval() if tokenizer.vocab_size > vocab_size: model.resize_token_embeddings(tokenizer.vocab_size) device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) bundle = { "repo_id": normalized, "name": details["name"], "series": series, "type": model_type, "model": model, "tokenizer": tokenizer, "device": device, "model_path": str(model_path), "tokenizer_path": str(tokenizer_path), "downloads": details["downloads"], } MODEL_CACHE[key] = bundle return bundle def ensure_port(start_port: int) -> int: for port in range(start_port, start_port + 50): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: try: sock.bind(("127.0.0.1", port)) except OSError: continue return port with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.bind(("127.0.0.1", 0)) return sock.getsockname()[1] def page_html() -> str: return f""" CompactAI Web

CompactAI Web

Pull a model from Hugging Face, keep it cached locally, and chat in the browser.

Hugging Face: CompactAI pip install -r requirements.txt Local inference

Model

Loading model list…
python3 interactive_web.py

Prompt

Response

""" class Handler(BaseHTTPRequestHandler): def _send_json(self, payload, status=200): body = json.dumps(payload).encode("utf-8") self.send_response(status) self.send_header("Content-Type", "application/json; charset=utf-8") self.send_header("Content-Length", str(len(body))) self.send_header("Cache-Control", "no-store") self.end_headers() self.wfile.write(body) def _send_html(self, payload: str, status=200): body = payload.encode("utf-8") self.send_response(status) self.send_header("Content-Type", "text/html; charset=utf-8") self.send_header("Content-Length", str(len(body))) self.send_header("Cache-Control", "no-store") self.end_headers() self.wfile.write(body) def do_GET(self): parsed = urlparse(self.path) if parsed.path in {"/", "/index.html"}: self._send_html(page_html()) return if parsed.path == "/api/models": try: self._send_json(model_list()) except Exception as exc: self._send_json({"success": False, "error": str(exc)}, 500) return if parsed.path.startswith("/api/models/"): repo_id = normalize_repo_id(parsed.path.removeprefix("/api/models/")) try: details = model_details(repo_id) if not details: self._send_json( {"success": False, "error": "Model not found."}, 404 ) else: self._send_json(details) except Exception as exc: self._send_json({"success": False, "error": str(exc)}, 500) return self._send_json({"success": False, "error": "Not found."}, 404) def do_POST(self): parsed = urlparse(self.path) length = int(self.headers.get("Content-Length", "0") or "0") raw = self.rfile.read(length).decode("utf-8") if length else "{}" try: payload = json.loads(raw or "{}") except Exception: payload = {} if parsed.path == "/api/ensure": try: repo_id = normalize_repo_id(payload.get("modelId", "")) model_type = payload.get("type", "model") if not repo_id: self._send_json( {"success": False, "error": "Missing model ID."}, 400 ) return details = model_details(repo_id) if not details: self._send_json( {"success": False, "error": "Model not found."}, 404 ) return bundle = load_bundle(repo_id, model_type) self._send_json( { "success": True, "id": bundle["repo_id"], "name": bundle["name"], "series": bundle["series"], "type": bundle["type"], } ) except Exception as exc: self._send_json({"success": False, "error": str(exc)}, 500) return if parsed.path == "/api/generate": try: repo_id = normalize_repo_id(payload.get("modelId", "")) model_type = payload.get("type", "model") prompt = str(payload.get("prompt", "")) if not repo_id: self._send_json( {"success": False, "error": "Missing model ID."}, 400 ) return bundle = load_bundle(repo_id, model_type) with GENERATION_LOCK: text = generate( model=bundle["model"], tokenizer=bundle["tokenizer"], prompt=prompt, max_new_tokens=int(payload.get("max_new_tokens", 256)), temperature=float(payload.get("temperature", 0.8)), top_k=int(payload.get("top_k", 40)), repetition_penalty=float( payload.get("repetition_penalty", 1.0) ), device=str(bundle["device"]), sft_mode=model_type != "pretrain", force_thought=bool(payload.get("force_thought", False)), stream=False, decode_mode=str(payload.get("decode_mode", "legacy")), best_of=int(payload.get("best_of", 3)), no_repeat_ngram_size=int( payload.get("no_repeat_ngram_size", 3) ), context_window=int(payload.get("context_window", 2048)), beam_width=int(payload.get("beam_width", 8)), length_penalty=float(payload.get("length_penalty", 0.7)), ) self._send_json( { "success": True, "text": text, "name": bundle["name"], "series": bundle["series"], } ) except Exception as exc: self._send_json({"success": False, "error": str(exc)}, 500) return self._send_json({"success": False, "error": "Not found."}, 404) def log_message(self, format, *args): return def main(): CACHE_ROOT.mkdir(parents=True, exist_ok=True) port = ensure_port(int(os.environ.get("PORT", "7860"))) server = ThreadingHTTPServer(("127.0.0.1", port), Handler) url = f"http://127.0.0.1:{port}" print(url, flush=True) try: webbrowser.open(url) except Exception: pass try: server.serve_forever() except KeyboardInterrupt: pass finally: server.server_close() if __name__ == "__main__": main()