Homepage / downloads /interactive.py
CompactAI's picture
Upload 88 files
093ccf9 verified
#!/usr/bin/env python3
from __future__ import annotations
import json
import math
import os
import re
import shutil
import socket
import string
import sys
import threading
import webbrowser
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
from urllib.parse import quote, unquote, urlparse
from urllib.request import Request, urlopen
import hashlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
# ---------------------------------------------------------------------------
# Config (from ailay.config)
# ---------------------------------------------------------------------------
@dataclass
class ModelConfig:
dim: int = 128
n_unique_layers: int = 8
n_logical_layers: int = 16
n_heads: int = 4
n_kv_heads: int = 2
ffn_dim: int = 224
dropout: float = 0.0
seq_len: int = 2048
sliding_window_size: int = 512
mtp_horizons: Tuple[int, ...] = (2, 3, 4)
rope_fraction: float = 0.5
embed_scale: bool = True
logit_soft_cap: float = -1.0
quantization: str = "nvfp4"
# Engram (conditional memory) config
engram_dim: int = 0
engram_heads: int = 4
engram_table_size: int = 8192
engram_max_ngram: int = 3
# mHC (Manifold-Constrained Hyper-Connections) config
mhc_expansion: int = 1
@property
def head_dim(self) -> int:
return self.dim // self.n_heads
model_config = ModelConfig()
MODEL_SERIES = {
"haiku": {
"dim": 64,
"n_unique_layers": 12,
"n_logical_layers": 24,
"n_heads": 4,
"n_kv_heads": 2,
"ffn_dim": 384,
"dropout": 0.0,
"seq_len": 2048,
"mtp_horizons": (2, 3, 4),
"rope_fraction": 0.5,
"batch_size": 80,
"grad_accum": 1,
"lr": 8e-4,
"min_lr": 1e-5,
"sft_lr": 2e-4,
"sft_min_lr": 1e-5,
"warmup_steps": 300,
"weight_decay": 0.02,
"pretrain_passes": 2,
"sft_passes": 3,
"max_sft_target_chars": 0,
"use_grad_checkpoint": True,
"num_workers": 24,
"prefetch_factor": 64,
"shuffle_buffer": 8192,
"max_pretrain_tokens": 0,
"min_pretrain_tokens": 100_000_000,
"quantization": "nvfp4",
"engram_dim": 8,
"engram_heads": 2,
"engram_table_size": 64,
"engram_max_ngram": 2,
"mhc_expansion": 2,
},
"sonnet": {
"dim": 1024,
"n_unique_layers": 20,
"n_logical_layers": 40,
"n_heads": 16,
"n_kv_heads": 4,
"ffn_dim": 4096,
"dropout": 0.0,
"seq_len": 2048,
"mtp_horizons": (2,),
"rope_fraction": 0.5,
"batch_size": 24,
"grad_accum": 1,
"lr": 1e-4,
"min_lr": 2e-5,
"sft_lr": 5e-5,
"sft_min_lr": 5e-6,
"warmup_steps": 250,
"weight_decay": 0.1,
"pretrain_passes": 1,
"sft_passes": 1,
"max_sft_target_chars": 0,
"use_grad_checkpoint": True,
"num_workers": 32,
"prefetch_factor": 64,
"shuffle_buffer": 16384,
"max_pretrain_tokens": 0,
"min_pretrain_tokens": 100_000_000,
"quantization": "nvfp4",
"engram_dim": 32,
"engram_heads": 8,
"engram_table_size": 4096,
"engram_max_ngram": 2,
"mhc_expansion": 2,
},
"opus": {
"dim": 1536,
"n_unique_layers": 18,
"n_logical_layers": 36,
"n_heads": 16,
"n_kv_heads": 4,
"ffn_dim": 5888,
"dropout": 0.0,
"seq_len": 2048,
"mtp_horizons": (2,),
"rope_fraction": 0.5,
"batch_size": 24,
"grad_accum": 1,
"lr": 1.6e-4,
"min_lr": 1.6e-5,
"sft_lr": 3e-5,
"sft_min_lr": 3e-6,
"warmup_steps": 200,
"weight_decay": 0.1,
"pretrain_passes": 1,
"sft_passes": 1,
"max_sft_target_chars": 0,
"use_grad_checkpoint": True,
"num_workers": 48,
"prefetch_factor": 64,
"shuffle_buffer": 16384,
"max_pretrain_tokens": 0,
"min_pretrain_tokens": 100_000_000,
"quantization": "nvfp4",
"engram_dim": 64,
"engram_heads": 8,
"engram_table_size": 8192,
"engram_max_ngram": 2,
"mhc_expansion": 4,
},
}
# ---------------------------------------------------------------------------
# Tokenizer (from ailay.tokenizer)
# ---------------------------------------------------------------------------
FORMAT_TOKENS = [
"<|user|>",
"<|assistant|>",
"<|system|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|begin_of_thought|>",
"<|end_of_thought|>",
"<|begin_of_solution|>",
"<|end_of_solution|>",
]
class WordTokenizer:
WORD_RE = re.compile(
r"\s+|[^\W\d_]+(?:['\u2019][^\W\d_]+)?|\d+|[^\w\s]+", re.UNICODE
)
def __init__(
self, extra_chars: str = "", format_tokens: Optional[List[str]] = None
) -> None:
base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r"
fallback_chars = sorted(set(base + extra_chars))
self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]
self.format_tokens = (
list(format_tokens) if format_tokens else list(FORMAT_TOKENS)
)
self.special = list(self.core_special) + list(self.format_tokens)
self.id_to_token: List[str] = (
list(self.core_special) + self.format_tokens + fallback_chars
)
self.token_to_id: Dict[str, int] = {
t: i for i, t in enumerate(self.id_to_token)
}
self.special_multi_tokens = sorted(
[t for t in self.special if len(t) > 1], key=len, reverse=True
)
self.multi_char_tokens = self.special_multi_tokens
self.dynamic_additions = 0
@property
def pad_id(self) -> int:
return self.token_to_id["<PAD>"]
@property
def bos_id(self) -> int:
return self.token_to_id["<BOS>"]
@property
def eos_id(self) -> int:
return self.token_to_id["<EOS>"]
@property
def unk_id(self) -> int:
return self.token_to_id["<UNK>"]
@property
def vocab_size(self) -> int:
return len(self.id_to_token)
def maybe_add_char(self, ch: str) -> bool:
if ch in self.token_to_id:
return False
self.token_to_id[ch] = len(self.id_to_token)
self.id_to_token.append(ch)
self.dynamic_additions += 1
return True
def maybe_add_token(self, token: str) -> bool:
if token in self.token_to_id:
return False
self.token_to_id[token] = len(self.id_to_token)
self.id_to_token.append(token)
self.dynamic_additions += 1
return True
def iter_lexical_tokens(self, text: str) -> Iterator[str]:
i = 0
n = len(text)
while i < n:
matched_special = False
for token in self.special_multi_tokens:
if text.startswith(token, i):
yield token
i += len(token)
matched_special = True
break
if matched_special:
continue
m = self.WORD_RE.match(text, i)
if m is None:
yield text[i]
i += 1
continue
tok = m.group(0)
yield tok
i = m.end()
def encode(
self, text: str, add_bos: bool = False, add_eos: bool = False
) -> List[int]:
out: List[int] = []
if add_bos:
out.append(self.bos_id)
unk = self.unk_id
t2i = self.token_to_id
for tok in self.iter_lexical_tokens(text):
tid = t2i.get(tok)
if tid is not None:
out.append(tid)
continue
for ch in tok:
out.append(t2i.get(ch, unk))
if add_eos:
out.append(self.eos_id)
return out
def decode(self, ids: Sequence[int], skip_special: bool = True) -> str:
pieces: List[str] = []
for idx in ids:
if int(idx) < 0 or int(idx) >= len(self.id_to_token):
continue
tok = self.id_to_token[int(idx)]
if skip_special and tok in self.special:
continue
pieces.append(tok)
return "".join(pieces)
def save(self, path: Path) -> None:
with path.open("w", encoding="utf-8") as f:
json.dump(
{
"id_to_token": self.id_to_token,
"format_tokens": self.format_tokens,
"core_special": self.core_special,
"tokenizer_type": "word_level_v1",
},
f,
ensure_ascii=False,
indent=2,
)
@classmethod
def load(cls, path: Path) -> WordTokenizer:
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
format_tokens = data.get("format_tokens", FORMAT_TOKENS)
tokenizer = cls(extra_chars="", format_tokens=format_tokens)
tokenizer.id_to_token = data["id_to_token"]
tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)}
tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens)
tokenizer.special_multi_tokens = sorted(
[t for t in tokenizer.special if len(t) > 1], key=len, reverse=True
)
tokenizer.multi_char_tokens = tokenizer.special_multi_tokens
return tokenizer
LetterTokenizer = WordTokenizer
# ---------------------------------------------------------------------------
# Model (from ailay.model)
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
if hasattr(torch.nn.functional, "rms_norm"):
return torch.nn.functional.rms_norm(
x, self.weight.shape, self.weight, self.eps
)
x_fp = x.float()
rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (x_fp * rms).to(dtype=x.dtype) * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, base: float = 10000.0) -> None:
super().__init__()
inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv, persistent=False)
def cos_sin(
self, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos()[None, None, :, :].to(dtype=dtype)
sin = emb.sin()[None, None, :, :].to(dtype=dtype)
return cos, sin
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class CausalSelfAttention(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
dropout: float,
sliding_window: int,
rope_fraction: float,
) -> None:
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.n_rep = n_heads // n_kv_heads
self.dropout = dropout
self.sliding_window = sliding_window
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
for lin in (self.wq, self.wk, self.wv):
nn.init.normal_(lin.weight, std=dim ** -0.5)
nn.init.normal_(self.wo.weight, std=(n_heads * head_dim) ** -0.5)
self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
self.rope = RotaryEmbedding(self.rope_dim)
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
self.output_gate = nn.Parameter(torch.zeros(n_heads))
def forward(
self,
x: torch.Tensor,
is_global: bool,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
past_len = past_kv[0].shape[2] if past_kv is not None else 0
cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype)
cos_slice = cos[:, :, past_len : past_len + T, :]
sin_slice = sin[:, :, past_len : past_len + T, :]
q_rope = q[..., : self.rope_dim]
q_pass = q[..., self.rope_dim :]
k_rope = k[..., : self.rope_dim]
k_pass = k[..., self.rope_dim :]
q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice)
k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice)
q = torch.cat([q_rope, q_pass], dim=-1)
k = torch.cat([k_rope, k_pass], dim=-1)
if past_kv is not None:
k = torch.cat([past_kv[0], k], dim=2)
v = torch.cat([past_kv[1], v], dim=2)
new_kv = (k, v) if use_cache else None
S = k.shape[2]
if self.n_rep > 1:
k = (
k[:, :, None, :, :]
.expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
.reshape(B, self.n_heads, S, self.head_dim)
)
v = (
v[:, :, None, :, :]
.expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
.reshape(B, self.n_heads, S, self.head_dim)
)
drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0
if is_global:
if past_kv is None and T > 1:
out = F.scaled_dot_product_attention(
q, k, v, is_causal=True, dropout_p=drop_p
)
else:
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p)
else:
T_q = q.shape[2]
q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1)
k_pos = torch.arange(S, device=q.device).unsqueeze(0)
mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p
)
gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1)
out = out * gate
out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
out = self.wo(out)
return out, new_kv
class SwiGLU(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
super().__init__()
self.gate = nn.Linear(dim, hidden_dim, bias=False)
self.up = nn.Linear(dim, hidden_dim, bias=False)
self.down = nn.Linear(hidden_dim, dim, bias=False)
self.drop = nn.Dropout(dropout)
nn.init.normal_(self.gate.weight, std=dim ** -0.5)
nn.init.normal_(self.up.weight, std=dim ** -0.5)
nn.init.normal_(self.down.weight, std=hidden_dim ** -0.5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = F.silu(self.gate(x)) * self.up(x)
out = self.down(h)
if self.training and torch.is_grad_enabled():
out = self.drop(out)
return out
class EngramBlock(nn.Module):
"""Conditional memory via O(1) hashed N-gram lookup (DeepSeek Engram)."""
def __init__(
self,
dim: int,
engram_dim: int,
n_heads: int = 4,
table_size: int = 8192,
max_ngram: int = 3,
) -> None:
super().__init__()
self.dim = dim
self.engram_dim = engram_dim
self.n_heads = n_heads
self.table_size = table_size
self.max_ngram = max_ngram
self.embeddings = nn.ParameterDict()
for n in range(2, max_ngram + 1):
for k in range(n_heads):
self.embeddings[f"{n}_{k}"] = nn.Parameter(
torch.randn(table_size, engram_dim) * (engram_dim**-0.5)
)
for n in range(2, max_ngram + 1):
for k in range(n_heads):
seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16)
rng = torch.Generator().manual_seed(seed)
a = torch.randint(1, 2**31, (1,), generator=rng).item()
b = torch.randint(0, 2**31, (1,), generator=rng).item()
self.register_buffer(
f"hash_a_{n}_{k}", torch.tensor(a), persistent=False
)
self.register_buffer(
f"hash_b_{n}_{k}", torch.tensor(b), persistent=False
)
total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
self.branch_conv = nn.Conv1d(
total_branch_dim,
total_branch_dim,
kernel_size=4,
dilation=max_ngram,
padding=0,
groups=total_branch_dim,
bias=True,
)
nn.init.zeros_(self.branch_conv.weight)
nn.init.zeros_(self.branch_conv.bias)
self.gate_query = nn.Linear(dim, engram_dim, bias=False)
self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False)
self.gate_value = nn.Linear(total_branch_dim, dim, bias=False)
self.gate_scale = engram_dim**-0.5
def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor:
a = getattr(self, f"hash_a_{n}_{k}")
b = getattr(self, f"hash_b_{n}_{k}")
B, T = token_ids.shape
padded = F.pad(token_ids, (n - 1, 0), value=0)
combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device)
for i in range(n):
combined = (combined * 31 + padded[:, i : i + T].long()) % self.table_size
return ((a * combined) ^ b) % self.table_size
def forward(
self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
B, T, _ = hidden.shape
if token_ids is None:
token_ids = hidden.mean(dim=-1).long() % self.table_size
all_indices = []
all_tables = []
for n in range(2, self.max_ngram + 1):
for k in range(self.n_heads):
all_indices.append(self._hash_ngram(token_ids, n, k))
all_tables.append(self.embeddings[f"{n}_{k}"])
branch_outputs = [tbl[idx] for idx, tbl in zip(all_indices, all_tables)]
memory = torch.cat(branch_outputs, dim=-1)
conv_in = memory.transpose(1, 2)
conv_in = F.pad(
conv_in,
(self.branch_conv.dilation[0] * (self.branch_conv.kernel_size[0] - 1), 0),
)
conv_out = self.branch_conv(conv_in)
memory = conv_out.transpose(1, 2)
query = self.gate_query(hidden)
key = self.gate_key(memory)
gate = torch.sigmoid(
(query * key).sum(dim=-1, keepdim=True) * self.gate_scale
)
value = self.gate_value(memory)
return gate * value
def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
M = torch.exp(logits.clamp(-10, 10))
for _ in range(n_iters):
M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10)
M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10)
return M
class ManifoldHyperConnection(nn.Module):
"""Manifold-Constrained Hyper-Connections (mHC) residual wrapper."""
def __init__(self, dim: int, expansion: int = 2) -> None:
super().__init__()
self.dim = dim
self.expansion = expansion
n = expansion
self.bias_pre = nn.Parameter(torch.zeros(1, n))
self.bias_post = nn.Parameter(torch.zeros(1, n))
self.bias_res = nn.Parameter(torch.zeros(n, n))
self.theta_pre = nn.Linear(n * dim, n, bias=False)
self.theta_post = nn.Linear(n * dim, n, bias=False)
self.theta_res = nn.Linear(n * dim, n * n, bias=False)
self.alpha_pre = nn.Parameter(torch.tensor(0.0))
self.alpha_post = nn.Parameter(torch.tensor(0.0))
self.alpha_res = nn.Parameter(torch.tensor(0.0))
def _compute_mappings(
self, x_expanded: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, _ = x_expanded.shape
n = self.expansion
x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]])
d_pre = torch.tanh(self.theta_pre(x_norm))
d_post = torch.tanh(self.theta_post(x_norm))
d_res = self.theta_res(x_norm)
H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre)
H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post)
H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape(
B, T, n, n
)
H_res = _sinkhorn_knopp(H_res_raw)
return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res
def expand_stream(self, x: torch.Tensor) -> torch.Tensor:
return x.repeat(1, 1, self.expansion)
def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor:
B, T, _ = x_expanded.shape
return x_expanded.view(B, T, self.expansion, self.dim).mean(dim=-2)
def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor:
B, T, _ = x_expanded.shape
x_streams = x_expanded.view(B, T, self.expansion, self.dim)
return (H_pre @ x_streams).squeeze(-2)
def post_res_mix(
self,
layer_output: torch.Tensor,
x_expanded: torch.Tensor,
H_post: torch.Tensor,
H_res: torch.Tensor,
) -> torch.Tensor:
B, T, _ = x_expanded.shape
x_streams = x_expanded.view(B, T, self.expansion, self.dim)
mixed = torch.matmul(H_res, x_streams)
post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2))
return (mixed + post_out).reshape(B, T, self.expansion * self.dim)
class TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
ffn_dim: int,
dropout: float,
sliding_window: int,
rope_fraction: float,
engram_dim: int = 0,
engram_heads: int = 4,
engram_table_size: int = 8192,
engram_max_ngram: int = 3,
mhc_expansion: int = 1,
) -> None:
super().__init__()
self.dim = dim
self.norm1 = RMSNorm(dim)
self.attn = CausalSelfAttention(
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
dropout=dropout,
sliding_window=sliding_window,
rope_fraction=rope_fraction,
)
self.norm2 = RMSNorm(dim)
self.ffn = SwiGLU(dim, ffn_dim, dropout)
self.use_engram = engram_dim > 0
if self.use_engram:
self.engram = EngramBlock(
dim=dim,
engram_dim=engram_dim,
n_heads=engram_heads,
table_size=engram_table_size,
max_ngram=engram_max_ngram,
)
self.engram_norm = RMSNorm(dim)
self.use_mhc = mhc_expansion > 1
if self.use_mhc:
self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
def forward(
self,
x: torch.Tensor,
is_global: bool,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
token_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
if self.use_mhc:
x_exp = self.mhc_attn.expand_stream(x)
H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp)
attn_in = self.mhc_attn.pre_mix(x_exp, H_pre)
attn_out, new_kv = self.attn(
self.norm1(attn_in), is_global, past_kv, use_cache
)
x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res)
if self.use_engram:
collapsed = self.mhc_attn.collapse_stream(x_exp)
collapsed = collapsed + self.engram(
self.engram_norm(collapsed), token_ids=token_ids
)
x_exp = self.mhc_attn.expand_stream(collapsed)
H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp)
ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2)
ffn_out = self.ffn(self.norm2(ffn_in))
x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2)
x = self.mhc_attn.collapse_stream(x_exp)
else:
attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
x = x + attn_out
if self.use_engram:
x = x + self.engram(self.engram_norm(x), token_ids=token_ids)
x = x + self.ffn(self.norm2(x))
return x, new_kv
def _detect_engram_dim(state_dict: dict) -> int:
for key in state_dict:
if ".engram." in key and ".embeddings." in key:
return state_dict[key].shape[-1]
return 0
def _detect_mhc_expansion(state_dict: dict) -> int:
for key, val in state_dict.items():
if ".mhc_attn.bias_pre" in key and val.dim() == 2:
return val.shape[-1]
return 1
class TinyMemoryLM(nn.Module):
def __init__(
self,
vocab_size: int,
dim: int,
n_unique_layers: int,
n_logical_layers: int,
n_heads: int,
n_kv_heads: int,
ffn_dim: int,
dropout: float,
mtp_horizons: Sequence[int],
grad_checkpoint: bool,
sliding_window: int = 512,
rope_fraction: float = 0.5,
embed_scale: bool = True,
engram_dim: int = 0,
engram_heads: int = 4,
engram_table_size: int = 8192,
engram_max_ngram: int = 3,
mhc_expansion: int = 1,
) -> None:
super().__init__()
self.dim = dim
self.n_unique_layers = n_unique_layers
self.n_logical_layers = n_logical_layers
self.grad_checkpoint = grad_checkpoint
self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0
head_dim = dim // n_heads
self.embed_tokens = nn.Embedding(vocab_size, dim)
self.head = nn.Linear(dim, vocab_size, bias=False)
self.head.weight = self.embed_tokens.weight
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
self.blocks = nn.ModuleList(
[
TransformerBlock(
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
ffn_dim=ffn_dim,
dropout=dropout,
sliding_window=sliding_window,
rope_fraction=rope_fraction,
engram_dim=engram_dim,
engram_heads=engram_heads,
engram_table_size=engram_table_size,
engram_max_ngram=engram_max_ngram,
mhc_expansion=mhc_expansion,
)
for _ in range(n_unique_layers)
]
)
self.norm = RMSNorm(dim)
self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
self.mtp_adapters = nn.ModuleDict(
{str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons}
)
self.mtp_norms = nn.ModuleDict(
{str(h): RMSNorm(dim) for h in self.mtp_horizons}
)
res_scale = (2 * n_logical_layers) ** -0.5
for block in self.blocks:
block.attn.wo.weight.data.mul_(res_scale)
block.ffn.down.weight.data.mul_(res_scale)
def resize_token_embeddings(self, new_vocab_size: int) -> None:
old_vocab_size = self.embed_tokens.num_embeddings
if new_vocab_size == old_vocab_size:
return
device = self.embed_tokens.weight.device
old_embed_weight = self.embed_tokens.weight.data.clone()
self.embed_tokens = nn.Embedding(
new_vocab_size, self.embed_tokens.embedding_dim
).to(device)
self.head = nn.Linear(
self.embed_tokens.embedding_dim, new_vocab_size, bias=False
).to(device)
self.head.weight = self.embed_tokens.weight
old_bias = self.output_bias.data.clone()
self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
copy_size = min(old_vocab_size, new_vocab_size)
self.output_bias.data[:copy_size] = old_bias[:copy_size]
self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
logical = []
blocks_list = list(self.blocks)
full_sequence = blocks_list + blocks_list
for logical_idx, block in enumerate(full_sequence[: self.n_logical_layers]):
logical.append((block, logical_idx))
return logical
def forward(
self,
ids: torch.Tensor,
use_cache: bool = False,
past_key_values: Optional[
List[Optional[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
return_hidden: bool = False,
) -> Tuple[
torch.Tensor,
Dict[int, torch.Tensor],
Optional[torch.Tensor],
Optional[List[Tuple[torch.Tensor, torch.Tensor]]],
]:
B, T = ids.shape
x = self.embed_tokens(ids) * self.embed_scale_factor
logical_layers = self._build_logical_layers()
new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = (
[] if use_cache else None
)
for layer_idx, (block, logical_idx) in enumerate(logical_layers):
is_global = logical_idx % 2 == 0
past_kv = (
past_key_values[layer_idx]
if past_key_values is not None and layer_idx < len(past_key_values)
else None
)
if self.grad_checkpoint and self.training and not use_cache:
x, layer_kv = checkpoint(
block, x, is_global, past_kv, use_cache, ids, use_reentrant=True
)
else:
x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
if new_past_key_values is not None:
new_past_key_values.append(layer_kv)
x = self.norm(x)
h_out = x if return_hidden else None
logits = self.head(x)
if self.embed_scale_factor != 1.0:
logits = logits / self.embed_scale_factor
logits = logits + self.output_bias
mtp: Dict[int, torch.Tensor] = {}
if self.mtp_horizons and self.training:
for horizon in self.mtp_horizons:
if horizon > 1 and horizon <= T - 1:
shifted_h = x[:, :-horizon, :]
adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
adapted_h = self.mtp_norms[str(horizon)](adapted_h)
mtp_logits = self.head(adapted_h)
if self.embed_scale_factor != 1.0:
mtp_logits = mtp_logits / self.embed_scale_factor
mtp_logits = mtp_logits + self.output_bias
mtp[horizon] = mtp_logits
return logits, mtp, h_out, new_past_key_values
# ---------------------------------------------------------------------------
# Generation (from ailay.generation)
# ---------------------------------------------------------------------------
def sample_text(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int,
temperature: float,
top_k: int,
branches: int,
branch_len: int,
device: torch.device,
seq_len: int,
) -> str:
def _sample_id(logits: torch.Tensor) -> torch.Tensor:
if not torch.isfinite(logits).any():
logits = torch.zeros_like(logits)
logits = torch.where(
torch.isfinite(logits), logits, torch.full_like(logits, -1e9)
)
if top_k > 0:
v, idx = torch.topk(logits, k=min(top_k, logits.shape[-1]))
p = torch.softmax(v, dim=-1)
return idx.gather(-1, torch.multinomial(p, 1))
p = torch.softmax(logits, dim=-1)
return torch.multinomial(p, 1)
model.eval()
ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
prompt_len = len(ids)
x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
with torch.no_grad():
generated = 0
while generated < max_new_tokens:
if branches <= 1:
ctx = x[:, -seq_len:]
logits, _, _, _ = model(ctx)
nlogits = logits[:, -1, :] / max(temperature, 1e-6)
nid = _sample_id(nlogits)
x = torch.cat([x, nid], dim=1)
generated += 1
continue
rollout = min(branch_len, max_new_tokens - generated)
best_nll: Optional[float] = None
best_tokens: Optional[List[torch.Tensor]] = None
for _ in range(branches):
cand = x
cand_tokens: List[torch.Tensor] = []
nll = 0.0
for _ in range(rollout):
ctx = cand[:, -seq_len:]
logits, _, _, _ = model(ctx)
nlogits = logits[:, -1, :] / max(temperature, 1e-6)
nid = _sample_id(nlogits)
lp = F.log_softmax(nlogits.float(), dim=-1)
nll += float(-lp.gather(-1, nid).item())
cand = torch.cat([cand, nid], dim=1)
cand_tokens.append(nid)
if best_nll is None or nll < best_nll:
best_nll = nll
best_tokens = cand_tokens
assert best_tokens is not None
for t in best_tokens:
x = torch.cat([x, t], dim=1)
generated += 1
generated_ids = x[0, prompt_len:].tolist()
return tokenizer.decode(generated_ids, skip_special=True)
def sample_text_cached(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int,
temperature: float,
top_k: int,
device: torch.device,
seq_len: int,
) -> str:
model.eval()
ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
prompt_len = len(ids)
x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
with torch.no_grad():
logits, _, _, past_kv = model(x, use_cache=True)
nlogits = logits[:, -1, :] / max(temperature, 1e-6)
if top_k > 0:
v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1]))
p = torch.softmax(v, dim=-1)
nid = idx.gather(-1, torch.multinomial(p, 1))
else:
p = torch.softmax(nlogits, dim=-1)
nid = torch.multinomial(p, 1)
all_ids = [int(nid.item())]
for _ in range(max_new_tokens - 1):
logits, _, _, past_kv = model(nid, use_cache=True, past_key_values=past_kv)
nlogits = logits[:, -1, :] / max(temperature, 1e-6)
if top_k > 0:
v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1]))
p = torch.softmax(v, dim=-1)
nid = idx.gather(-1, torch.multinomial(p, 1))
else:
p = torch.softmax(nlogits, dim=-1)
nid = torch.multinomial(p, 1)
tid = int(nid.item())
all_ids.append(tid)
if tid == tokenizer.eos_id:
break
return tokenizer.decode(all_ids, skip_special=True)
def speculative_decode(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int,
temperature: float,
top_k: int,
device: torch.device,
seq_len: int,
) -> str:
model.eval()
ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
all_generated: List[int] = []
with torch.no_grad():
logits, _, h_out, past_kv = model(x, use_cache=True, return_hidden=True)
def _sample_from(lg: torch.Tensor) -> int:
lg = lg / max(temperature, 1e-6)
if top_k > 0:
v, idx = torch.topk(lg, k=min(top_k, lg.shape[-1]))
p = torch.softmax(v, dim=-1)
return int(idx[torch.multinomial(p, 1)].item())
p = torch.softmax(lg, dim=-1)
return int(torch.multinomial(p, 1).item())
main_token = _sample_from(logits[0, -1, :])
all_generated.append(main_token)
while len(all_generated) < max_new_tokens:
if main_token == tokenizer.eos_id:
break
draft_tokens = []
if h_out is not None and model.mtp_horizons:
last_hidden = h_out[:, -1:, :]
for h in model.mtp_horizons:
adapter = model.mtp_adapters[str(h)]
norm = model.mtp_norms[str(h)]
adapted = norm(adapter(last_hidden))
draft_logits = model.head(adapted) + model.output_bias
draft_tok = _sample_from(draft_logits[0, 0, :])
draft_tokens.append(draft_tok)
if not draft_tokens:
nid = torch.tensor([[main_token]], dtype=torch.long, device=device)
logits, _, h_out, past_kv = model(
nid, use_cache=True, past_key_values=past_kv, return_hidden=True
)
main_token = _sample_from(logits[0, -1, :])
all_generated.append(main_token)
continue
verify_input = torch.tensor(
[[main_token] + draft_tokens], dtype=torch.long, device=device
)
verify_logits, _, h_out, past_kv = model(
verify_input,
use_cache=True,
past_key_values=past_kv,
return_hidden=True,
)
accepted = 0
all_generated.append(main_token) if main_token not in all_generated[
-1:
] else None
for i, draft_tok in enumerate(draft_tokens):
verified_tok = _sample_from(verify_logits[0, i, :])
if verified_tok == draft_tok:
all_generated.append(draft_tok)
accepted += 1
if draft_tok == tokenizer.eos_id:
break
else:
all_generated.append(verified_tok)
break
if accepted < len(draft_tokens):
trim_len = len(draft_tokens) - accepted - 1
if trim_len > 0 and past_kv is not None:
past_kv = [
(k[:, :, :-trim_len, :], v[:, :, :-trim_len, :])
if k is not None
else None
for k, v in past_kv
]
main_token = all_generated[-1]
return tokenizer.decode(all_generated, skip_special=True)
def build_stop_token_ids(tokenizer: WordTokenizer) -> set:
stop_tokens = {tokenizer.eos_id}
for tok in ("<|user|>", "<|system|>", "<|assistant|>"):
tid = tokenizer.token_to_id.get(tok)
if tid is not None:
stop_tokens.add(int(tid))
return stop_tokens
def apply_no_repeat_ngram(
logits: torch.Tensor,
token_history: Sequence[int],
ngram_size: int,
) -> torch.Tensor:
if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1):
return logits
prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple()
banned: set = set()
for i in range(len(token_history) - ngram_size + 1):
if tuple(token_history[i : i + ngram_size - 1]) == prefix:
banned.add(int(token_history[i + ngram_size - 1]))
if not banned:
return logits
out = logits.clone()
banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long)
out[banned_ids] = float("-inf")
return out
def score_candidate(
prompt: str,
raw_text: str,
visible_text: str,
avg_logprob: float,
) -> float:
clean = visible_text.strip()
if not clean:
return -1e9
score = avg_logprob
words = clean.lower().split()
prompt_words = re.findall(r"[A-Za-z][A-Za-z'-]{2,}", prompt.lower())
prompt_stop = {
"what",
"which",
"when",
"where",
"why",
"how",
"are",
"is",
"the",
"and",
"for",
"with",
"that",
"this",
"from",
"into",
"about",
"explain",
"tell",
"give",
"list",
"show",
"write",
"their",
"there",
"your",
}
prompt_keywords = {w for w in prompt_words if w not in prompt_stop}
candidate_keywords = set(re.findall(r"[A-Za-z][A-Za-z'-]{2,}", clean.lower()))
if len(words) < 6:
score -= 2.0
else:
score += min(2.0, len(words) * 0.03)
if clean[-1:] in ".!?":
score += 0.5
if "<|user|>" in raw_text or "<|system|>" in raw_text:
score -= 4.0
if raw_text.count("<|assistant|>") > 1:
score -= 2.0
if prompt_keywords:
overlap = len(prompt_keywords & candidate_keywords) / len(prompt_keywords)
if overlap == 0.0:
score -= 2.5
else:
score += min(3.5, overlap * 4.0)
for open_tok, close_tok in [
("<|begin_of_thought|>", "<|end_of_thought|>"),
("<|begin_of_solution|>", "<|end_of_solution|>"),
]:
if (open_tok in raw_text) != (close_tok in raw_text):
score -= 1.0
if len(words) >= 3:
trigrams = [tuple(words[i : i + 3]) for i in range(len(words) - 2)]
if trigrams:
unique_ratio = len(set(trigrams)) / len(trigrams)
if unique_ratio < 0.35:
score -= 4.0
elif unique_ratio < 0.55:
score -= 2.0
else:
score += min(1.0, (unique_ratio - 0.55) * 2.0)
alpha_words = [
w
for w in words
if len(w) <= 18 and (sum(ch.isalpha() for ch in w) / max(len(w), 1)) > 0.7
]
alpha_ratio = len(alpha_words) / max(len(words), 1)
if alpha_ratio < 0.45:
score -= 3.0
elif alpha_ratio < 0.65:
score -= 1.0
return score
def generate_candidate(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int,
temperature: float,
top_k: int,
repetition_penalty: float,
no_repeat_ngram_size: int,
device: str,
sft_mode: bool,
force_thought: bool,
stream: bool,
context_window: int,
) -> Tuple[str, str, float, int]:
if sft_mode:
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
else:
full_prompt = prompt
if force_thought:
full_prompt = f"{full_prompt}<|begin_of_thought|> "
input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device)
visible_tokens: List[str] = []
raw_tokens: List[str] = []
stop_token_ids = build_stop_token_ids(tokenizer)
total_logprob = 0.0
sampled_tokens = 0
with torch.no_grad():
for _ in range(max_new_tokens):
ctx_ids = (
input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
)
logits, _, _, _ = model(ctx_ids)
next_logits = logits[0, -1, :].clone()
raw_next_logits = next_logits.clone()
if repetition_penalty != 1.0:
seen = set(input_ids_t[0].tolist())
for token_id in seen:
if next_logits[token_id] > 0:
next_logits[token_id] /= repetition_penalty
else:
next_logits[token_id] *= repetition_penalty
if temperature != 1.0:
next_logits = next_logits / max(temperature, 1e-6)
if no_repeat_ngram_size > 1:
next_logits = apply_no_repeat_ngram(
next_logits,
input_ids_t[0].tolist(),
no_repeat_ngram_size,
)
if top_k > 0:
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0)))
next_logits[next_logits < v[-1]] = float("-inf")
top_p = 0.9
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[remove_mask] = float("-inf")
next_logits = sorted_logits.scatter(0, sorted_indices, sorted_logits)
if not torch.isfinite(next_logits).any():
next_logits = raw_next_logits
if temperature != 1.0:
next_logits = next_logits / max(temperature, 1e-6)
probs = torch.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1).item()
total_logprob += float(torch.log(probs[next_id] + 1e-12).item())
sampled_tokens += 1
if next_id in stop_token_ids:
break
token_str = (
tokenizer.id_to_token[next_id]
if next_id < len(tokenizer.id_to_token)
else ""
)
raw_tokens.append(token_str)
if token_str not in tokenizer.special:
visible_tokens.append(token_str)
if stream:
print(token_str, end="", flush=True)
input_ids_t = torch.cat(
[input_ids_t, torch.tensor([[next_id]], device=device)], dim=1
)
if stream:
print()
avg_logprob = total_logprob / max(1, sampled_tokens)
return "".join(visible_tokens), "".join(raw_tokens), avg_logprob, 0
def generate_beam_search(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int = 60,
beam_width: int = 8,
length_penalty: float = 0.7,
no_repeat_ngram_size: int = 3,
device: str = "cuda",
sft_mode: bool = False,
context_window: int = 2048,
) -> str:
if sft_mode:
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
else:
full_prompt = prompt
prompt_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
prompt_len = len(prompt_ids)
stop_ids = build_stop_token_ids(tokenizer)
beams: List[Tuple[float, List[int]]] = [(0.0, list(prompt_ids))]
completed: List[Tuple[float, List[int]]] = []
for _step in range(max_new_tokens):
if not beams:
break
candidates: List[Tuple[float, List[int]]] = []
for beam_score, beam_ids in beams:
x = torch.tensor(
[beam_ids[-context_window:]], dtype=torch.long, device=device
)
with torch.no_grad():
logits, _, _, _ = model(x)
nl = logits[0, -1, :]
log_probs = F.log_softmax(nl, dim=-1)
gen_ids = beam_ids[prompt_len:]
if no_repeat_ngram_size > 1 and len(gen_ids) >= no_repeat_ngram_size - 1:
prefix = tuple(gen_ids[-(no_repeat_ngram_size - 1) :])
for i in range(len(gen_ids) - no_repeat_ngram_size + 1):
if tuple(gen_ids[i : i + no_repeat_ngram_size - 1]) == prefix:
log_probs[gen_ids[i + no_repeat_ngram_size - 1]] = float("-inf")
topk_lp, topk_ids = torch.topk(log_probs, beam_width)
for i in range(beam_width):
tid = topk_ids[i].item()
new_score = beam_score + topk_lp[i].item()
new_ids = beam_ids + [tid]
if tid in stop_ids:
completed.append((new_score, new_ids))
else:
candidates.append((new_score, new_ids))
def _norm_score(pair):
gen_len = max(1, len(pair[1]) - prompt_len)
return pair[0] / (gen_len**length_penalty)
candidates.sort(key=_norm_score, reverse=True)
beams = candidates[:beam_width]
pool = completed + beams
if not pool:
return ""
def _norm_score_final(pair):
gen_len = max(1, len(pair[1]) - prompt_len)
return pair[0] / (gen_len**length_penalty)
pool.sort(key=_norm_score_final, reverse=True)
best_ids = pool[0][1][prompt_len:]
text = tokenizer.decode(best_ids, skip_special=True)
nl_pos = text.find("\n")
if nl_pos > 5:
text = text[:nl_pos]
return text.strip()
def generate(
model: TinyMemoryLM,
tokenizer: WordTokenizer,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.8,
top_k: int = 40,
repetition_penalty: float = 1.0,
device: str = "cuda",
sft_mode: bool = False,
force_thought: bool = False,
stream: bool = True,
decode_mode: str = "legacy",
best_of: int = 3,
no_repeat_ngram_size: int = 3,
context_window: int = 2048,
beam_width: int = 8,
length_penalty: float = 0.7,
) -> str:
if decode_mode == "beam":
text = generate_beam_search(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=max_new_tokens,
beam_width=beam_width,
length_penalty=length_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
device=device,
sft_mode=sft_mode,
context_window=context_window,
)
if stream:
print(text)
return text
if decode_mode == "legacy":
text, _, _, _ = generate_candidate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
device=device,
sft_mode=sft_mode,
force_thought=force_thought,
stream=stream,
context_window=context_window,
)
return text
candidates: List[Tuple[float, str, str, float]] = []
for _ in range(max(1, best_of)):
candidate_text, raw_text, avg_logprob, _ = generate_candidate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
device=device,
sft_mode=sft_mode,
force_thought=force_thought,
stream=False,
context_window=context_window,
)
score = score_candidate(prompt, raw_text, candidate_text, avg_logprob)
candidates.append((score, candidate_text, raw_text, avg_logprob))
best_score, best_text, _, _ = max(candidates, key=lambda item: item[0])
if stream:
print(best_text, end="", flush=True)
print()
return best_text
# ---------------------------------------------------------------------------
# Web server (from interactive.py)
# ---------------------------------------------------------------------------
ROOT = Path(__file__).resolve().parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
HF_ORG = "CompactAI"
HF_API = "https://huggingface.co/api"
CACHE_ROOT = Path.home() / ".cache" / "compactai_web"
USER_AGENT = "Mozilla/5.0 CompactAI-Web"
MODEL_CACHE: dict[tuple[str, str], dict[str, object]] = {}
MODEL_CACHE_LOCK = threading.RLock()
GENERATION_LOCK = threading.Lock()
def request_json(url: str):
req = Request(url, headers={"User-Agent": USER_AGENT})
with urlopen(req, timeout=60) as response:
return json.loads(response.read().decode("utf-8"))
def request_text(url: str) -> str:
req = Request(url, headers={"User-Agent": USER_AGENT})
with urlopen(req, timeout=60) as response:
return response.read().decode("utf-8", errors="replace")
def download_file(url: str, destination: Path) -> None:
destination.parent.mkdir(parents=True, exist_ok=True)
temp_path = destination.with_suffix(destination.suffix + ".tmp")
req = Request(url, headers={"User-Agent": USER_AGENT})
with urlopen(req, timeout=120) as response, temp_path.open("wb") as handle:
shutil.copyfileobj(response, handle)
temp_path.replace(destination)
def normalize_repo_id(raw_repo_id: str) -> str:
if not isinstance(raw_repo_id, str):
return ""
repo_id = raw_repo_id.strip()
if not repo_id:
return ""
try:
repo_id = unquote(repo_id)
except Exception:
pass
return (
repo_id.replace("https://huggingface.co/", "")
.replace("http://huggingface.co/", "")
.replace("api/models/", "")
.replace("models/", "")
.split("?", 1)[0]
.split("#", 1)[0]
.strip("/")
)
def series_from_name(name: str) -> str | None:
lower = (name or "").lower()
if "haiku" in lower:
return "Haiku"
if "sonnet" in lower:
return "Sonnet"
if "opus" in lower:
return "Opus"
return None
def encoded_repo_id(repo_id: str) -> str:
return "/".join(
quote(part, safe="") for part in normalize_repo_id(repo_id).split("/") if part
)
def hf_file_url(repo_id: str, filename: str) -> str:
encoded_name = "/".join(
quote(part, safe="") for part in filename.split("/") if part
)
return (
f"https://huggingface.co/{encoded_repo_id(repo_id)}/resolve/main/{encoded_name}"
)
def model_list() -> list[dict[str, object]]:
data = request_json(f"{HF_API}/models?author={quote(HF_ORG)}&full=true&limit=200")
models: list[dict[str, object]] = []
for item in data:
siblings = item.get("siblings") or []
filenames = [s.get("rfilename", "") for s in siblings if isinstance(s, dict)]
has_model = "model.pt" in filenames or "model/model.pt" in filenames
has_pretrain = "pretrain.pt" in filenames or "model/pretrain.pt" in filenames
has_tokenizer = (
"tokenizer.json" in filenames or "model/tokenizer.json" in filenames
)
if not has_model and not has_pretrain:
continue
name = (item.get("id") or "").split("/")[-1]
series = series_from_name(name)
if not series:
continue
models.append(
{
"id": item.get("id", ""),
"name": name,
"series": series,
"downloads": item.get("downloads", 0) or 0,
"likes": item.get("likes", 0) or 0,
"has_model": has_model,
"has_pretrain": has_pretrain,
"has_tokenizer": has_tokenizer,
}
)
return sorted(models, key=lambda entry: entry["downloads"], reverse=True)
def model_details(repo_id: str) -> dict[str, object] | None:
normalized = normalize_repo_id(repo_id)
if not normalized:
return None
data = request_json(f"{HF_API}/models/{encoded_repo_id(normalized)}")
siblings = data.get("siblings") or []
files: dict[str, dict[str, float]] = {}
has_model = False
has_pretrain = False
for sibling in siblings:
if not isinstance(sibling, dict):
continue
filename = sibling.get("rfilename") or ""
if not filename:
continue
size_mb = round((sibling.get("size") or 0) / (1024 * 1024), 2)
files[filename] = {"size_mb": size_mb}
if filename.startswith("model/"):
files[filename.removeprefix("model/")] = {"size_mb": size_mb}
if filename in {"model.pt", "model/model.pt"}:
has_model = True
if filename in {"pretrain.pt", "model/pretrain.pt"}:
has_pretrain = True
readme_raw = ""
try:
readme_raw = request_text(
f"https://huggingface.co/{encoded_repo_id(normalized)}/raw/main/README.md"
)
except Exception:
readme_raw = ""
name = (data.get("id") or normalized).split("/")[-1]
return {
"id": normalized,
"name": name,
"series": series_from_name(name) or "Sonnet",
"downloads": data.get("downloads", 0) or 0,
"files": files,
"readme_raw": readme_raw,
"hf_model_id": normalized,
"has_model": has_model,
"has_pretrain": has_pretrain,
}
def cache_dir(repo_id: str, model_type: str) -> Path:
return CACHE_ROOT / normalize_repo_id(repo_id).replace("/", "__") / model_type
def artifact_candidates(model_type: str) -> list[str]:
return (
["model/pretrain.pt", "pretrain.pt"]
if model_type == "pretrain"
else ["model/model.pt", "model.pt"]
)
def ensure_artifact(repo_id: str, model_type: str, destination_name: str) -> Path:
normalized = normalize_repo_id(repo_id)
target = cache_dir(normalized, model_type) / destination_name
if target.exists():
return target
last_error: Exception | None = None
for candidate in (
artifact_candidates(model_type)
if destination_name.endswith(".pt")
else ["model/tokenizer.json", "tokenizer.json"]
):
try:
download_file(hf_file_url(normalized, candidate), target)
return target
except Exception as exc:
last_error = exc
raise RuntimeError(
f"Unable to download {destination_name} for {normalized}: {last_error}"
)
def series_config(series: str) -> dict[str, object]:
return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"])
def load_bundle(repo_id: str, model_type: str) -> dict[str, object]:
normalized = normalize_repo_id(repo_id)
details = model_details(normalized)
if not details:
raise RuntimeError("Model details are unavailable.")
series = str(details["series"])
key = (normalized, model_type)
with MODEL_CACHE_LOCK:
cached = MODEL_CACHE.get(key)
if cached:
return cached
bundle_dir = cache_dir(normalized, model_type)
bundle_dir.mkdir(parents=True, exist_ok=True)
model_path = bundle_dir / (
"pretrain.pt" if model_type == "pretrain" else "model.pt"
)
tokenizer_path = bundle_dir / "tokenizer.json"
if not model_path.exists():
ensure_artifact(normalized, model_type, model_path.name)
if not tokenizer_path.exists():
ensure_artifact(normalized, model_type, tokenizer_path.name)
tokenizer = WordTokenizer.load(tokenizer_path)
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
cfg = series_config(series)
vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
# Auto-detect new arch features from checkpoint weights
engram_dim = _detect_engram_dim(state_dict) or int(
cfg.get("engram_dim", model_config.engram_dim)
)
mhc_expansion = _detect_mhc_expansion(state_dict) or int(
cfg.get("mhc_expansion", model_config.mhc_expansion)
)
model = TinyMemoryLM(
vocab_size=vocab_size,
dim=int(cfg.get("dim", model_config.dim)),
n_unique_layers=int(
cfg.get("n_unique_layers", model_config.n_unique_layers)
),
n_logical_layers=int(
cfg.get("n_logical_layers", model_config.n_logical_layers)
),
n_heads=int(cfg.get("n_heads", model_config.n_heads)),
n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
dropout=float(cfg.get("dropout", model_config.dropout)),
mtp_horizons=tuple(
int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)
),
grad_checkpoint=False,
sliding_window=int(
cfg.get("sliding_window_size", model_config.sliding_window_size)
),
rope_fraction=float(
cfg.get("rope_fraction", model_config.rope_fraction)
),
embed_scale=bool(
cfg.get("embed_scale", model_config.embed_scale)
),
engram_dim=engram_dim,
engram_heads=int(cfg.get("engram_heads", model_config.engram_heads)),
engram_table_size=int(
cfg.get("engram_table_size", model_config.engram_table_size)
),
engram_max_ngram=int(
cfg.get("engram_max_ngram", model_config.engram_max_ngram)
),
mhc_expansion=mhc_expansion,
)
model.load_state_dict(state_dict, strict=False)
model.eval()
if tokenizer.vocab_size > vocab_size:
model.resize_token_embeddings(tokenizer.vocab_size)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
bundle = {
"repo_id": normalized,
"name": details["name"],
"series": series,
"type": model_type,
"model": model,
"tokenizer": tokenizer,
"device": device,
"model_path": str(model_path),
"tokenizer_path": str(tokenizer_path),
"downloads": details["downloads"],
}
MODEL_CACHE[key] = bundle
return bundle
def ensure_port(start_port: int) -> int:
for port in range(start_port, start_port + 50):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
try:
sock.bind(("127.0.0.1", port))
except OSError:
continue
return port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
return sock.getsockname()[1]
def page_html() -> str:
return f"""<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>CompactAI Web</title>
<style>
:root {{
color-scheme: dark;
--bg: #050505;
--panel: #111111;
--panel-2: #161616;
--line: #262626;
--text: #f5f5f5;
--muted: #a3a3a3;
--accent: #d97706;
--accent-2: #b45309;
--soft: #1f1f1f;
}}
* {{ box-sizing: border-box; }}
body {{
margin: 0;
font-family: Geist, -apple-system, BlinkMacSystemFont, sans-serif;
background: var(--bg);
color: var(--text);
line-height: 1.5;
}}
a {{ color: inherit; }}
.wrap {{ max-width: 1120px; margin: 0 auto; padding: 28px 20px 40px; }}
.hero {{
display: flex;
justify-content: space-between;
align-items: end;
gap: 16px;
padding: 22px 0 28px;
border-bottom: 1px solid var(--line);
margin-bottom: 22px;
}}
h1 {{ margin: 0; font-size: clamp(2rem, 5vw, 3.5rem); letter-spacing: -0.04em; }}
.subtitle {{ margin: 10px 0 0; color: var(--muted); max-width: 58ch; }}
.grid {{
display: grid;
grid-template-columns: 1.1fr 1fr;
gap: 18px;
}}
.panel {{
background: var(--panel);
border: 1px solid var(--line);
border-radius: 18px;
padding: 18px;
}}
.panel h2 {{ margin: 0 0 12px; font-size: 15px; letter-spacing: 0.02em; text-transform: uppercase; color: var(--muted); }}
.row {{ display: flex; gap: 10px; flex-wrap: wrap; }}
select, textarea, input {{
width: 100%;
background: var(--panel-2);
color: var(--text);
border: 1px solid var(--line);
border-radius: 12px;
padding: 12px 14px;
font: inherit;
outline: none;
}}
textarea {{ min-height: 170px; resize: vertical; }}
select {{ appearance: none; }}
.choice {{
flex: 1 1 150px;
display: flex;
align-items: center;
gap: 10px;
padding: 10px 12px;
border: 1px solid var(--line);
border-radius: 12px;
background: var(--panel-2);
cursor: pointer;
}}
.choice input {{ width: auto; }}
.btns {{ display: flex; flex-wrap: wrap; gap: 10px; }}
button {{
border: 1px solid var(--line);
border-radius: 12px;
padding: 11px 14px;
background: var(--soft);
color: var(--text);
font: inherit;
cursor: pointer;
transition: transform 0.15s ease, border-color 0.15s ease, background 0.15s ease;
}}
button:hover {{ transform: translateY(-1px); border-color: #3a3a3a; }}
.primary {{ background: var(--accent); border-color: var(--accent); color: #fff; }}
.primary:hover {{ background: var(--accent-2); border-color: var(--accent-2); }}
.status {{
margin-top: 12px;
color: var(--muted);
font-size: 13px;
min-height: 1.4em;
}}
.output {{
white-space: pre-wrap;
background: #0b0b0b;
border: 1px solid var(--line);
border-radius: 16px;
min-height: 280px;
padding: 16px;
color: #e7e5e4;
overflow: auto;
}}
.meta {{
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 8px;
}}
.chip {{
display: inline-flex;
align-items: center;
gap: 6px;
padding: 6px 10px;
border-radius: 999px;
border: 1px solid var(--line);
background: var(--panel-2);
font-size: 12px;
color: var(--muted);
}}
.code {{
margin-top: 14px;
padding: 12px 14px;
border-radius: 12px;
border: 1px solid var(--line);
background: #0b0b0b;
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
font-size: 13px;
overflow-x: auto;
}}
@media (max-width: 900px) {{
.grid {{ grid-template-columns: 1fr; }}
.hero {{ align-items: start; flex-direction: column; }}
}}
</style>
</head>
<body>
<div class="wrap">
<div class="hero">
<div>
<h1>CompactAI Web</h1>
<p class="subtitle">Pull a model from Hugging Face, keep it cached locally, and chat in the browser.</p>
</div>
<div class="meta">
<span class="chip">Hugging Face: CompactAI</span>
<span class="chip">pip install -r requirements.txt</span>
<span class="chip">Local inference</span>
</div>
</div>
<div class="grid">
<section class="panel">
<h2>Model</h2>
<select id="modelSelect"></select>
<div class="row" style="margin-top: 10px;">
<label class="choice"><input type="radio" name="type" value="model" checked> Instruct / final</label>
<label class="choice"><input type="radio" name="type" value="pretrain"> Pretrain</label>
</div>
<div class="btns" style="margin-top: 12px;">
<button id="downloadBtn">Download</button>
<button id="refreshBtn">Refresh models</button>
</div>
<div class="status" id="modelStatus">Loading model list…</div>
<div class="code">python3 interactive_web.py</div>
</section>
<section class="panel">
<h2>Prompt</h2>
<textarea id="prompt" placeholder="Ask something…"></textarea>
<div class="row" style="margin-top: 10px;">
<input id="temperature" type="number" min="0.1" max="2" step="0.05" value="0.8" style="flex: 1 1 120px;">
<input id="topK" type="number" min="1" max="100" step="1" value="40" style="flex: 1 1 120px;">
<input id="maxTokens" type="number" min="16" max="2048" step="16" value="256" style="flex: 1 1 120px;">
</div>
<div class="btns" style="margin-top: 12px;">
<button id="generateBtn" class="primary">Generate</button>
</div>
<div class="status" id="genStatus"></div>
</section>
</div>
<section class="panel" style="margin-top: 18px;">
<h2>Response</h2>
<div id="output" class="output"></div>
</section>
</div>
<script>
const modelSelect = document.getElementById('modelSelect');
const modelStatus = document.getElementById('modelStatus');
const genStatus = document.getElementById('genStatus');
const output = document.getElementById('output');
const promptBox = document.getElementById('prompt');
async function api(path, body) {{
const response = await fetch(path, {{
method: body ? 'POST' : 'GET',
headers: body ? {{ 'Content-Type': 'application/json' }} : undefined,
body: body ? JSON.stringify(body) : undefined,
}});
return response.json();
}}
function currentType() {{
return document.querySelector('input[name="type"]:checked').value;
}}
function currentModelId() {{
return modelSelect.value;
}}
function setModels(models) {{
modelSelect.innerHTML = '';
for (const model of models) {{
const option = document.createElement('option');
option.value = model.id;
option.textContent = `${{model.name}} • ${{model.series}}`;
modelSelect.appendChild(option);
}}
if (models.length === 0) {{
const option = document.createElement('option');
option.value = '';
option.textContent = 'No CompactAI models found';
modelSelect.appendChild(option);
}}
}}
async function refreshModels() {{
modelStatus.textContent = 'Loading model list…';
try {{
const models = await api('/api/models');
setModels(models);
modelStatus.textContent = models.length ? `${{models.length}} models available from CompactAI` : 'No compatible models found.';
}} catch (error) {{
modelStatus.textContent = 'Failed to load model list.';
}}
}}
async function ensureModel() {{
const modelId = currentModelId();
if (!modelId) {{
modelStatus.textContent = 'Pick a model first.';
return null;
}}
modelStatus.textContent = 'Downloading model files…';
const result = await api('/api/ensure', {{ modelId, type: currentType() }});
if (!result.success) {{
modelStatus.textContent = result.error || 'Download failed.';
return null;
}}
modelStatus.textContent = `${{result.name}} ready on ${{result.series}}`;
return result;
}}
async function generate() {{
output.textContent = '';
genStatus.textContent = '';
const modelId = currentModelId();
const prompt = promptBox.value.trim();
if (!modelId) {{
genStatus.textContent = 'Pick a model first.';
return;
}}
if (!prompt) {{
genStatus.textContent = 'Enter a prompt first.';
return;
}}
genStatus.textContent = 'Preparing model…';
const result = await api('/api/generate', {{
modelId,
type: currentType(),
prompt,
temperature: Number(document.getElementById('temperature').value || 0.8),
top_k: Number(document.getElementById('topK').value || 40),
max_new_tokens: Number(document.getElementById('maxTokens').value || 256),
}});
if (!result.success) {{
genStatus.textContent = result.error || 'Generation failed.';
return;
}}
output.textContent = result.text || '';
genStatus.textContent = 'Done.';
}}
document.getElementById('refreshBtn').addEventListener('click', refreshModels);
document.getElementById('downloadBtn').addEventListener('click', ensureModel);
document.getElementById('generateBtn').addEventListener('click', generate);
promptBox.addEventListener('keydown', (event) => {{
if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {{
event.preventDefault();
generate();
}}
}});
refreshModels();
</script>
</body>
</html>"""
class Handler(BaseHTTPRequestHandler):
def _send_json(self, payload, status=200):
body = json.dumps(payload).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(body)
def _send_html(self, payload: str, status=200):
body = payload.encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(body)
def do_GET(self):
parsed = urlparse(self.path)
if parsed.path in {"/", "/index.html"}:
self._send_html(page_html())
return
if parsed.path == "/api/models":
try:
self._send_json(model_list())
except Exception as exc:
self._send_json({"success": False, "error": str(exc)}, 500)
return
if parsed.path.startswith("/api/models/"):
repo_id = normalize_repo_id(parsed.path.removeprefix("/api/models/"))
try:
details = model_details(repo_id)
if not details:
self._send_json(
{"success": False, "error": "Model not found."}, 404
)
else:
self._send_json(details)
except Exception as exc:
self._send_json({"success": False, "error": str(exc)}, 500)
return
self._send_json({"success": False, "error": "Not found."}, 404)
def do_POST(self):
parsed = urlparse(self.path)
length = int(self.headers.get("Content-Length", "0") or "0")
raw = self.rfile.read(length).decode("utf-8") if length else "{}"
try:
payload = json.loads(raw or "{}")
except Exception:
payload = {}
if parsed.path == "/api/ensure":
try:
repo_id = normalize_repo_id(payload.get("modelId", ""))
model_type = payload.get("type", "model")
if not repo_id:
self._send_json(
{"success": False, "error": "Missing model ID."}, 400
)
return
details = model_details(repo_id)
if not details:
self._send_json(
{"success": False, "error": "Model not found."}, 404
)
return
bundle = load_bundle(repo_id, model_type)
self._send_json(
{
"success": True,
"id": bundle["repo_id"],
"name": bundle["name"],
"series": bundle["series"],
"type": bundle["type"],
}
)
except Exception as exc:
self._send_json({"success": False, "error": str(exc)}, 500)
return
if parsed.path == "/api/generate":
try:
repo_id = normalize_repo_id(payload.get("modelId", ""))
model_type = payload.get("type", "model")
prompt = str(payload.get("prompt", ""))
if not repo_id:
self._send_json(
{"success": False, "error": "Missing model ID."}, 400
)
return
bundle = load_bundle(repo_id, model_type)
with GENERATION_LOCK:
text = generate(
model=bundle["model"],
tokenizer=bundle["tokenizer"],
prompt=prompt,
max_new_tokens=int(payload.get("max_new_tokens", 256)),
temperature=float(payload.get("temperature", 0.8)),
top_k=int(payload.get("top_k", 40)),
repetition_penalty=float(
payload.get("repetition_penalty", 1.0)
),
device=str(bundle["device"]),
sft_mode=model_type != "pretrain",
force_thought=bool(payload.get("force_thought", False)),
stream=False,
decode_mode=str(payload.get("decode_mode", "legacy")),
best_of=int(payload.get("best_of", 3)),
no_repeat_ngram_size=int(
payload.get("no_repeat_ngram_size", 3)
),
context_window=int(payload.get("context_window", 2048)),
beam_width=int(payload.get("beam_width", 8)),
length_penalty=float(payload.get("length_penalty", 0.7)),
)
self._send_json(
{
"success": True,
"text": text,
"name": bundle["name"],
"series": bundle["series"],
}
)
except Exception as exc:
self._send_json({"success": False, "error": str(exc)}, 500)
return
self._send_json({"success": False, "error": "Not found."}, 404)
def log_message(self, format, *args):
return
def main():
CACHE_ROOT.mkdir(parents=True, exist_ok=True)
port = ensure_port(int(os.environ.get("PORT", "7860")))
server = ThreadingHTTPServer(("127.0.0.1", port), Handler)
url = f"http://127.0.0.1:{port}"
print(url, flush=True)
try:
webbrowser.open(url)
except Exception:
pass
try:
server.serve_forever()
except KeyboardInterrupt:
pass
finally:
server.server_close()
if __name__ == "__main__":
main()