diff --git "a/app.py" "b/app.py" new file mode 100644--- /dev/null +++ "b/app.py" @@ -0,0 +1,3151 @@ +#!/usr/bin/env python3 +"""Public-facing TMLM-Haiku interactive CLI. + +Pulls models from the CompactAI-O HuggingFace collection: + https://huggingface.co/collections/CompactAI-O/tmlm-haiku-series +""" +from __future__ import annotations + + +#!/usr/bin/env python3 +from __future__ import annotations + +import hashlib +import json +import math +import os +import string +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterator, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + + +HUGGINGFACE_MODELS = { + "TMLM-Haiku-1": "CompactAI-O/TMLM-Haiku-1", + "TMLM-Haiku-1.3": "CompactAI-O/TMLM-Haiku-1.3", + "TMLM-Haiku-2": "CompactAI-O/TMLM-Haiku-2", + "Glint-1": "CompactAI-O/Glint-1", +} + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +@dataclass +class ModelConfig: + dim: int = 128 + n_unique_layers: int = 8 + n_logical_layers: int = 16 + n_heads: int = 4 + n_kv_heads: int = 2 + ffn_dim: int = 224 + dropout: float = 0.0 + seq_len: int = 2048 + sliding_window_size: int = 512 + mtp_horizons: Tuple[int, ...] = (2, 3, 4) + rope_fraction: float = 0.5 + embed_scale: bool = True + logit_soft_cap: float = -1.0 + quantization: str = "nvfp4" + + @property + def head_dim(self) -> int: + return self.dim // self.n_heads + + +model_config = ModelConfig() + +MODEL_SERIES = { + "haiku": { + "dim": 64, + "n_unique_layers": 12, + "n_logical_layers": 24, + "n_heads": 4, + "n_kv_heads": 2, + "ffn_dim": 384, + "dropout": 0.0, + "seq_len": 2048, + "sliding_window_size": 2048, + "mtp_horizons": (), + "rope_fraction": 0.5, + "engram_dim": 8, + "engram_heads": 2, + "engram_table_size": 64, + "engram_max_ngram": 2, + "mhc_expansion": 2, + "sleep_gate_cap": 0, + "sleep_gate_heads": 4, + "latent_think_layers": 0, + "prelude_layers": 0, + "coda_layers": 0, + "recurrent_loops": 0, + "recurrent_act_threshold": 0.9, + "recurrent_lora_rank": 0, + "recurrent_loop_embed_dim": 0, + }, + "sonnet": { + "dim": 1024, + "n_unique_layers": 20, + "n_logical_layers": 40, + "n_heads": 16, + "n_kv_heads": 4, + "ffn_dim": 4096, + "dropout": 0.0, + "seq_len": 2048, + "mtp_horizons": (2,), + "engram_dim": 32, + "engram_heads": 8, + "engram_table_size": 4096, + "engram_max_ngram": 2, + "mhc_expansion": 2, + "sleep_gate_cap": 0, + "sleep_gate_heads": 8, + "latent_think_layers": 0, + "prelude_layers": 0, + "coda_layers": 0, + "recurrent_loops": 0, + "recurrent_act_threshold": 0.99, + "recurrent_lora_rank": 0, + "recurrent_loop_embed_dim": 0, + }, + "opus": { + "dim": 1536, + "n_unique_layers": 18, + "n_logical_layers": 36, + "n_heads": 16, + "n_kv_heads": 4, + "ffn_dim": 5888, + "dropout": 0.0, + "seq_len": 2048, + "mtp_horizons": (2,), + "engram_dim": 64, + "engram_heads": 8, + "engram_table_size": 8192, + "engram_max_ngram": 2, + "mhc_expansion": 4, + "sleep_gate_cap": 0, + "sleep_gate_heads": 8, + "latent_think_layers": 0, + "prelude_layers": 0, + "coda_layers": 0, + "recurrent_loops": 0, + "recurrent_act_threshold": 0.99, + "recurrent_lora_rank": 0, + "recurrent_loop_embed_dim": 0, + }, +} + + +# --------------------------------------------------------------------------- +# Tokenizer +# --------------------------------------------------------------------------- + +FORMAT_TOKENS = [ + "<|user|>", + "<|assistant|>", + "<|system|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|begin_of_thought|>", + "<|end_of_thought|>", + "<|begin_of_solution|>", + "<|end_of_solution|>", +] + + +class WordTokenizer: + def __init__( + self, extra_chars: str = "", format_tokens: Optional[List[str]] = None + ) -> None: + base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r" + fallback_chars = sorted(set(base + extra_chars)) + self.core_special = ["", "", "", ""] + self.format_tokens = ( + list(format_tokens) if format_tokens else list(FORMAT_TOKENS) + ) + self.special = list(self.core_special) + list(self.format_tokens) + self.id_to_token: List[str] = ( + list(self.core_special) + self.format_tokens + fallback_chars + ) + self.token_to_id: Dict[str, int] = { + t: i for i, t in enumerate(self.id_to_token) + } + self.special_multi_tokens = sorted( + [t for t in self.special if len(t) > 1], key=len, reverse=True + ) + self.multi_char_tokens = self.special_multi_tokens + self.dynamic_additions = 0 + + @property + def pad_id(self) -> int: + return self.token_to_id[""] + + @property + def bos_id(self) -> int: + return self.token_to_id[""] + + @property + def eos_id(self) -> int: + return self.token_to_id[""] + + @property + def unk_id(self) -> int: + return self.token_to_id[""] + + @property + def vocab_size(self) -> int: + return len(self.id_to_token) + + def maybe_add_char(self, ch: str) -> bool: + if ch in self.token_to_id: + return False + self.token_to_id[ch] = len(self.id_to_token) + self.id_to_token.append(ch) + self.dynamic_additions += 1 + return True + + def iter_lexical_tokens(self, text: str) -> Iterator[str]: + i = 0 + n = len(text) + while i < n: + matched_special = False + for token in self.special_multi_tokens: + if text.startswith(token, i): + yield token + i += len(token) + matched_special = True + break + if matched_special: + continue + yield text[i] + i += 1 + + def encode( + self, text: str, add_bos: bool = False, add_eos: bool = False + ) -> List[int]: + out: List[int] = [] + if add_bos: + out.append(self.bos_id) + unk = self.unk_id + t2i = self.token_to_id + for tok in self.iter_lexical_tokens(text): + out.append(t2i.get(tok, unk)) + if add_eos: + out.append(self.eos_id) + return out + + def decode(self, ids: Sequence[int], skip_special: bool = True) -> str: + pieces: List[str] = [] + for idx in ids: + if int(idx) < 0 or int(idx) >= len(self.id_to_token): + continue + tok = self.id_to_token[int(idx)] + if skip_special and tok in self.special: + continue + pieces.append(tok) + return "".join(pieces) + + @classmethod + def load(cls, path: Path) -> WordTokenizer: + with path.open("r", encoding="utf-8") as f: + data = json.load(f) + format_tokens = data.get("format_tokens", FORMAT_TOKENS) + tokenizer = cls(extra_chars="", format_tokens=format_tokens) + tokenizer.id_to_token = data["id_to_token"] + tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)} + tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens) + tokenizer.special_multi_tokens = sorted( + [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True + ) + tokenizer.multi_char_tokens = tokenizer.special_multi_tokens + return tokenizer + + +LetterTokenizer = WordTokenizer + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(torch.nn.functional, "rms_norm"): + return torch.nn.functional.rms_norm( + x, self.weight.shape, self.weight, self.eps + ) + x_fp = x.float() + rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps) + return (x_fp * rms).to(dtype=x.dtype) * self.weight + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim: int, base: float = 10000.0) -> None: + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + + def cos_sin( + self, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat([freqs, freqs], dim=-1) + cos = emb.cos()[None, None, :, :].to(dtype=dtype) + sin = emb.sin()[None, None, :, :].to(dtype=dtype) + return cos, sin + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + dropout: float, + sliding_window: int, + rope_fraction: float, + ) -> None: + super().__init__() + self.dim = dim + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.n_rep = n_heads // n_kv_heads + self.dropout = dropout + self.sliding_window = sliding_window + + self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) + + self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2) + self.rope = RotaryEmbedding(self.rope_dim) + + self.q_norm = RMSNorm(head_dim) + self.k_norm = RMSNorm(head_dim) + + self.output_gate = nn.Parameter(torch.ones(n_heads)) + + def forward( + self, + x: torch.Tensor, + is_global: bool, + past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + B, T, _ = x.shape + + q = self.wq(x).view(B, T, self.n_heads, self.head_dim) + k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) + v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(k) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + past_len = past_kv[0].shape[2] if past_kv is not None else 0 + cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype) + cos_slice = cos[:, :, past_len : past_len + T, :] + sin_slice = sin[:, :, past_len : past_len + T, :] + + q_rope = q[..., : self.rope_dim] + q_pass = q[..., self.rope_dim :] + k_rope = k[..., : self.rope_dim] + k_pass = k[..., self.rope_dim :] + + q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice) + k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice) + + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + + if past_kv is not None: + k = torch.cat([past_kv[0], k], dim=2) + v = torch.cat([past_kv[1], v], dim=2) + + new_kv = (k, v) if use_cache else None + + S = k.shape[2] + if self.n_rep > 1: + k = ( + k[:, :, None, :, :] + .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) + .reshape(B, self.n_heads, S, self.head_dim) + ) + v = ( + v[:, :, None, :, :] + .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) + .reshape(B, self.n_heads, S, self.head_dim) + ) + + drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0 + + if is_global: + if past_kv is None and T > 1: + out = F.scaled_dot_product_attention( + q, k, v, is_causal=True, dropout_p=drop_p + ) + else: + out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p) + else: + T_q = q.shape[2] + q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1) + k_pos = torch.arange(S, device=q.device).unsqueeze(0) + mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window) + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p + ) + + gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1) + out = out * gate + + out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) + out = self.wo(out) + + return out, new_kv + + +class SwiGLU(nn.Module): + def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None: + super().__init__() + self.gate = nn.Linear(dim, hidden_dim, bias=False) + self.up = nn.Linear(dim, hidden_dim, bias=False) + self.down = nn.Linear(hidden_dim, dim, bias=False) + self.drop = nn.Dropout(dropout) + + nn.init.normal_(self.gate.weight, std=dim**-0.5) + nn.init.normal_(self.up.weight, std=dim**-0.5) + nn.init.normal_(self.down.weight, std=hidden_dim**-0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = F.silu(self.gate(x)) * self.up(x) + out = self.down(h) + if self.training and torch.is_grad_enabled(): + out = self.drop(out) + return out + + +def loop_index_embedding(h: torch.Tensor, loop_t: int, loop_dim: int, theta: float = 10000.0) -> torch.Tensor: + if loop_dim <= 0: + return h + loop_dim = min(loop_dim, h.shape[-1]) + if loop_dim % 2 == 1: + loop_dim -= 1 + if loop_dim <= 0: + return h + inv_freq = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) + phase = torch.tensor(float(loop_t), device=h.device, dtype=h.dtype) * inv_freq + loop_embed = torch.cat([phase.sin(), phase.cos()], dim=0).view(1, 1, loop_dim) + out = h.clone() + out[..., :loop_dim] = out[..., :loop_dim] + loop_embed + return out + + +class DepthLoRAAdapter(nn.Module): + def __init__(self, dim: int, rank: int, max_loops: int) -> None: + super().__init__() + self.rank = max(0, rank) + if self.rank <= 0: + self.down = None + self.B = None + self.scale = None + return + self.down = nn.Linear(dim, self.rank, bias=False) + self.B = nn.Parameter(torch.randn(self.rank, dim) * 0.02) + self.scale = nn.Embedding(max(1, max_loops), self.rank) + nn.init.zeros_(self.scale.weight) + + def forward(self, x: torch.Tensor, loop_t: int) -> torch.Tensor: + if self.rank <= 0 or self.down is None or self.B is None or self.scale is None: + return torch.zeros_like(x) + t_idx = min(loop_t, self.scale.num_embeddings - 1) + scale = self.scale(torch.tensor(t_idx, device=x.device)) + return (self.down(x) * scale) @ self.B + + +class StableRecurrentInjection(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.log_A = nn.Parameter(torch.full((dim,), -2.0)) + self.log_dt = nn.Parameter(torch.full((dim,), -2.0)) + self.input_gate = nn.Parameter(torch.zeros(dim)) + + def forward(self, h: torch.Tensor, e: torch.Tensor, transformer_out: torch.Tensor) -> torch.Tensor: + A = torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))).view(1, 1, -1) + B = torch.sigmoid(self.input_gate).view(1, 1, -1) + return A * h + B * e + transformer_out + + +class AdaptiveHalting(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + self.halt = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.halt.weight) + nn.init.constant_(self.halt.bias, -2.0) + + def forward(self, h: torch.Tensor) -> torch.Tensor: + return torch.sigmoid(self.halt(h)).squeeze(-1) + + +class EngramBlock(nn.Module): + """DeepSeek Engram: conditional memory via O(1) hashed N-gram lookup. + + Stores common token-pair/triplet patterns in an embedding table and + retrieves them with multi-head hashing. A context-aware gate (using the + current hidden state as query) decides how much of the retrieved memory + to inject into the residual stream. + + Reference: DeepSeek-AI, "Conditional Memory via Scalable Lookup" (2025). + """ + + def __init__( + self, + dim: int, + engram_dim: int, + n_heads: int = 4, + table_size: int = 8192, + max_ngram: int = 3, + ) -> None: + super().__init__() + self.dim = dim + self.engram_dim = engram_dim + self.n_heads = n_heads + self.table_size = table_size + self.max_ngram = max_ngram + + # One embedding table per (ngram_order, hash_head) + self.embeddings = nn.ParameterDict() + for n in range(2, max_ngram + 1): + for k in range(n_heads): + self.embeddings[f"{n}_{k}"] = nn.Parameter( + torch.randn(table_size, engram_dim) * (engram_dim**-0.5) + ) + + # Fixed hash parameters (non-learnable, deterministic) + for n in range(2, max_ngram + 1): + for k in range(n_heads): + seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16) + rng = torch.Generator().manual_seed(seed) + a = torch.randint(1, 2**31, (1,), generator=rng).item() + b = torch.randint(0, 2**31, (1,), generator=rng).item() + self.register_buffer( + f"hash_a_{n}_{k}", torch.tensor(a), persistent=False + ) + self.register_buffer( + f"hash_b_{n}_{k}", torch.tensor(b), persistent=False + ) + + # Causal convolution over N-gram branch outputs (kernel=4, dilation=max_ngram) + total_branch_dim = engram_dim * n_heads * (max_ngram - 1) + self.branch_conv = nn.Conv1d( + total_branch_dim, + total_branch_dim, + kernel_size=4, + dilation=max_ngram, + padding=0, + groups=total_branch_dim, + bias=True, + ) + nn.init.zeros_(self.branch_conv.weight) + nn.init.zeros_(self.branch_conv.bias) + + # Context-aware gating: hidden state as query, memory as key/value + self.gate_query = nn.Linear(dim, engram_dim, bias=False) + self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False) + self.gate_value = nn.Linear(total_branch_dim, dim, bias=False) + self.gate_scale = engram_dim**-0.5 + + def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor: + """Hash n-gram token sequences into table indices. + + Args: + token_ids: (B, T) token IDs + n: n-gram order (2 = bigram, 3 = trigram) + k: hash head index + Returns: + indices: (B, T) integer indices into embedding table + """ + a = getattr(self, f"hash_a_{n}_{k}") + b = getattr(self, f"hash_b_{n}_{k}") + B, T = token_ids.shape + + # Pad left with zeros so every position has a valid n-gram + padded = F.pad(token_ids, (n - 1, 0), value=0) # (B, T+n-1) + + # Polynomial rolling hash + combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device) + for i in range(n): + combined = combined * 31 + padded[:, i : i + T].long() + + indices = ((a * combined) ^ b) % self.table_size + return indices + + def forward( + self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Forward pass. + + Args: + hidden: (B, T, dim) current hidden state + token_ids: (B, T) input token IDs for n-gram hashing. + If None, uses argmax of hidden projections as proxy. + Returns: + output: (B, T, dim) memory injection for residual stream + """ + B, T, _ = hidden.shape + + if token_ids is None: + # Fallback: derive pseudo-token-ids from hidden state + token_ids = hidden.mean(dim=-1).long() % self.table_size + + # Retrieve and concatenate across n-gram orders and hash heads + branch_outputs = [] + for n in range(2, self.max_ngram + 1): + for k in range(self.n_heads): + indices = self._hash_ngram(token_ids, n, k) # (B, T) + table = self.embeddings[f"{n}_{k}"] # (table_size, engram_dim) + retrieved = table[indices] # (B, T, engram_dim) + branch_outputs.append(retrieved) + + # (B, T, engram_dim * n_heads * (max_ngram - 1)) + memory = torch.cat(branch_outputs, dim=-1) + + # Causal convolution over sequence dimension + # Pad left for causality (kernel_size - 1 = 3) + conv_in = memory.transpose(1, 2) # (B, C, T) + conv_in = F.pad( + conv_in, + ((self.branch_conv.kernel_size[0] - 1) * self.branch_conv.dilation[0], 0), + ) + conv_out = self.branch_conv(conv_in) # (B, C, T) + memory = conv_out.transpose(1, 2) # (B, T, C) + + # Context-aware gating + query = self.gate_query(hidden) # (B, T, engram_dim) + key = self.gate_key(memory) # (B, T, engram_dim) + gate = torch.sigmoid( + (query * key).sum(dim=-1, keepdim=True) * self.gate_scale + ) # (B, T, 1) + value = self.gate_value(memory) # (B, T, dim) + + return gate * value + + +class SleepGate(nn.Module): + """Persistent memory + periodic consolidation gate.""" + + def __init__( + self, + dim: int, + cap: int = 128, + n_heads: int = 4, + retention_enabled: bool = True, + retention_hidden: int = 0, + ) -> None: + super().__init__() + self.dim = dim + self.cap = cap + self.n_heads = n_heads + self.head_dim = dim // n_heads + self.scale = self.head_dim ** -0.5 + self.retention_enabled = retention_enabled + + self.register_buffer("mem_emb", torch.zeros(cap, dim, dtype=torch.bfloat16)) + self.register_buffer("mem_age", torch.zeros(cap, dtype=torch.long)) + self.register_buffer("mem_beta", torch.ones(cap, dtype=torch.float32)) + self.register_buffer("mem_count", torch.zeros((), dtype=torch.long)) + self.register_buffer("mem_head", torch.zeros((), dtype=torch.long)) + self.register_buffer("global_step", torch.zeros((), dtype=torch.long)) + + self.q_proj = nn.Linear(dim, dim, bias=False) + self.k_proj = nn.Linear(dim, dim, bias=False) + self.v_proj = nn.Linear(dim, dim, bias=False) + self.o_proj = nn.Linear(dim, dim, bias=False) + nn.init.zeros_(self.o_proj.weight) + self.gate_scale = nn.Parameter(torch.zeros(())) + + if retention_enabled: + if retention_hidden > 0: + self.retention_gate: Optional[nn.Module] = nn.Sequential( + nn.Linear(dim, retention_hidden, bias=False), + nn.GELU(), + nn.Linear(retention_hidden, 1, bias=True), + ) + nn.init.constant_(self.retention_gate[-1].bias, 2.2) + else: + self.retention_gate = nn.Linear(dim, 1, bias=True) + nn.init.constant_(self.retention_gate.bias, 2.2) + else: + self.retention_gate = None + + self._last_beta: Optional[torch.Tensor] = None + + def write(self, hidden: torch.Tensor) -> None: + B, T, _ = hidden.shape + tail_full = hidden[:, max(0, T - 16):, :].float().mean(dim=1) + if self.retention_gate is not None: + beta_live = torch.sigmoid(self.retention_gate(tail_full).squeeze(-1)) + self._last_beta = beta_live if self.training else None + beta_store = beta_live.detach().float() + else: + self._last_beta = None + beta_store = torch.ones(B, device=hidden.device, dtype=torch.float32) + tail = tail_full.to(self.mem_emb.dtype).detach() + with torch.no_grad(): + head = int(self.mem_head.item()) + count = int(self.mem_count.item()) + step = int(self.global_step.item()) + for b in range(B): + self.mem_emb[head] = tail[b] + self.mem_age[head] = step + self.mem_beta[head] = beta_store[b] + head = (head + 1) % self.cap + if count < self.cap: + count += 1 + self.mem_head.fill_(head) + self.mem_count.fill_(count) + + def read(self, x: torch.Tensor) -> torch.Tensor: + count = int(self.mem_count.item()) + if count == 0: + return torch.zeros_like(x) + B, T, D = x.shape + mem = self.mem_emb[:count].clone().to(x.dtype) + q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) + v = self.v_proj(mem).view(count, self.n_heads, self.head_dim).transpose(0, 1) + attn = torch.einsum("bhtd,hmd->bhtm", q, k) * self.scale + attn = F.softmax(attn, dim=-1) + if self.retention_enabled: + step = int(self.global_step.item()) + ages = self.mem_age[:count].to(x.device) + delta = (step - ages).clamp(min=0).to(x.dtype) + betas = self.mem_beta[:count].to(x.dtype).clamp(min=1e-6, max=1.0) + weights = betas.pow(delta) + attn = attn * weights.view(1, 1, 1, count) + attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9) + out = torch.einsum("bhtm,hmd->bhtd", attn, v) + out = out.transpose(1, 2).contiguous().view(B, T, D) + out = self.o_proj(out) + return torch.sigmoid(self.gate_scale) * out + + @torch.no_grad() + def reset(self) -> None: + self.mem_emb.zero_() + self.mem_age.zero_() + self.mem_beta.fill_(1.0) + self.mem_count.zero_() + self.mem_head.zero_() + self.global_step.zero_() + self._last_beta = None + + +def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor: + M = torch.exp(logits.clamp(-10, 10)) + for _ in range(n_iters): + M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10) + M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10) + return M + + +class ManifoldHyperConnection(nn.Module): + def __init__(self, dim: int, expansion: int = 2) -> None: + super().__init__() + self.dim = dim + self.expansion = expansion + n = expansion + + self.expand_fn = "duplicate" + self.collapse_fn = "mean" + + self.bias_pre = nn.Parameter(torch.zeros(1, n)) + self.bias_post = nn.Parameter(torch.zeros(1, n)) + self.bias_res = nn.Parameter(torch.zeros(n, n)) + + self.theta_pre = nn.Linear(n * dim, n, bias=False) + self.theta_post = nn.Linear(n * dim, n, bias=False) + self.theta_res = nn.Linear(n * dim, n * n, bias=False) + + self.alpha_pre = nn.Parameter(torch.tensor(0.0)) + self.alpha_post = nn.Parameter(torch.tensor(0.0)) + self.alpha_res = nn.Parameter(torch.tensor(0.0)) + + def _compute_mappings( + self, x_expanded: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, _ = x_expanded.shape + n = self.expansion + + x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]]) + + d_pre = torch.tanh(self.theta_pre(x_norm)) + d_post = torch.tanh(self.theta_post(x_norm)) + d_res = self.theta_res(x_norm) + + H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre) + H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post) + H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape( + B, T, n, n + ) + + H_res = _sinkhorn_knopp(H_res_raw) + + return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res + + def expand_stream(self, x: torch.Tensor) -> torch.Tensor: + return x.repeat(1, 1, self.expansion) + + def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor: + B, T, _ = x_expanded.shape + n = self.expansion + C = self.dim + return x_expanded.view(B, T, n, C).mean(dim=-2) + + def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor: + B, T, _ = x_expanded.shape + n = self.expansion + x_streams = x_expanded.view(B, T, n, self.dim) + return (H_pre @ x_streams).squeeze(-2) + + def post_res_mix( + self, + layer_output: torch.Tensor, + x_expanded: torch.Tensor, + H_post: torch.Tensor, + H_res: torch.Tensor, + ) -> torch.Tensor: + B, T, _ = x_expanded.shape + n = self.expansion + C = self.dim + + x_streams = x_expanded.view(B, T, n, C) + mixed = torch.matmul(H_res, x_streams) + post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2)) + + result = mixed + post_out + return result.reshape(B, T, n * C) + + +class TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + ffn_dim: int, + dropout: float, + sliding_window: int, + rope_fraction: float, + engram_dim: int = 0, + engram_heads: int = 4, + engram_table_size: int = 8192, + engram_max_ngram: int = 3, + mhc_expansion: int = 1, + ) -> None: + super().__init__() + self.norm1 = RMSNorm(dim) + self.attn = CausalSelfAttention( + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + dropout=dropout, + sliding_window=sliding_window, + rope_fraction=rope_fraction, + ) + self.norm2 = RMSNorm(dim) + self.ffn = SwiGLU(dim, ffn_dim, dropout) + self.use_engram = engram_dim > 0 + if self.use_engram: + self.engram = EngramBlock( + dim=dim, + engram_dim=engram_dim, + n_heads=engram_heads, + table_size=engram_table_size, + max_ngram=engram_max_ngram, + ) + self.engram_norm = RMSNorm(dim) + self.use_mhc = mhc_expansion > 1 + if self.use_mhc: + self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion) + self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion) + + def forward( + self, + x: torch.Tensor, + is_global: bool, + past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + token_ids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + if self.use_mhc: + x_exp = self.mhc_attn.expand_stream(x) + H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp) + attn_in = self.mhc_attn.pre_mix(x_exp, H_pre) + attn_out, new_kv = self.attn( + self.norm1(attn_in), is_global, past_kv, use_cache + ) + x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res) + if self.use_engram: + collapsed = self.mhc_attn.collapse_stream(x_exp) + collapsed = collapsed + self.engram( + self.engram_norm(collapsed), token_ids=token_ids + ) + x_exp = self.mhc_attn.expand_stream(collapsed) + H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp) + ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2) + ffn_out = self.ffn(self.norm2(ffn_in)) + x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2) + x = self.mhc_attn.collapse_stream(x_exp) + else: + attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache) + x = x + attn_out + if self.use_engram: + x = x + self.engram(self.engram_norm(x), token_ids=token_ids) + x = x + self.ffn(self.norm2(x)) + return x, new_kv + + +class RecurrentDepthBlock(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + ffn_dim: int, + dropout: float, + sliding_window: int, + rope_fraction: float, + n_loops: int, + act_threshold: float, + lora_rank: int, + loop_embed_dim: int, + ) -> None: + super().__init__() + self.n_loops = max(1, n_loops) + self.act_threshold = act_threshold + self.loop_embed_dim = max(0, loop_embed_dim) + self.norm = RMSNorm(dim) + self.block = TransformerBlock( + dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, + ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, + rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, + ) + self.injection = StableRecurrentInjection(dim) + self.act = AdaptiveHalting(dim) + self.lora = DepthLoRAAdapter(dim, lora_rank, self.n_loops) + + def forward( + self, + h: torch.Tensor, + e: torch.Tensor, + token_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, + use_cache: bool = False, + n_loops: Optional[int] = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: + loops = max(1, n_loops or self.n_loops) + B, T, _ = h.shape + halted = torch.zeros(B, T, device=h.device, dtype=torch.bool) + cumulative_p = torch.zeros(B, T, device=h.device, dtype=h.dtype) + output = torch.zeros_like(h) + new_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + current = h + final_halt = None + + for t in range(loops): + h_loop = loop_index_embedding(current, t, self.loop_embed_dim) + combined = self.norm(h_loop + e) + past_kv = None + if past_key_values is not None and t < len(past_key_values): + past_kv = past_key_values[t] + trans_out, layer_kv = self.block(combined, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=token_ids) + trans_out = trans_out + self.lora(trans_out, t) + next_h = self.injection(current, e, trans_out) + p = self.act(next_h) + p = p * (~halted).to(p.dtype) + final_halt = p + should_halt = (~halted) & ((cumulative_p + p) >= self.act_threshold) + update_weight = torch.where(should_halt, (1.0 - cumulative_p).clamp(min=0.0), p) + output = output + next_h * update_weight.unsqueeze(-1) + cumulative_p = cumulative_p + update_weight + current = torch.where(halted.unsqueeze(-1), current, next_h) + halted = halted | should_halt + if new_past is not None: + new_past.append(layer_kv) + if not use_cache and bool(halted.all()): + break + + remainder = (1.0 - cumulative_p).clamp(min=0.0) + output = output + current * remainder.unsqueeze(-1) + aux: Dict[str, torch.Tensor] = {} + if final_halt is not None: + aux["recurrent_halt_mean"] = final_halt.mean() + return output, aux, new_past + + +class TinyMemoryLM(nn.Module): + def __init__( + self, + vocab_size: int, + dim: int, + n_unique_layers: int, + n_logical_layers: int, + n_heads: int, + n_kv_heads: int, + ffn_dim: int, + dropout: float, + mtp_horizons: Sequence[int], + grad_checkpoint: bool, + sliding_window: int = 512, + rope_fraction: float = 0.5, + embed_scale: bool = True, + engram_dim: int = 0, + engram_heads: int = 4, + engram_table_size: int = 8192, + engram_max_ngram: int = 3, + mhc_expansion: int = 1, + sleep_gate_cap: int = 0, + sleep_gate_heads: int = 4, + sleep_retention_enabled: bool = True, + sleep_retention_hidden: int = 0, + latent_think_layers: int = 0, + prelude_layers: int = 0, + coda_layers: int = 0, + recurrent_loops: int = 0, + recurrent_act_threshold: float = 0.99, + recurrent_lora_rank: int = 0, + recurrent_loop_embed_dim: int = 0, + ) -> None: + super().__init__() + self.dim = dim + self.n_unique_layers = n_unique_layers + self.n_logical_layers = n_logical_layers + self.grad_checkpoint = grad_checkpoint + self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0 + head_dim = dim // n_heads + + self.embed_tokens = nn.Embedding(vocab_size, dim) + self.head = nn.Linear(dim, vocab_size, bias=False) + self.head.weight = self.embed_tokens.weight + self.output_bias = nn.Parameter(torch.zeros(vocab_size)) + + self.use_recurrent_depth = recurrent_loops > 0 + self.prelude_layers = max(0, prelude_layers) + self.coda_layers = max(0, coda_layers) + self.recurrent_loops = max(0, recurrent_loops) + + self.blocks: Optional[nn.ModuleList] = None + self.prelude: Optional[nn.ModuleList] = None + self.recurrent: Optional[RecurrentDepthBlock] = None + self.coda: Optional[nn.ModuleList] = None + + def _make_blocks(n: int) -> nn.ModuleList: + return nn.ModuleList([ + TransformerBlock( + dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, + ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, + rope_fraction=rope_fraction, engram_dim=engram_dim, + engram_heads=engram_heads, engram_table_size=engram_table_size, + engram_max_ngram=engram_max_ngram, mhc_expansion=mhc_expansion, + ) + for _ in range(n) + ]) + + if self.use_recurrent_depth: + if self.prelude_layers > 0: + self.prelude = _make_blocks(self.prelude_layers) + self.recurrent = RecurrentDepthBlock( + dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, + ffn_dim=ffn_dim, dropout=dropout, sliding_window=sliding_window, + rope_fraction=rope_fraction, n_loops=self.recurrent_loops, + act_threshold=recurrent_act_threshold, lora_rank=recurrent_lora_rank, + loop_embed_dim=recurrent_loop_embed_dim or max(2, dim // 8), + ) + if self.coda_layers > 0: + self.coda = _make_blocks(self.coda_layers) + else: + self.blocks = _make_blocks(max(1, n_unique_layers)) + + self.norm = RMSNorm(dim) + + self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1}) + self.mtp_adapters = nn.ModuleDict( + {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons} + ) + self.mtp_norms = nn.ModuleDict( + {str(h): RMSNorm(dim) for h in self.mtp_horizons} + ) + + res_scale = (2 * max(1, n_logical_layers)) ** -0.5 + for group in (self.blocks, self.prelude, self.coda): + if group is None: + continue + for block in group: + block.attn.wo.weight.data.mul_(res_scale) + block.ffn.down.weight.data.mul_(res_scale) + if self.recurrent is not None: + self.recurrent.block.attn.wo.weight.data.mul_(res_scale) + self.recurrent.block.ffn.down.weight.data.mul_(res_scale) + + self.sleep_gate: Optional[SleepGate] = None + if sleep_gate_cap > 0: + self.sleep_gate = SleepGate( + dim=dim, cap=sleep_gate_cap, n_heads=sleep_gate_heads, + retention_enabled=sleep_retention_enabled, + retention_hidden=sleep_retention_hidden, + ) + + self.think_blocks: Optional[nn.ModuleList] = None + self.think_norm: Optional[RMSNorm] = None + if latent_think_layers > 0: + self.think_blocks = nn.ModuleList([ + TransformerBlock( + dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, + ffn_dim=ffn_dim, dropout=0.0, sliding_window=2048, + rope_fraction=rope_fraction, engram_dim=0, mhc_expansion=1, + ) + for _ in range(latent_think_layers) + ]) + self.think_norm = RMSNorm(dim) + + def resize_token_embeddings(self, new_vocab_size: int) -> None: + old_vocab_size = self.embed_tokens.num_embeddings + if new_vocab_size == old_vocab_size: + return + device = self.embed_tokens.weight.device + old_embed_weight = self.embed_tokens.weight.data.clone() + self.embed_tokens = nn.Embedding(new_vocab_size, self.embed_tokens.embedding_dim).to(device) + self.head = nn.Linear(self.embed_tokens.embedding_dim, new_vocab_size, bias=False).to(device) + self.head.weight = self.embed_tokens.weight + old_bias = self.output_bias.data.clone() + self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device)) + copy_size = min(old_vocab_size, new_vocab_size) + self.output_bias.data[:copy_size] = old_bias[:copy_size] + self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size] + + def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]: + if self.blocks is None: + return [] + blocks_list = list(self.blocks) + full_sequence = blocks_list + blocks_list + return [(block, i) for i, block in enumerate(full_sequence[: self.n_logical_layers])] + + def forward( + self, + ids: torch.Tensor, + use_cache: bool = False, + past_key_values: Optional[List[Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None, + return_hidden: bool = False, + ) -> Tuple[torch.Tensor, Dict[int, torch.Tensor], Dict[str, torch.Tensor], Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: + B, T = ids.shape + x = self.embed_tokens(ids) * self.embed_scale_factor + new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None + aux: Dict[str, torch.Tensor] = {} + + if self.use_recurrent_depth: + offset = 0 + if self.prelude is not None: + for block in self.prelude: + past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None + x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) + if new_past_key_values is not None: + new_past_key_values.append(layer_kv) + offset += 1 + encoded = x + recurrent_past = past_key_values[offset: offset + self.recurrent_loops] if past_key_values is not None else None + x, recurrent_aux, recurrent_kv = self.recurrent( + x, encoded, token_ids=ids, past_key_values=recurrent_past, use_cache=use_cache, + ) + aux.update(recurrent_aux) + if new_past_key_values is not None and recurrent_kv is not None: + new_past_key_values.extend(recurrent_kv) + offset += self.recurrent_loops + if self.coda is not None: + for block in self.coda: + past_kv = past_key_values[offset] if past_key_values is not None and offset < len(past_key_values) else None + x, layer_kv = block(x, is_global=True, past_kv=past_kv, use_cache=use_cache, token_ids=ids) + if new_past_key_values is not None: + new_past_key_values.append(layer_kv) + offset += 1 + else: + logical_layers = self._build_logical_layers() + last_logical_idx = len(logical_layers) - 1 + for layer_idx, (block, logical_idx) in enumerate(logical_layers): + is_global = logical_idx % 2 == 0 or layer_idx == last_logical_idx + past_kv = past_key_values[layer_idx] if past_key_values is not None and layer_idx < len(past_key_values) else None + if self.grad_checkpoint and self.training and not use_cache: + x, layer_kv = checkpoint(block, x, is_global, past_kv, use_cache, ids, use_reentrant=True) + else: + x, layer_kv = block(x, is_global, past_kv, use_cache, ids) + if new_past_key_values is not None: + new_past_key_values.append(layer_kv) + + x = self.norm(x) + + if self.sleep_gate is not None: + x = x + self.sleep_gate.read(x) + if self.training: + self.sleep_gate.write(x) + + if self.think_blocks is not None: + for think_block in self.think_blocks: + x, _ = think_block(x, is_global=True) + x = self.think_norm(x) + + h_out = x if return_hidden else None + logits = self.head(x) + if self.embed_scale_factor != 1.0: + logits = logits / self.embed_scale_factor + logits = logits + self.output_bias + + mtp: Dict[int, torch.Tensor] = {} + if self.mtp_horizons and self.training: + for horizon in self.mtp_horizons: + if horizon > 1 and horizon <= T - 1: + shifted_h = x[:, :-horizon, :] + adapted_h = self.mtp_adapters[str(horizon)](shifted_h) + adapted_h = self.mtp_norms[str(horizon)](adapted_h) + mtp_logits = self.head(adapted_h) + if self.embed_scale_factor != 1.0: + mtp_logits = mtp_logits / self.embed_scale_factor + mtp_logits = mtp_logits + self.output_bias + mtp[horizon] = mtp_logits + + return logits, mtp, aux, h_out, new_past_key_values + + +# --------------------------------------------------------------------------- +# Generation +# --------------------------------------------------------------------------- + + +def build_stop_token_ids(tokenizer: WordTokenizer) -> set: + stop_tokens = {tokenizer.eos_id} + for tok in ("<|user|>", "<|system|>", "<|assistant|>"): + tid = tokenizer.token_to_id.get(tok) + if tid is not None: + stop_tokens.add(int(tid)) + return stop_tokens + + +def apply_no_repeat_ngram( + logits: torch.Tensor, + token_history: Sequence[int], + ngram_size: int, +) -> torch.Tensor: + if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1): + return logits + prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple() + banned: set = set() + for i in range(len(token_history) - ngram_size + 1): + if tuple(token_history[i : i + ngram_size - 1]) == prefix: + banned.add(int(token_history[i + ngram_size - 1])) + if not banned: + return logits + out = logits.clone() + banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long) + out[banned_ids] = float("-inf") + return out + + +def apply_loop_penalty( + logits: torch.Tensor, + tokenizer: WordTokenizer, + generated_text: str, + penalty: float = 5.0, +) -> torch.Tensor: + """Detect repeated substring loops and penalise continuation tokens.""" + if len(generated_text) < 16: + return logits + out = logits.clone() + for span_len in [24, 16, 12, 8]: + if len(generated_text) < span_len * 2: + continue + suffix = generated_text[-span_len:] + prev = generated_text[:-span_len].rfind(suffix) + if prev == -1: + continue + next_pos = prev + span_len + if next_pos < len(generated_text): + next_char = generated_text[next_pos] + tid = tokenizer.token_to_id.get(next_char) + if tid is not None: + out[tid] -= penalty + break + return out + + +def apply_min_p(logits: torch.Tensor, min_p: float) -> torch.Tensor: + """Filter tokens below min_p fraction of the top token probability.""" + if min_p <= 0.0: + return logits + probs = torch.softmax(logits, dim=-1) + threshold = probs.max() * min_p + out = logits.clone() + out[probs < threshold] = float("-inf") + return out + + +def generate( + model: TinyMemoryLM, + tokenizer: WordTokenizer, + prompt: str, + max_new_tokens: int = 256, + temperature: float = 0.8, + top_k: int = 16, + top_p: float = 0.95, + repetition_penalty: float = 1.0, + device: str = "cuda", + sft_mode: bool = True, + stream: bool = True, + no_repeat_ngram_size: int = 0, + context_window: int = 2048, + logit_soft_cap: float = 15.0, + min_p: float = 0.05, + loop_penalty: float = 5.0, +) -> str: + if sft_mode: + full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" + else: + full_prompt = prompt + input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) + input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) + visible_tokens: List[str] = [] + stop_token_ids = build_stop_token_ids(tokenizer) + generated_text = "" + + generated_ids: List[int] = [] + # Full history (prompt + generated) for ngram blocking — prevents echoing prompt + full_ids_history: List[int] = list(input_ids) + + with torch.no_grad(): + for _ in range(max_new_tokens): + ctx_ids = ( + input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t + ) + logits, *_ = model(ctx_ids) + next_logits = logits[0, -1, :].clone() + + # Logit soft-capping (Gemma-style) — prevents overconfident collapse + if logit_soft_cap > 0: + next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap) + + raw_next_logits = next_logits.clone() + + # Repetition penalty on previously generated tokens + if repetition_penalty != 1.0 and generated_ids: + for tok_id in set(generated_ids): + if next_logits[tok_id] > 0: + next_logits[tok_id] /= repetition_penalty + else: + next_logits[tok_id] *= repetition_penalty + + # No-repeat n-gram blocking on generated tokens only + if no_repeat_ngram_size > 0 and generated_ids: + next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size) + + # Substring loop detection + next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty) + + # Temperature scaling + if temperature != 1.0: + next_logits = next_logits / max(temperature, 1e-6) + + # Min-p filtering — remove tokens below min_p * max_prob + if min_p > 0: + next_logits = apply_min_p(next_logits, min_p) + + # Top-k filtering + if top_k > 0: + v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) + next_logits[next_logits < v[-1]] = float("-inf") + + # Top-p (nucleus) filtering + if 0.0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + remove_mask = cumulative_probs > top_p + remove_mask[0] = False + indices_to_remove = sorted_indices[remove_mask] + next_logits[indices_to_remove] = float("-inf") + + # Fallback if all tokens masked + if not torch.isfinite(next_logits).any(): + next_logits = raw_next_logits + if temperature != 1.0: + next_logits = next_logits / max(temperature, 1e-6) + + if temperature == 0: + next_id = torch.argmax(next_logits).item() + else: + probs = torch.softmax(next_logits, dim=-1) + next_id = torch.multinomial(probs, num_samples=1).item() + if next_id in stop_token_ids: + break + token_str = ( + tokenizer.id_to_token[next_id] + if next_id < len(tokenizer.id_to_token) + else "" + ) + generated_ids.append(next_id) + full_ids_history.append(next_id) + if token_str not in tokenizer.special: + visible_tokens.append(token_str) + generated_text += token_str + if stream: + print(token_str, end="", flush=True) + input_ids_t = torch.cat( + [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 + ) + if stream: + print() + return "".join(visible_tokens) + + +def generate_stream( + model: TinyMemoryLM, + tokenizer: WordTokenizer, + prompt: str, + max_new_tokens: int = 256, + temperature: float = 0.8, + top_k: int = 16, + top_p: float = 0.95, + repetition_penalty: float = 1.0, + device: str = "cpu", + sft_mode: bool = True, + no_repeat_ngram_size: int = 0, + context_window: int = 2048, + logit_soft_cap: float = 15.0, + min_p: float = 0.05, + loop_penalty: float = 5.0, +) -> "Iterator[str]": + """Yield the accumulated response string after each new token (for Gradio streaming).""" + if sft_mode: + full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" + else: + full_prompt = prompt + input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) + input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) + stop_token_ids = build_stop_token_ids(tokenizer) + generated_ids: list = [] + generated_text = "" + + with torch.no_grad(): + for _ in range(max_new_tokens): + ctx_ids = input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t + logits, *_ = model(ctx_ids) + next_logits = logits[0, -1, :].clone() + + if logit_soft_cap > 0: + next_logits = logit_soft_cap * torch.tanh(next_logits / logit_soft_cap) + + raw_next_logits = next_logits.clone() + + if repetition_penalty != 1.0 and generated_ids: + for tok_id in set(generated_ids): + if next_logits[tok_id] > 0: + next_logits[tok_id] /= repetition_penalty + else: + next_logits[tok_id] *= repetition_penalty + + if no_repeat_ngram_size > 0 and generated_ids: + next_logits = apply_no_repeat_ngram(next_logits, generated_ids, no_repeat_ngram_size) + + next_logits = apply_loop_penalty(next_logits, tokenizer, generated_text, penalty=loop_penalty) + + if temperature != 1.0: + next_logits = next_logits / max(temperature, 1e-6) + if min_p > 0: + next_logits = apply_min_p(next_logits, min_p) + if top_k > 0: + v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) + next_logits[next_logits < v[-1]] = float("-inf") + if 0.0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) + sorted_probs = torch.softmax(sorted_logits, dim=-1) + cumulative_probs = torch.cumsum(sorted_probs, dim=-1) + remove_mask = cumulative_probs > top_p + remove_mask[0] = False + next_logits[sorted_indices[remove_mask]] = float("-inf") + if not torch.isfinite(next_logits).any(): + next_logits = raw_next_logits + if temperature != 1.0: + next_logits = next_logits / max(temperature, 1e-6) + + if temperature == 0: + next_id = int(torch.argmax(next_logits).item()) + else: + probs = torch.softmax(next_logits, dim=-1) + next_id = int(torch.multinomial(probs, num_samples=1).item()) + + if next_id in stop_token_ids: + break + + token_str = tokenizer.id_to_token[next_id] if next_id < len(tokenizer.id_to_token) else "" + generated_ids.append(next_id) + if token_str not in tokenizer.special: + generated_text += token_str + yield generated_text + + input_ids_t = torch.cat( + [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 + ) + + +# --------------------------------------------------------------------------- +# Local model loading +# --------------------------------------------------------------------------- + + +def series_from_name(name: str) -> str | None: + lower = (name or "").lower() + if "haiku" in lower: + return "Haiku" + if "sonnet" in lower: + return "Sonnet" + if "opus" in lower: + return "Opus" + return None + + +def series_config(series: str) -> dict[str, object]: + return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"]) + + +def discover_models(runs_dir: Path) -> List[dict]: + models = [] + if not runs_dir.is_dir(): + return models + for child in sorted(runs_dir.iterdir()): + if not child.is_dir(): + continue + tokenizer_path = child / "tokenizer.json" + if not tokenizer_path.exists(): + continue + name = child.name + series = None + for ckpt_name in ("model.pt", "pretrain.pt"): + ckpt_path = child / ckpt_name + if ckpt_path.exists(): + series = _fast_series_from_checkpoint(ckpt_path) + break + if series is None: + series = series_from_name(name) or "Sonnet" + found = False + for ckpt_name in ("model.pt", "model_rep.pt", "pretrain.pt"): + ckpt_path = child / ckpt_name + if ckpt_path.exists(): + models.append( + { + "name": name, + "checkpoint": ckpt_name, + "series": series, + "model_path": ckpt_path, + "tokenizer_path": tokenizer_path, + } + ) + found = True + if not found: + step_ckpts = sorted( + child.glob("checkpoint_step_*.pt"), + key=lambda p: int(p.stem.rsplit("_", 1)[-1]), + ) + if step_ckpts: + ckpt_path = step_ckpts[-1] + models.append( + { + "name": name, + "checkpoint": ckpt_path.name, + "series": series, + "model_path": ckpt_path, + "tokenizer_path": tokenizer_path, + } + ) + return models + + +def _detect_engram(state_dict): + for key in state_dict: + if ".engram." in key: + if ".embeddings." in key: + return state_dict[key].shape[-1] + return 0 + + +def _detect_mhc(state_dict): + for key, val in state_dict.items(): + if ".mhc_attn.bias_pre" in key and val.dim() == 2: + return val.shape[-1] # (1, expansion) + return 1 + + +def _detect_sleep_gate(state_dict) -> Tuple[int, int]: + for key, val in state_dict.items(): + if key == "sleep_gate.mem_emb" and val.dim() == 2: + cap = val.shape[0] + return cap, 4 + return 0, 4 + + +def _detect_latent_think(state_dict) -> int: + indices = { + int(k.split(".")[1]) + for k in state_dict + if k.startswith("think_blocks.") and k.split(".")[1].isdigit() + } + return max(indices) + 1 if indices else 0 + + +def _detect_prelude_layers(state_dict) -> int: + indices = { + int(k.split(".")[1]) + for k in state_dict + if k.startswith("prelude.") and k.split(".")[1].isdigit() + } + return max(indices) + 1 if indices else 0 + + +def _detect_coda_layers(state_dict) -> int: + indices = { + int(k.split(".")[1]) + for k in state_dict + if k.startswith("coda.") and k.split(".")[1].isdigit() + } + return max(indices) + 1 if indices else 0 + + +def _detect_recurrent_loops(state_dict) -> int: + if "recurrent.norm.weight" in state_dict or "recurrent.block.attn.wq.weight" in state_dict: + if "recurrent.lora.scale.weight" in state_dict: + return state_dict["recurrent.lora.scale.weight"].shape[0] + return 1 + return 0 + + +def _detect_recurrent_lora_rank(state_dict) -> int: + for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): + if key in state_dict: + shape = state_dict[key].shape + if len(shape) == 2: + return int(shape[0]) + return 0 + + +def _infer_series_from_lora_rank(rank: int) -> str | None: + if rank == 0: + return None + if rank <= 8: + return "haiku" + if rank <= 16: + return "sonnet" + return "opus" + + +def _fast_series_from_checkpoint(ckpt_path: Path) -> str | None: + try: + cp = torch.load(ckpt_path, map_location="cpu", weights_only=False) + sd = cp.get("model_state", cp.get("state_dict", {})) + rank = 0 + for key in ("recurrent.lora.B", "recurrent.lora.down.weight"): + if key in sd: + rank = int(sd[key].shape[0]) + break + if rank == 0: + return None + if rank <= 8: + return "Haiku" + if rank <= 16: + return "Sonnet" + return "Opus" + except Exception: + pass + return None + + +def _infer_arch_from_state_dict(state_dict, cfg): + """Infer architecture hyper-parameters directly from checkpoint weights, + falling back to *cfg* (series config) when a key is not found.""" + overrides = {} + + has_prelude = any(k.startswith("prelude.") for k in state_dict) + has_blocks = any(k.startswith("blocks.") for k in state_dict) + has_recurrent = any(k.startswith("recurrent.") for k in state_dict) + uses_recurrent_arch = has_prelude and has_recurrent and not has_blocks + + # dim from embed_tokens.weight [vocab, dim] + if "embed_tokens.weight" in state_dict: + overrides["dim"] = state_dict["embed_tokens.weight"].shape[1] + + if uses_recurrent_arch: + if "prelude.0.ffn.gate.weight" in state_dict: + overrides["ffn_dim"] = state_dict["prelude.0.ffn.gate.weight"].shape[0] + overrides["n_unique_layers"] = 0 + src = "prelude.0" + else: + if "blocks.0.ffn.gate.weight" in state_dict: + overrides["ffn_dim"] = state_dict["blocks.0.ffn.gate.weight"].shape[0] + block_ids = { + int(k.split(".")[1]) + for k in state_dict + if k.startswith("blocks.") and k.split(".")[1].isdigit() + } + if block_ids: + overrides["n_unique_layers"] = max(block_ids) + 1 + src = "blocks.0" + + dim = overrides.get("dim", int(cfg.get("dim", model_config.dim))) + if f"{src}.attn.wq.weight" in state_dict: + wq_rows = state_dict[f"{src}.attn.wq.weight"].shape[0] + if f"{src}.attn.q_norm.weight" in state_dict: + head_dim = state_dict[f"{src}.attn.q_norm.weight"].shape[0] + overrides["n_heads"] = wq_rows // head_dim + if f"{src}.attn.wk.weight" in state_dict: + wk_rows = state_dict[f"{src}.attn.wk.weight"].shape[0] + if f"{src}.attn.k_norm.weight" in state_dict: + head_dim = state_dict[f"{src}.attn.k_norm.weight"].shape[0] + overrides["n_kv_heads"] = wk_rows // head_dim + + # engram params + for key, val in state_dict.items(): + if ".engram.embeddings." in key and key.endswith("_0") and val.dim() == 2: + overrides["engram_table_size"] = val.shape[0] + overrides["engram_dim"] = val.shape[1] + break + engram_dim = overrides.get("engram_dim", int(cfg.get("engram_dim", 0))) + engram_max_ngram = int(cfg.get("engram_max_ngram", 2)) + if engram_dim > 0: + for key, val in state_dict.items(): + if ".engram.branch_conv.weight" in key and val.dim() == 3: + total_branch_dim = val.shape[0] + denom = engram_dim * (engram_max_ngram - 1) + if denom > 0 and total_branch_dim % denom == 0: + overrides["engram_heads"] = total_branch_dim // denom + break + + merged = dict(cfg) + merged.update(overrides) + return merged + + +def load_local_model(model_path: Path, tokenizer_path: Path, series: str) -> dict: + tokenizer = WordTokenizer.load(tokenizer_path) + ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) + cfg = series_config(series) + vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size)) + + state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt + + cfg = _infer_arch_from_state_dict(state_dict, cfg) + + engram_dim = int(cfg.get("engram_dim", 0)) + if _detect_engram(state_dict) == 0: + engram_dim = 0 + + mhc_expansion = _detect_mhc(state_dict) + if mhc_expansion == 1: + mhc_expansion = int(cfg.get("mhc_expansion", 1)) + + ckpt_sleep_cap, ckpt_sleep_heads = _detect_sleep_gate(state_dict) + sleep_gate_cap = ckpt_sleep_cap if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_cap", 0)) + sleep_gate_heads = ckpt_sleep_heads if ckpt_sleep_cap > 0 else int(cfg.get("sleep_gate_heads", 4)) + sleep_retention_enabled = bool(cfg.get("sleep_retention_enabled", True)) + sleep_retention_hidden = int(cfg.get("sleep_retention_hidden", 0)) + + latent_think_layers = _detect_latent_think(state_dict) + if latent_think_layers == 0: + latent_think_layers = int(cfg.get("latent_think_layers", 0)) + + prelude_layers = _detect_prelude_layers(state_dict) + coda_layers = _detect_coda_layers(state_dict) + recurrent_loops = _detect_recurrent_loops(state_dict) + + ckpt_lora_rank = _detect_recurrent_lora_rank(state_dict) + if ckpt_lora_rank > 0: + inferred_series = _infer_series_from_lora_rank(ckpt_lora_rank) + if inferred_series and inferred_series != series.lower(): + series = inferred_series.capitalize() + cfg = series_config(series) + recurrent_lora_rank = ckpt_lora_rank + else: + recurrent_lora_rank = int(cfg.get("recurrent_lora_rank", 0)) + + recurrent_act_threshold = float(cfg.get("recurrent_act_threshold", 0.99)) + recurrent_loop_embed_dim = int(cfg.get("recurrent_loop_embed_dim", 0)) + + n_unique = int(cfg.get("n_unique_layers", model_config.n_unique_layers)) + + model = TinyMemoryLM( + vocab_size=vocab_size, + dim=int(cfg.get("dim", model_config.dim)), + n_unique_layers=n_unique, + n_logical_layers=int(cfg.get("n_logical_layers", model_config.n_logical_layers)), + n_heads=int(cfg.get("n_heads", model_config.n_heads)), + n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)), + ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)), + dropout=float(cfg.get("dropout", model_config.dropout)), + mtp_horizons=tuple(int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)), + grad_checkpoint=False, + sliding_window=int(cfg.get("sliding_window_size", getattr(model_config, "sliding_window_size", 512))), + rope_fraction=float(cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))), + embed_scale=bool(cfg.get("embed_scale", getattr(model_config, "embed_scale", True))), + engram_dim=engram_dim, + engram_heads=int(cfg.get("engram_heads", 4)), + engram_table_size=int(cfg.get("engram_table_size", 8192)), + engram_max_ngram=int(cfg.get("engram_max_ngram", 3)), + mhc_expansion=mhc_expansion, + sleep_gate_cap=sleep_gate_cap, + sleep_gate_heads=sleep_gate_heads, + sleep_retention_enabled=sleep_retention_enabled, + sleep_retention_hidden=sleep_retention_hidden, + latent_think_layers=latent_think_layers, + prelude_layers=prelude_layers, + coda_layers=coda_layers, + recurrent_loops=recurrent_loops, + recurrent_act_threshold=recurrent_act_threshold, + recurrent_lora_rank=recurrent_lora_rank, + recurrent_loop_embed_dim=recurrent_loop_embed_dim, + ) + model.load_state_dict(state_dict, strict=False) + model.eval() + if tokenizer.vocab_size > vocab_size: + model.resize_token_embeddings(tokenizer.vocab_size) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + return { + "model": model, + "tokenizer": tokenizer, + "device": device, + "series": series, + "sft_mode": ckpt.get("sft_mode", None), + "phase": ckpt.get("phase", None), + } + + +# --------------------------------------------------------------------------- +# HuggingFace Model Download & Loading +# --------------------------------------------------------------------------- + +def download_huggingface_model(hf_id: str, cache_dir: Path) -> dict: + try: + from huggingface_hub import snapshot_download + except ImportError: + print("huggingface_hub not installed. Install with: pip install huggingface_hub") + sys.exit(1) + + print(f"Downloading {hf_id}...") + try: + local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) + except Exception as e: + print(f"Failed to download {hf_id}: {e}") + return None + + print(f"Using cached {hf_id} from {local_dir}") + + # Check common subdirectory names: "models/", "model/" + if (local_dir / "models").exists(): + model_dir = local_dir / "models" + elif (local_dir / "model").exists(): + model_dir = local_dir / "model" + else: + model_dir = local_dir + model_path = model_dir / "model.pt" + pretrain_path = model_dir / "pretrain.pt" + tokenizer_path = model_dir / "tokenizer.json" + + ckpt_path = None + for p in [model_path, pretrain_path]: + if p.exists(): + ckpt_path = p + break + + if ckpt_path is None or not tokenizer_path.exists(): + print(f"Missing model files in {model_dir}") + print(f" model.pt exists: {model_path.exists()}") + print(f" pretrain.pt exists: {pretrain_path.exists()}") + print(f" tokenizer.json exists: {tokenizer_path.exists()}") + return None + + return { + "model_path": ckpt_path, + "tokenizer_path": tokenizer_path, + "model_name": ckpt_path.stem, + } + + +def load_huggingface_model(hf_id: str, cache_dir: Path) -> dict: + files = download_huggingface_model(hf_id, cache_dir) + if files is None: + return None + + return load_local_model(files["model_path"], files["tokenizer_path"], "Haiku") + + +# --------------------------------------------------------------------------- +# Compare All Models +# --------------------------------------------------------------------------- + +_hf_model_cache: Dict[str, dict] = {} + + +def prefetch_huggingface_models() -> None: + root = Path(__file__).resolve().parent + cache_dir = root / "cache" / "huggingface" + cache_dir.mkdir(parents=True, exist_ok=True) + + print("Downloading/preparing HuggingFace models...") + for name, hf_id in HUGGINGFACE_MODELS.items(): + print(f" {name}...") + bundle = load_huggingface_model(hf_id, cache_dir) + if bundle: + _hf_model_cache[name] = bundle + print(f"Prepared {len(_hf_model_cache)} HuggingFace models") + + +def compare_all_models(prompt: str, cfg: dict) -> None: + root = Path(__file__).resolve().parent + runs_dir = root / "runs" + all_models = discover_models(runs_dir) + + is_pretrain = not cfg.get("sft_mode", True) + local_models = [ + m for m in all_models + if ("pretrain" in m["checkpoint"]) == is_pretrain + ] + + if not local_models: + print("No models found matching mode.") + return + + results: List[dict] = [] + + for m in local_models: + print(f"\n{'='*60}") + print(f"Loading local {m['name']}/{m['checkpoint']}...") + try: + bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) + except Exception as e: + print(f"Failed to load {m['name']}: {e}") + continue + + model = bundle["model"] + tokenizer = bundle["tokenizer"] + device = bundle["device"] + + print(f"Generating on '{prompt}'...") + output = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_new_tokens=cfg["max_new_tokens"], + temperature=cfg["temperature"], + top_k=cfg["top_k"], + top_p=cfg["top_p"], + min_p=cfg["min_p"], + no_repeat_ngram_size=cfg["no_repeat_ngram_size"], + repetition_penalty=cfg["repetition_penalty"], + logit_soft_cap=cfg["logit_soft_cap"], + loop_penalty=cfg["loop_penalty"], + device=str(device), + sft_mode=cfg["sft_mode"], + stream=True, + context_window=cfg["context_window"], + ) + + results.append({ + "name": f"[LOCAL] {m['name']}/{m['checkpoint']}", + "output": output, + "device": device, + }) + + for name, bundle in _hf_model_cache.items(): + print(f"\n{'='*60}") + print(f"Loading {name} (cached)...") + + model = bundle["model"] + tokenizer = bundle["tokenizer"] + device = bundle["device"] + + print(f"Generating on '{prompt}'...") + output = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_new_tokens=cfg["max_new_tokens"], + temperature=cfg["temperature"], + top_k=cfg["top_k"], + top_p=cfg["top_p"], + min_p=cfg["min_p"], + no_repeat_ngram_size=cfg["no_repeat_ngram_size"], + repetition_penalty=cfg["repetition_penalty"], + logit_soft_cap=cfg["logit_soft_cap"], + loop_penalty=cfg["loop_penalty"], + device=str(device), + sft_mode=cfg["sft_mode"], + stream=True, + context_window=cfg["context_window"], + ) + + results.append({ + "name": name, + "output": output, + "device": device, + }) + + print(f"\n{'='*60}") + print("=" * 60) + print("SIDE-BY-SIDE COMPARISON") + print("=" * 60) + for r in results: + print(f"\n--- {r['name']} ---") + print(r["output"]) + print(f"\n{'='*60}") + + +# --------------------------------------------------------------------------- +# Benchmark +# --------------------------------------------------------------------------- + +BENCHMARKS = { + "blimp": { + "label": "BLiMP", + "desc": "Grammaticality minimal pairs (67 paradigms). Accuracy = % grammatical < ungrammatical perplexity.", + "hf_dataset": ("nyu-mll/blimp", None), + "metric": "accuracy", + }, + "wikitext2": { + "label": "WikiText-2", + "desc": "LM perplexity on Wikipedia test split. Lower is better.", + "hf_dataset": ("Salesforce/wikitext", "wikitext-2-raw-v1"), + "metric": "perplexity", + }, + "arc_easy": { + "label": "ARC-Easy", + "desc": "Multiple-choice science QA (~2.4K). Perplexity-based answer selection.", + "hf_dataset": ("allenai/ai2_arc", "ARC-Easy"), + "metric": "accuracy", + }, +} + + +def _score_text(model: TinyMemoryLM, tokenizer: WordTokenizer, text: str, device: str) -> float: + ids = tokenizer.encode(text, add_bos=True, add_eos=False) + if len(ids) < 2: + return float("nan") + ids_t = torch.tensor([ids], dtype=torch.long, device=device) + with torch.no_grad(): + logits, *_ = model(ids_t) + log_probs = F.log_softmax(logits[0], dim=-1) + targets = ids_t[0, 1:] + nll = -log_probs[range(len(targets)), targets].mean().item() + return nll + + +def _score_completion(model: TinyMemoryLM, tokenizer: WordTokenizer, context: str, completion: str, device: str) -> float: + full_ids = tokenizer.encode(context + completion, add_bos=True, add_eos=False) + ctx_ids = tokenizer.encode(context, add_bos=True, add_eos=False) + n_ctx = len(ctx_ids) + n_ref = len(full_ids) - n_ctx + if n_ref <= 0: + return float("nan") + ids_t = torch.tensor([full_ids], dtype=torch.long, device=device) + with torch.no_grad(): + logits, *_ = model(ids_t) + log_probs = F.log_softmax(logits[0], dim=-1) + targets = ids_t[0, 1:] + ref_start = n_ctx - 1 + ref_end = min(ref_start + n_ref, log_probs.shape[0]) + if ref_start >= ref_end: + return float("nan") + nll = -log_probs[ref_start:ref_end][range(ref_end - ref_start), targets[ref_start:ref_end]].mean().item() + return nll + + +BLIMP_PARADIGMS = [ + "adjunct_island", "anaphor_gender_agreement", "anaphor_number_agreement", + "animate_subject_passive", "animate_subject_trans", "causative", + "complex_NP_island", "coordinate_structure_constraint_complex_left_branch", + "coordinate_structure_constraint_object_extraction", + "determiner_noun_agreement_1", "determiner_noun_agreement_2", + "determiner_noun_agreement_irregular_1", "determiner_noun_agreement_irregular_2", + "determiner_noun_agreement_with_adj_2", "determiner_noun_agreement_with_adj_irregular_1", + "determiner_noun_agreement_with_adj_irregular_2", "determiner_noun_agreement_with_adjective_1", + "distractor_agreement_relational_noun", "distractor_agreement_relative_clause", + "drop_argument", "ellipsis_n_bar_1", "ellipsis_n_bar_2", + "existential_there_object_raising", "existential_there_quantifiers_1", + "existential_there_quantifiers_2", "existential_there_subject_raising", + "expletive_it_object_raising", "inchoative", "intransitive", + "irregular_past_participle_adjectives", "irregular_past_participle_verbs", + "irregular_plural_subject_verb_agreement_1", "irregular_plural_subject_verb_agreement_2", + "left_branch_island_echo_question", "left_branch_island_simple_question", + "matrix_question_npi_licensor_present", "npi_present_1", "npi_present_2", + "only_npi_licensor_present", "only_npi_scope", "passive_1", "passive_2", + "principle_A_c_command", "principle_A_case_1", "principle_A_case_2", + "principle_A_domain_1", "principle_A_domain_2", "principle_A_domain_3", + "principle_A_reconstruction", "regular_plural_subject_verb_agreement_1", + "regular_plural_subject_verb_agreement_2", "sentential_negation_npi_licensor_present", + "sentential_negation_npi_scope", "sentential_subject_island", + "superlative_quantifiers_1", "superlative_quantifiers_2", + "tough_vs_raising_1", "tough_vs_raising_2", "transitive", "wh_island", + "wh_questions_object_gap", "wh_questions_subject_gap", + "wh_questions_subject_gap_long_distance", "wh_vs_that_no_gap", + "wh_vs_that_no_gap_long_distance", "wh_vs_that_with_gap", + "wh_vs_that_with_gap_long_distance", +] + + +def _run_blimp(model, tokenizer, device, n_samples: int = 200) -> Tuple[List[str], List[float]]: + from datasets import load_dataset # type: ignore + accuracies: List[float] = [] + for paradigm in BLIMP_PARADIGMS: + try: + ds = load_dataset("nyu-mll/blimp", paradigm, split="train") + except Exception as e: + print(f" {paradigm}: skip ({e})") + accuracies.append(float("nan")) + continue + items = list(ds)[:n_samples] + correct = 0 + for ex in items: + good_nll = _score_text(model, tokenizer, ex["sentence_good"], device) + bad_nll = _score_text(model, tokenizer, ex["sentence_bad"], device) + if math.isnan(good_nll) or math.isnan(bad_nll): + continue + if good_nll < bad_nll: + correct += 1 + acc = correct / len(items) if items else float("nan") + accuracies.append(acc) + print(f" {paradigm:50s} acc={acc:.3f}") + return BLIMP_PARADIGMS, accuracies + + +def _run_wikitext2(model, tokenizer, device, chunk_chars: int = 512, max_chunks: int = 100) -> Tuple[List[str], List[float]]: + from datasets import load_dataset # type: ignore + ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") + full_text = "\n".join(ex["text"] for ex in ds if ex["text"].strip()) + chunks = [full_text[i:i + chunk_chars] for i in range(0, len(full_text), chunk_chars)] + chunks = [c for c in chunks if len(c) > 20][:max_chunks] + labels: List[str] = [] + ppls: List[float] = [] + for i, chunk in enumerate(chunks): + nll = _score_text(model, tokenizer, chunk, device) + ppl = math.exp(nll) if not math.isnan(nll) else float("nan") + labels.append(f"chunk {i + 1}") + ppls.append(ppl) + if (i + 1) % 10 == 0: + valid = [v for v in ppls if not math.isnan(v)] + mean = sum(valid) / len(valid) if valid else float("nan") + print(f" chunk {i + 1}/{len(chunks)} running mean ppl={mean:.2f}") + return labels, ppls + + +def _run_arc_easy(model, tokenizer, device, max_samples: int = 200) -> Tuple[List[str], List[float]]: + from datasets import load_dataset # type: ignore + ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test") + items = list(ds)[:max_samples] + labels: List[str] = [] + scores: List[float] = [] + for i, ex in enumerate(items): + question = ex["question"] + choices = ex["choices"]["text"] + choice_labels = ex["choices"]["label"] + answer_key = ex["answerKey"] + context = f"Question: {question}\nAnswer:" + nlls = [_score_completion(model, tokenizer, context, f" {c}", device) for c in choices] + if all(math.isnan(v) for v in nlls): + scores.append(float("nan")) + else: + best_idx = min(range(len(nlls)), key=lambda j: nlls[j] if not math.isnan(nlls[j]) else float("inf")) + predicted = choice_labels[best_idx] + scores.append(1.0 if predicted == answer_key else 0.0) + labels.append(f"Q{i + 1}") + n_valid = sum(1 for s in scores if not math.isnan(s)) + acc = sum(s for s in scores if not math.isnan(s)) / n_valid if n_valid else float("nan") + print(f" {n_valid} questions evaluated, accuracy={acc:.3f}") + return labels, scores + + +def run_benchmark_mode() -> None: + try: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed. pip install matplotlib") + return + + bench_keys = list(BENCHMARKS.keys()) + print("\nBenchmarks:") + for i, k in enumerate(bench_keys): + b = BENCHMARKS[k] + print(f" [{i + 1}] {b['label']} — {b['desc']}") + print("Select benchmark [1]:", end=" ", flush=True) + try: + b_choice = input().strip() or "1" + except (EOFError, KeyboardInterrupt): + print() + return + if not (b_choice.isdigit() and 1 <= int(b_choice) <= len(bench_keys)): + print("Invalid selection.") + return + bench_key = bench_keys[int(b_choice) - 1] + bench = BENCHMARKS[bench_key] + print(f"Benchmark: {bench['label']}") + + root = Path(__file__).resolve().parent + runs_dir = root / "runs" + all_models = discover_models(runs_dir) + + model_entries: List[dict] = [] + for m in all_models: + model_entries.append({"label": f"[LOCAL] {m['name']}/{m['checkpoint']}", "type": "local", "meta": m}) + for hf_name, hf_id in HUGGINGFACE_MODELS.items(): + model_entries.append({"label": f"[HF] {hf_name}", "type": "hf", "hf_id": hf_id, "hf_name": hf_name}) + + if not model_entries: + print("No models found.") + return + + print("\nAvailable models:") + for i, e in enumerate(model_entries): + print(f" [{i + 1}] {e['label']}") + print(" [a] All models") + print("Select models (comma-separated or 'a'):", end=" ", flush=True) + try: + raw = input().strip() + except (EOFError, KeyboardInterrupt): + print() + return + + if raw.lower() == "a": + selected = list(range(len(model_entries))) + else: + selected = [] + for tok in raw.split(","): + tok = tok.strip() + if tok.isdigit() and 1 <= int(tok) <= len(model_entries): + selected.append(int(tok) - 1) + if not selected: + print("No valid selection.") + return + + all_results: List[dict] = [] + shared_x_labels: Optional[List[str]] = None + + for idx in selected: + entry = model_entries[idx] + print(f"\n{'='*60}\nLoading {entry['label']}...") + try: + if entry["type"] == "local": + m = entry["meta"] + bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) + else: + bundle = load_huggingface_model(entry["hf_id"], root / ".hf_cache") + except Exception as e: + print(f" Failed: {e}") + continue + + model = bundle["model"] + tokenizer = bundle["tokenizer"] + device = str(bundle["device"]) + model.eval() + + if bench_key == "blimp": + x_labels, y_vals = _run_blimp(model, tokenizer, device) + elif bench_key == "wikitext2": + x_labels, y_vals = _run_wikitext2(model, tokenizer, device) + else: + x_labels, y_vals = _run_arc_easy(model, tokenizer, device) + + if shared_x_labels is None: + shared_x_labels = x_labels + + valid = [v for v in y_vals if not math.isnan(v)] + summary = sum(valid) / len(valid) if valid else float("nan") + all_results.append({"label": entry["label"], "y": y_vals, "summary": summary}) + + if not all_results or shared_x_labels is None: + print("No results to plot.") + return + + metric = bench["metric"] + paired = sorted(zip([r["summary"] for r in all_results], [r["label"] for r in all_results]), + reverse=(metric != "perplexity")) + summaries, model_labels = zip(*paired) if paired else ([], []) + n = len(summaries) + colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)] + + fig, ax = plt.subplots(figsize=(max(6, n * 1.4), 6)) + bars = ax.bar(range(n), summaries, color=colors, edgecolor="black") + for bar, val in zip(bars, summaries): + ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, + f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold") + + ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)" + ax.set_ylabel(ylabel) + ax.set_title(f"{bench['label']} Benchmark — Model Comparison") + ax.set_xticks(range(n)) + ax.set_xticklabels(model_labels, rotation=20, ha="right", fontsize=9) + if metric == "accuracy": + ax.set_ylim(0, 1.05) + ax.grid(True, axis="y", alpha=0.3) + plt.tight_layout() + + out_path = root / f"benchmark_{bench_key}.png" + plt.savefig(str(out_path), dpi=150) + print(f"\nChart saved to {out_path}") + try: + import subprocess + subprocess.Popen(["xdg-open", str(out_path)]) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Interactive CLI +# --------------------------------------------------------------------------- + + +def _pick_series(detected: str) -> str: + series_list = list(MODEL_SERIES.keys()) + detected_lower = detected.lower() + default_idx = next( + (i + 1 for i, s in enumerate(series_list) if s == detected_lower), 1 + ) + + # Skip selection if only one series available + if len(series_list) == 1: + return series_list[0].capitalize() + + print("Series:") + for i, s in enumerate(series_list): + marker = " (detected)" if s == detected_lower else "" + print(f" [{i + 1}] {s.capitalize()}{marker}") + while True: + try: + choice = input(f"Select series [{default_idx}]: ").strip() + except (EOFError, KeyboardInterrupt): + print() + sys.exit(0) + if not choice: + choice = str(default_idx) + if choice.isdigit() and 1 <= int(choice) <= len(series_list): + return series_list[int(choice) - 1].capitalize() + print(f"Enter a number 1-{len(series_list)}") + + +def pick_model(runs_dir: Path) -> tuple[dict, str]: + models = discover_models(runs_dir) + if not models: + print(f"No models found in {runs_dir}") + print("Expected layout: runs//model.pt (or pretrain.pt) + tokenizer.json") + sys.exit(1) + + if len(models) == 1: + m = models[0] + print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") + bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) + return bundle, m["checkpoint"] + + print("Available models:") + for i, m in enumerate(models): + print(f" [{i + 1}] {m['name']}/{m['checkpoint']} ({m['series']})") + while True: + try: + choice = input("Select model [1]: ").strip() + except (EOFError, KeyboardInterrupt): + print() + sys.exit(0) + if not choice: + choice = "1" + if choice.isdigit() and 1 <= int(choice) <= len(models): + m = models[int(choice) - 1] + print(f"Loading {m['name']}/{m['checkpoint']} ({m['series']})...") + bundle = load_local_model(m["model_path"], m["tokenizer_path"], m["series"]) + return bundle, m["checkpoint"] + print(f"Enter a number 1-{len(models)}") + + +# --------------------------------------------------------------------------- +# Generation mode configs +# --------------------------------------------------------------------------- + +MODES = { + "chat-coherent": { + "label": "Chat — Coherent", + "desc": "structured, consistent, strong repetition control", + "sft_mode": "chat", + "temperature": 0.35, + "top_k": 20, + "top_p": 0.88, + "min_p": 0.10, + "no_repeat_ngram_size": 4, + "repetition_penalty": 1.22, + "logit_soft_cap": 20.0, + "loop_penalty": 20.0, + "max_new_tokens": 4096, + "context_window": 2048, + }, + "chat-variants": { + "label": "Chat — Variants", + "desc": "creative, diverse, more surprising outputs", + "sft_mode": "chat", + "temperature": 0.65, + "top_k": 60, + "top_p": 0.92, + "min_p": 0.05, + "no_repeat_ngram_size": 4, + "repetition_penalty": 1.12, + "logit_soft_cap": 20.0, + "loop_penalty": 14.0, + "max_new_tokens": 4096, + "context_window": 2048, + }, + "pretrain-coherent": { + "label": "Pretrain — Coherent", + "desc": "grounded continuation, low temperature, tight sampling", + "sft_mode": False, + "temperature": 0.3, + "top_k": 20, + "top_p": 0.85, + "min_p": 0.10, + "no_repeat_ngram_size": 4, + "repetition_penalty": 1.2, + "logit_soft_cap": 20.0, + "loop_penalty": 20.0, + "max_new_tokens": 4096, + "context_window": 2048, + }, + "pretrain-variants": { + "label": "Pretrain — Variants", + "desc": "free-form continuation, higher temperature, more exploration", + "sft_mode": False, + "temperature": 0.7, + "top_k": 60, + "top_p": 0.93, + "min_p": 0.04, + "no_repeat_ngram_size": 4, + "repetition_penalty": 1.12, + "logit_soft_cap": 20.0, + "loop_penalty": 12.0, + "max_new_tokens": 4096, + "context_window": 2048, + }, +} + +_MODE_LIST = list(MODES.keys()) + + +def pick_mode(is_pretrain: bool) -> dict: + """Prompt the user to choose a generation mode. Returns a config dict.""" + # Filter to relevant modes based on checkpoint type + candidates = [k for k in _MODE_LIST if ("pretrain" in k) == is_pretrain] + print("\nGeneration mode:") + for i, key in enumerate(candidates): + cfg = MODES[key] + print(f" [{i + 1}] {cfg['label']} — {cfg['desc']}") + while True: + try: + choice = input("Select mode [1]: ").strip() + except (EOFError, KeyboardInterrupt): + print() + sys.exit(0) + if not choice: + choice = "1" + if choice.isdigit() and 1 <= int(choice) <= len(candidates): + key = candidates[int(choice) - 1] + cfg = MODES[key] + print(f"Mode: {cfg['label']}") + return cfg + print(f"Enter a number 1-{len(candidates)}") + + +def _run_loop(bundle: dict, cfg: dict) -> None: + model = bundle["model"] + tokenizer = bundle["tokenizer"] + device = bundle["device"] + sft = cfg["sft_mode"] + prompt_label = "You" if sft else "Prompt" + print(f"\nModel ready on {device}. Type your message, or /quit to exit.") + print(f" temp={cfg['temperature']} top_k={cfg['top_k']} top_p={cfg['top_p']}") + print(f" min_p={cfg['min_p']} ng={cfg['no_repeat_ngram_size']} rp={cfg['repetition_penalty']}") + print(f" cap={cfg['logit_soft_cap']} loop_penalty={cfg['loop_penalty']}\n") + while True: + try: + prompt = input(f"{prompt_label}: ").strip() + except (EOFError, KeyboardInterrupt): + print() + break + if not prompt: + continue + if prompt in ("/quit", "/exit", "/q"): + break + if prompt == "/help": + print("Commands: /quit /exit /q /help /mode") + if sft: + print("Anything else is sent as a chat prompt.") + else: + print("Anything else is sent as a raw continuation prompt.") + continue + if prompt == "/mode": + print(f"Current: {cfg['label']} — {cfg['desc']}") + continue + print("AI: ", end="", flush=True) + generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_new_tokens=cfg["max_new_tokens"], + temperature=cfg["temperature"], + top_k=cfg["top_k"], + top_p=cfg["top_p"], + min_p=cfg["min_p"], + no_repeat_ngram_size=cfg["no_repeat_ngram_size"], + repetition_penalty=cfg["repetition_penalty"], + logit_soft_cap=cfg["logit_soft_cap"], + loop_penalty=cfg["loop_penalty"], + device=str(device), + sft_mode=cfg["sft_mode"], + stream=True, + context_window=cfg["context_window"], + ) + + + + +# --------------------------------------------------------------------------- +# Dynamic collection discovery +# --------------------------------------------------------------------------- + +_COLLECTION_SLUG = "CompactAI-O/tmlm-haiku-series" +_AUTHOR = "CompactAI-O" +_SEARCH = "TMLM-Haiku" + +_FALLBACK_COLLECTION = [ + {"version": "TMLM-Haiku-2.3", "hf_id": "CompactAI-O/TMLM-Haiku-2.3"}, + {"version": "TMLM-Haiku-2", "hf_id": "CompactAI-O/TMLM-Haiku-2"}, + {"version": "TMLM-Haiku-1.3", "hf_id": "CompactAI-O/TMLM-Haiku-1.3"}, + {"version": "TMLM-Haiku-1", "hf_id": "CompactAI-O/TMLM-Haiku-1"}, + {"version": "Glint-1", "hf_id": "CompactAI-O/Glint-1"}, +] + +_EXTRA_REPOS = ["CompactAI-O/Glint-1"] + + +def _probe_repo(hf_id: str) -> dict | None: + """Return entry dict for one repo, or None if no usable checkpoints found.""" + from huggingface_hub import list_repo_files + + try: + files = set(list_repo_files(hf_id)) + except Exception: + return None + + # Detect which subdirectory holds the checkpoints + subdir: str | None = None + for candidate in ("models", "model"): + if any(f.startswith(f"{candidate}/") for f in files): + subdir = candidate + break + + prefix = f"{subdir}/" if subdir else "" + + # Collect all .pt files in the checkpoint directory + pt_files = sorted( + f[len(prefix):] for f in files + if f.startswith(prefix) and f.endswith(".pt") + ) + + _LABELS = { + "model.pt": ("Chat (SFT)", False), + "model_rep.pt": ("Chat (anti-repetition)", False), + "pretrain.pt": ("Pretrain (base)", True), + } + + checkpoints = [] + for fname in pt_files: + label, is_pretrain = _LABELS.get(fname, (fname.removesuffix(".pt"), "pretrain" in fname)) + checkpoints.append((label, fname, is_pretrain)) + + if not checkpoints: + return None + + return { + "version": hf_id.split("/")[-1], + "hf_id": hf_id, + "subdir": subdir, + "checkpoints": checkpoints, + "desc": "", + } + + +def fetch_collection() -> list[dict]: + """Query HF for all CompactAI-O TMLM-Haiku models, newest first.""" + from huggingface_hub import HfApi + + print("Checking HuggingFace collection for available models...") + try: + api = HfApi() + infos = list( + api.list_models( + author=_AUTHOR, + search=_SEARCH, + sort="lastModified", + ) + ) + infos.sort(key=lambda m: getattr(m, "lastModified", ""), reverse=True) + except Exception as exc: + print(f" Could not reach HuggingFace ({exc}); using fallback list.") + infos = [type("M", (), {"id": e["hf_id"]})() for e in _FALLBACK_COLLECTION] + + entries = [] + seen_ids: set = set() + for info in infos: + repo_id = info.id + if _SEARCH.lower() not in repo_id.lower(): + continue + entry = _probe_repo(repo_id) + if entry: + entries.append(entry) + seen_ids.add(repo_id) + + # Always include extra repos (e.g. Glint-1) not caught by TMLM-Haiku search + for repo_id in _EXTRA_REPOS: + if repo_id not in seen_ids: + entry = _probe_repo(repo_id) + if entry: + entries.append(entry) + seen_ids.add(repo_id) + + if not entries: + print(" No models found; using fallback list.") + for fb in _FALLBACK_COLLECTION: + e = _probe_repo(fb["hf_id"]) + if e: + entries.append(e) + + return entries + + +# --------------------------------------------------------------------------- +# Download helper +# --------------------------------------------------------------------------- + + +def _download_version(entry: dict, cache_dir: Path) -> Path: + """Download full repo snapshot; return the directory containing model files.""" + try: + from huggingface_hub import snapshot_download + except ImportError: + print("huggingface_hub not installed. Run: pip install huggingface_hub") + sys.exit(1) + + hf_id = entry["hf_id"] + print(f"Fetching {hf_id} ...") + try: + local_dir = Path(snapshot_download(repo_id=hf_id, cache_dir=str(cache_dir))) + except Exception as exc: + print(f"Download failed: {exc}") + sys.exit(1) + + subdir = entry.get("subdir") + model_dir = (local_dir / subdir) if subdir else local_dir + if not model_dir.exists(): + # Fallback to root + model_dir = local_dir + return model_dir + + +# --------------------------------------------------------------------------- +# Selection prompts +# --------------------------------------------------------------------------- + + +def _prompt_int(prompt: str, lo: int, hi: int, default: int = 1) -> int: + while True: + try: + raw = input(f"{prompt} [{default}]: ").strip() + except (EOFError, KeyboardInterrupt): + print() + sys.exit(0) + if not raw: + return default + if raw.isdigit() and lo <= int(raw) <= hi: + return int(raw) + print(f" Enter a number {lo}–{hi}.") + + +def pick_version(collection: list[dict]) -> dict: + print("\nTMLM-Haiku series (CompactAI-O)\n") + for i, entry in enumerate(collection): + desc = f" — {entry['desc']}" if entry["desc"] else "" + print(f" [{i + 1}] {entry['version']}{desc}") + idx = _prompt_int("Select version", 1, len(collection)) + return collection[idx - 1] + + +def pick_checkpoint(entry: dict) -> tuple[str, bool]: + """Return (filename, is_pretrain).""" + ckpts = entry["checkpoints"] + if len(ckpts) == 1: + label, fname, is_pretrain = ckpts[0] + print(f" Using: {label} ({fname})") + return fname, is_pretrain + + print(f"\nCheckpoints for {entry['version']}:") + for i, (label, fname, _) in enumerate(ckpts): + print(f" [{i + 1}] {label} ({fname})") + idx = _prompt_int("Select checkpoint", 1, len(ckpts)) + label, fname, is_pretrain = ckpts[idx - 1] + return fname, is_pretrain + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + + +# --------------------------------------------------------------------------- +# Gradio Space +# --------------------------------------------------------------------------- + +import gradio as gr + +_CACHE_DIR = Path(__file__).parent / ".hf_cache" +_CACHE_DIR.mkdir(parents=True, exist_ok=True) + +_collection_cache: list = [] +_model_cache: dict = {} + + +def _get_collection() -> list: + global _collection_cache + if not _collection_cache: + try: + _collection_cache = fetch_collection() + except Exception as e: + print(f"Warning: fetch_collection failed ({e}); using fallback.") + _collection_cache = [ + _probe_repo(e["hf_id"]) or {"version": e["version"], "hf_id": e["hf_id"], + "subdir": None, "checkpoints": [("Chat (SFT)", "model.pt", False)], "desc": ""} + for e in _FALLBACK_COLLECTION + ] + return _collection_cache + + +def _collection_versions() -> list[str]: + return [e["version"] for e in _get_collection()] + + +def _checkpoints_for(version: str) -> list[tuple[str, str, bool]]: + for e in _get_collection(): + if e["version"] == version: + return e["checkpoints"] + return [] + + +def _ckpt_labels(version: str) -> list[str]: + return [label for label, _, _ in _checkpoints_for(version)] + + +def _ckpt_is_pretrain(version: str, label: str) -> bool: + for lbl, _, is_pt in _checkpoints_for(version): + if lbl == label: + return is_pt + return False + + +def _ckpt_fname(version: str, label: str) -> str: + for lbl, fname, _ in _checkpoints_for(version): + if lbl == label: + return fname + return "model.pt" + + +def _load_bundle(version: str, ckpt_label: str) -> dict: + key = f"{version}/{ckpt_label}" + if key not in _model_cache: + fname = _ckpt_fname(version, ckpt_label) + for entry in _get_collection(): + if entry["version"] == version: + model_dir = _download_version(entry, _CACHE_DIR) + model_path = model_dir / fname + tokenizer_path = model_dir / "tokenizer.json" + _model_cache[key] = load_local_model(model_path, tokenizer_path, "Haiku") + break + return _model_cache[key] + + +def _build_conversation_prompt(history: list[dict], new_message: str) -> str: + """Flatten Gradio messages-format history + new turn into a raw prompt.""" + parts = [] + # history is [{role, content}, ...] pairs already in order + i = 0 + while i < len(history) - 1: + u = history[i] + a = history[i + 1] + if u["role"] == "user" and a["role"] == "assistant": + parts.append(f"<|user|>\n{u['content']}\n<|assistant|>\n{a['content']}") + i += 2 + parts.append(f"<|user|>\n{new_message}\n<|assistant|>\n") + return "".join(parts) + + +# ---- chat ---- + +def _on_version_change(version): + labels = _ckpt_labels(version) + return gr.update(choices=labels, value=labels[0] if labels else None) + + +def _chat_submit(message, history): + history = history or [] + history.append({"role": "user", "content": message}) + history.append({"role": "assistant", "content": ""}) + return "", history + + +def _chat_stream(history, version, ckpt_label, mode_key, use_custom, + temperature, top_k, top_p, min_p, rep_penalty, + ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode): + if not history or history[-1]["role"] != "assistant": + yield history + return + try: + bundle = _load_bundle(version, ckpt_label) + except Exception as e: + history[-1]["content"] = f"[Error loading model: {e}]" + yield history + return + + prior_msgs = history[:-2] # exclude the current user+empty-assistant pair + new_msg = history[-2]["content"] + + if use_custom: + cfg = { + "sft_mode": not raw_mode, + "temperature": temperature, "top_k": top_k, "top_p": top_p, + "min_p": min_p, "repetition_penalty": rep_penalty, + "no_repeat_ngram_size": ngram_size, "logit_soft_cap": soft_cap, + "loop_penalty": loop_pen, "max_new_tokens": max_tokens, + "context_window": ctx_win, + } + else: + cfg = dict(MODES[mode_key]) + + if prior_msgs: + prompt = _build_conversation_prompt(prior_msgs, new_msg) + sft = False + else: + prompt = new_msg + sft = cfg["sft_mode"] + + for partial in generate_stream( + model=bundle["model"], tokenizer=bundle["tokenizer"], + prompt=prompt, device=str(bundle["device"]), + sft_mode=sft, + temperature=cfg["temperature"], top_k=cfg["top_k"], + top_p=cfg["top_p"], min_p=cfg["min_p"], + repetition_penalty=cfg["repetition_penalty"], + no_repeat_ngram_size=cfg["no_repeat_ngram_size"], + logit_soft_cap=cfg["logit_soft_cap"], + loop_penalty=cfg["loop_penalty"], + max_new_tokens=cfg["max_new_tokens"], + context_window=cfg["context_window"], + ): + history[-1]["content"] = partial + yield history + + +# ---- compare ---- + +def _compare_fn(prompt, selected_versions, mode_key, use_custom, + temperature, top_k, top_p, min_p, rep_penalty, + ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode, + progress=gr.Progress(track_tqdm=True)): + if use_custom: + cfg = { + "sft_mode": not raw_mode, + "temperature": temperature, "top_k": top_k, "top_p": top_p, + "min_p": min_p, "repetition_penalty": rep_penalty, + "no_repeat_ngram_size": ngram_size, "logit_soft_cap": soft_cap, + "loop_penalty": loop_pen, "max_new_tokens": max_tokens, + "context_window": ctx_win, + } + else: + cfg = dict(MODES[mode_key]) + + all_versions = _collection_versions() + results = {} + for version in progress.tqdm(selected_versions or [], desc="Running models"): + labels = _ckpt_labels(version) + ckpt_label = labels[0] if labels else None + if not ckpt_label: + results[version] = "[No checkpoint found]" + continue + try: + bundle = _load_bundle(version, ckpt_label) + out = generate( + model=bundle["model"], tokenizer=bundle["tokenizer"], + prompt=prompt, device=str(bundle["device"]), + sft_mode=cfg["sft_mode"], + temperature=cfg["temperature"], top_k=cfg["top_k"], + top_p=cfg["top_p"], min_p=cfg["min_p"], + repetition_penalty=cfg["repetition_penalty"], + no_repeat_ngram_size=cfg["no_repeat_ngram_size"], + logit_soft_cap=cfg["logit_soft_cap"], + loop_penalty=cfg["loop_penalty"], + max_new_tokens=cfg["max_new_tokens"], + context_window=cfg["context_window"], + stream=False, + ) + results[version] = out + except Exception as e: + results[version] = f"[Error: {e}]" + + # Return one value per discovered version (empty string if not selected/run) + return [results.get(v, "") for v in all_versions] + + +# ---- benchmark ---- + +def _benchmark_fn(bench_key, selected_versions, max_samples, + progress=gr.Progress(track_tqdm=True)): + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + if not selected_versions: + return "No models selected.", None + + bench = BENCHMARKS[bench_key] + all_results = [] + log_lines = [f"Benchmark: {bench['label']}", ""] + + for version in progress.tqdm(selected_versions, desc="Benchmarking"): + log_lines.append(f"--- {version} ---") + labels = _ckpt_labels(version) + ckpt_label = labels[0] if labels else None + if not ckpt_label: + log_lines.append(" (no checkpoint)") + continue + try: + bundle = _load_bundle(version, ckpt_label) + model, tokenizer, device = bundle["model"], bundle["tokenizer"], str(bundle["device"]) + model.eval() + if bench_key == "blimp": + _, y = _run_blimp(model, tokenizer, device, n_samples=max_samples) + elif bench_key == "wikitext2": + _, y = _run_wikitext2(model, tokenizer, device, max_chunks=max_samples) + else: + _, y = _run_arc_easy(model, tokenizer, device, max_samples=max_samples) + valid = [v for v in y if not math.isnan(v)] + summary = sum(valid) / len(valid) if valid else float("nan") + all_results.append({"label": version, "summary": summary}) + log_lines.append(f" score: {summary:.4f}") + except Exception as e: + log_lines.append(f" error: {e}") + + if not all_results: + return "\n".join(log_lines), None + + metric = bench["metric"] + paired = sorted( + zip([r["summary"] for r in all_results], [r["label"] for r in all_results]), + reverse=(metric != "perplexity"), + ) + summaries, labels_ = zip(*paired) + n = len(summaries) + colors = [plt.cm.RdYlGn(i / max(n - 1, 1)) for i in range(n)] + fig, ax = plt.subplots(figsize=(max(6, n * 1.6), 5)) + bars = ax.bar(range(n), summaries, color=colors, edgecolor="black") + for bar, val in zip(bars, summaries): + ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, + f"{val:.3f}", ha="center", va="bottom", fontsize=9, fontweight="bold") + ylabel = "Mean Perplexity (↓ better)" if metric == "perplexity" else "Mean Accuracy (↑ better)" + ax.set_ylabel(ylabel) + ax.set_title(f"{bench['label']} — Model Comparison") + ax.set_xticks(range(n)) + ax.set_xticklabels(labels_, rotation=20, ha="right", fontsize=9) + if metric == "accuracy": + ax.set_ylim(0, 1.05) + ax.grid(True, axis="y", alpha=0.3) + plt.tight_layout() + out_path = "/tmp/benchmark_result.png" + plt.savefig(out_path, dpi=150) + plt.close(fig) + log_lines += ["", "Done."] + return "\n".join(log_lines), out_path + + +# ---- shared advanced params ---- + +def _advanced_block(): + with gr.Accordion("Advanced parameters", open=False): + use_custom = gr.Checkbox(label="Override preset with custom values below", value=False) + raw_mode = gr.Checkbox(label="Raw / pretrain mode (no <|user|> wrapping)", value=False) + with gr.Row(): + temperature = gr.Slider(0.0, 2.0, value=0.5, step=0.01, label="Temperature") + top_k = gr.Slider(0, 200, value=20, step=1, label="Top-k") + with gr.Row(): + top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") + min_p = gr.Slider(0.0, 0.5, value=0.05, step=0.01, label="Min-p") + with gr.Row(): + rep_penalty = gr.Slider(1.0, 2.0, value=1.15, step=0.01, label="Repetition penalty") + ngram_size = gr.Slider(0, 8, value=4, step=1, label="No-repeat n-gram size") + with gr.Row(): + soft_cap = gr.Slider(0.0, 50.0, value=20.0, step=0.5, label="Logit soft cap") + loop_pen = gr.Slider(0.0, 50.0, value=15.0, step=0.5, label="Loop penalty") + with gr.Row(): + max_tokens = gr.Slider(16, 2048, value=512, step=16, label="Max new tokens") + ctx_win = gr.Slider(128, 2048, value=2048, step=128, label="Context window") + return use_custom, temperature, top_k, top_p, min_p, rep_penalty, ngram_size, soft_cap, loop_pen, max_tokens, ctx_win, raw_mode + + +# ---- build UI ---- + +_initial_versions = _collection_versions() +_initial_version = _initial_versions[0] if _initial_versions else None +_initial_ckpt_labels = _ckpt_labels(_initial_version) if _initial_version else [] +_mode_keys = list(MODES.keys()) + +with gr.Blocks(title="CompactAI Models", theme=gr.themes.Soft()) as demo: + gr.Markdown( + "# CompactAI — TinyMemoryLM\n" + "Tiny recurrent-depth language models from [CompactAI-O](https://huggingface.co/CompactAI-O)." + ) + + # ── Chat ────────────────────────────────────────────────────────────────── + with gr.Tab("Chat"): + with gr.Row(): + with gr.Column(scale=1, min_width=240): + chat_version = gr.Dropdown( + choices=_initial_versions, + value=_initial_version, + label="Model version", + ) + chat_ckpt = gr.Dropdown( + choices=_initial_ckpt_labels, + value=_initial_ckpt_labels[0] if _initial_ckpt_labels else None, + label="Checkpoint", + ) + chat_mode = gr.Radio( + choices=_mode_keys, + value="chat-coherent", + label="Mode preset", + info="Ignored when 'Override preset' is checked.", + ) + c_use_custom, c_temp, c_topk, c_topp, c_minp, c_rep, c_ng, c_cap, c_lp, c_maxt, c_ctx, c_raw = _advanced_block() + + with gr.Column(scale=3): + chatbot = gr.Chatbot(label="Conversation", height=500, type="messages") + with gr.Row(): + msg_box = gr.Textbox(placeholder="Type a message…", show_label=False, scale=5) + send_btn = gr.Button("Send", variant="primary", scale=1) + clear_btn = gr.Button("Clear") + + chat_version.change(_on_version_change, chat_version, chat_ckpt) + + _chat_adv = [chat_version, chat_ckpt, chat_mode, + c_use_custom, c_temp, c_topk, c_topp, c_minp, + c_rep, c_ng, c_cap, c_lp, c_maxt, c_ctx, c_raw] + + msg_box.submit(_chat_submit, [msg_box, chatbot], [msg_box, chatbot], queue=False).then( + _chat_stream, [chatbot] + _chat_adv, chatbot + ) + send_btn.click(_chat_submit, [msg_box, chatbot], [msg_box, chatbot], queue=False).then( + _chat_stream, [chatbot] + _chat_adv, chatbot + ) + clear_btn.click(lambda: [], None, chatbot, queue=False) + + # ── Compare ─────────────────────────────────────────────────────────────── + with gr.Tab("Compare All Models"): + gr.Markdown("Run the same prompt on multiple models and compare side-by-side.") + with gr.Row(): + cmp_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt…", lines=3, scale=3) + with gr.Column(scale=1): + cmp_models = gr.CheckboxGroup( + choices=_initial_versions, value=_initial_versions, label="Models" + ) + cmp_mode = gr.Dropdown( + choices=_mode_keys, value="chat-coherent", label="Mode preset" + ) + cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw = _advanced_block() + cmp_run = gr.Button("Run comparison", variant="primary") + + with gr.Row(): + cmp_outputs = [ + gr.Textbox(label=v, lines=8, show_copy_button=True) + for v in _initial_versions + ] + + cmp_run.click( + _compare_fn, + inputs=[cmp_prompt, cmp_models, cmp_mode, + cmp_use_custom, cmp_temp, cmp_topk, cmp_topp, cmp_minp, + cmp_rep, cmp_ng, cmp_cap, cmp_lp, cmp_maxt, cmp_ctx, cmp_raw], + outputs=cmp_outputs, + ) + + # ── Benchmark ───────────────────────────────────────────────────────────── + with gr.Tab("Benchmark"): + gr.Markdown( + "Evaluate models on standard benchmarks.\n\n" + "- **BLiMP** — grammaticality minimal pairs (accuracy)\n" + "- **WikiText-2** — LM perplexity (lower = better)\n" + "- **ARC-Easy** — multiple-choice science QA (accuracy)" + ) + with gr.Row(): + bench_type = gr.Radio( + choices=list(BENCHMARKS.keys()), value="arc_easy", label="Benchmark" + ) + bench_models = gr.CheckboxGroup( + choices=_initial_versions, + value=[_initial_versions[0]] if _initial_versions else [], + label="Models", + ) + bench_samples = gr.Slider(10, 500, value=100, step=10, label="Max samples (fewer = faster)") + bench_run = gr.Button("Run benchmark", variant="primary") + with gr.Row(): + bench_log = gr.Textbox(label="Progress log", lines=12, interactive=False) + bench_plot = gr.Image(label="Results chart", type="filepath") + + bench_run.click( + _benchmark_fn, + inputs=[bench_type, bench_models, bench_samples], + outputs=[bench_log, bench_plot], + ) + + +if __name__ == "__main__": + demo.launch()