Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import os | |
| import re | |
| import shutil | |
| import socket | |
| import string | |
| import sys | |
| import threading | |
| import webbrowser | |
| from dataclasses import dataclass | |
| from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple | |
| from urllib.parse import quote, unquote, urlparse | |
| from urllib.request import Request, urlopen | |
| import hashlib | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.checkpoint import checkpoint | |
| # --------------------------------------------------------------------------- | |
| # Config (from ailay.config) | |
| # --------------------------------------------------------------------------- | |
| class ModelConfig: | |
| dim: int = 128 | |
| n_unique_layers: int = 8 | |
| n_logical_layers: int = 16 | |
| n_heads: int = 4 | |
| n_kv_heads: int = 2 | |
| ffn_dim: int = 224 | |
| dropout: float = 0.0 | |
| seq_len: int = 2048 | |
| sliding_window_size: int = 512 | |
| mtp_horizons: Tuple[int, ...] = (2, 3, 4) | |
| rope_fraction: float = 0.5 | |
| embed_scale: bool = True | |
| logit_soft_cap: float = -1.0 | |
| quantization: str = "nvfp4" | |
| # Engram (conditional memory) config | |
| engram_dim: int = 0 | |
| engram_heads: int = 4 | |
| engram_table_size: int = 8192 | |
| engram_max_ngram: int = 3 | |
| # mHC (Manifold-Constrained Hyper-Connections) config | |
| mhc_expansion: int = 1 | |
| def head_dim(self) -> int: | |
| return self.dim // self.n_heads | |
| model_config = ModelConfig() | |
| MODEL_SERIES = { | |
| "haiku": { | |
| "dim": 64, | |
| "n_unique_layers": 12, | |
| "n_logical_layers": 24, | |
| "n_heads": 4, | |
| "n_kv_heads": 2, | |
| "ffn_dim": 384, | |
| "dropout": 0.0, | |
| "seq_len": 2048, | |
| "mtp_horizons": (2, 3, 4), | |
| "rope_fraction": 0.5, | |
| "batch_size": 80, | |
| "grad_accum": 1, | |
| "lr": 8e-4, | |
| "min_lr": 1e-5, | |
| "sft_lr": 2e-4, | |
| "sft_min_lr": 1e-5, | |
| "warmup_steps": 300, | |
| "weight_decay": 0.02, | |
| "pretrain_passes": 2, | |
| "sft_passes": 3, | |
| "max_sft_target_chars": 0, | |
| "use_grad_checkpoint": True, | |
| "num_workers": 24, | |
| "prefetch_factor": 64, | |
| "shuffle_buffer": 8192, | |
| "max_pretrain_tokens": 0, | |
| "min_pretrain_tokens": 100_000_000, | |
| "quantization": "nvfp4", | |
| "engram_dim": 8, | |
| "engram_heads": 2, | |
| "engram_table_size": 64, | |
| "engram_max_ngram": 2, | |
| "mhc_expansion": 2, | |
| }, | |
| "sonnet": { | |
| "dim": 1024, | |
| "n_unique_layers": 20, | |
| "n_logical_layers": 40, | |
| "n_heads": 16, | |
| "n_kv_heads": 4, | |
| "ffn_dim": 4096, | |
| "dropout": 0.0, | |
| "seq_len": 2048, | |
| "mtp_horizons": (2,), | |
| "rope_fraction": 0.5, | |
| "batch_size": 24, | |
| "grad_accum": 1, | |
| "lr": 1e-4, | |
| "min_lr": 2e-5, | |
| "sft_lr": 5e-5, | |
| "sft_min_lr": 5e-6, | |
| "warmup_steps": 250, | |
| "weight_decay": 0.1, | |
| "pretrain_passes": 1, | |
| "sft_passes": 1, | |
| "max_sft_target_chars": 0, | |
| "use_grad_checkpoint": True, | |
| "num_workers": 32, | |
| "prefetch_factor": 64, | |
| "shuffle_buffer": 16384, | |
| "max_pretrain_tokens": 0, | |
| "min_pretrain_tokens": 100_000_000, | |
| "quantization": "nvfp4", | |
| "engram_dim": 32, | |
| "engram_heads": 8, | |
| "engram_table_size": 4096, | |
| "engram_max_ngram": 2, | |
| "mhc_expansion": 2, | |
| }, | |
| "opus": { | |
| "dim": 1536, | |
| "n_unique_layers": 18, | |
| "n_logical_layers": 36, | |
| "n_heads": 16, | |
| "n_kv_heads": 4, | |
| "ffn_dim": 5888, | |
| "dropout": 0.0, | |
| "seq_len": 2048, | |
| "mtp_horizons": (2,), | |
| "rope_fraction": 0.5, | |
| "batch_size": 24, | |
| "grad_accum": 1, | |
| "lr": 1.6e-4, | |
| "min_lr": 1.6e-5, | |
| "sft_lr": 3e-5, | |
| "sft_min_lr": 3e-6, | |
| "warmup_steps": 200, | |
| "weight_decay": 0.1, | |
| "pretrain_passes": 1, | |
| "sft_passes": 1, | |
| "max_sft_target_chars": 0, | |
| "use_grad_checkpoint": True, | |
| "num_workers": 48, | |
| "prefetch_factor": 64, | |
| "shuffle_buffer": 16384, | |
| "max_pretrain_tokens": 0, | |
| "min_pretrain_tokens": 100_000_000, | |
| "quantization": "nvfp4", | |
| "engram_dim": 64, | |
| "engram_heads": 8, | |
| "engram_table_size": 8192, | |
| "engram_max_ngram": 2, | |
| "mhc_expansion": 4, | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Tokenizer (from ailay.tokenizer) | |
| # --------------------------------------------------------------------------- | |
| FORMAT_TOKENS = [ | |
| "<|user|>", | |
| "<|assistant|>", | |
| "<|system|>", | |
| "<|start_header_id|>", | |
| "<|end_header_id|>", | |
| "<|begin_of_thought|>", | |
| "<|end_of_thought|>", | |
| "<|begin_of_solution|>", | |
| "<|end_of_solution|>", | |
| ] | |
| class WordTokenizer: | |
| WORD_RE = re.compile( | |
| r"\s+|[^\W\d_]+(?:['\u2019][^\W\d_]+)?|\d+|[^\w\s]+", re.UNICODE | |
| ) | |
| def __init__( | |
| self, extra_chars: str = "", format_tokens: Optional[List[str]] = None | |
| ) -> None: | |
| base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r" | |
| fallback_chars = sorted(set(base + extra_chars)) | |
| self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"] | |
| self.format_tokens = ( | |
| list(format_tokens) if format_tokens else list(FORMAT_TOKENS) | |
| ) | |
| self.special = list(self.core_special) + list(self.format_tokens) | |
| self.id_to_token: List[str] = ( | |
| list(self.core_special) + self.format_tokens + fallback_chars | |
| ) | |
| self.token_to_id: Dict[str, int] = { | |
| t: i for i, t in enumerate(self.id_to_token) | |
| } | |
| self.special_multi_tokens = sorted( | |
| [t for t in self.special if len(t) > 1], key=len, reverse=True | |
| ) | |
| self.multi_char_tokens = self.special_multi_tokens | |
| self.dynamic_additions = 0 | |
| def pad_id(self) -> int: | |
| return self.token_to_id["<PAD>"] | |
| def bos_id(self) -> int: | |
| return self.token_to_id["<BOS>"] | |
| def eos_id(self) -> int: | |
| return self.token_to_id["<EOS>"] | |
| def unk_id(self) -> int: | |
| return self.token_to_id["<UNK>"] | |
| def vocab_size(self) -> int: | |
| return len(self.id_to_token) | |
| def maybe_add_char(self, ch: str) -> bool: | |
| if ch in self.token_to_id: | |
| return False | |
| self.token_to_id[ch] = len(self.id_to_token) | |
| self.id_to_token.append(ch) | |
| self.dynamic_additions += 1 | |
| return True | |
| def maybe_add_token(self, token: str) -> bool: | |
| if token in self.token_to_id: | |
| return False | |
| self.token_to_id[token] = len(self.id_to_token) | |
| self.id_to_token.append(token) | |
| self.dynamic_additions += 1 | |
| return True | |
| def iter_lexical_tokens(self, text: str) -> Iterator[str]: | |
| i = 0 | |
| n = len(text) | |
| while i < n: | |
| matched_special = False | |
| for token in self.special_multi_tokens: | |
| if text.startswith(token, i): | |
| yield token | |
| i += len(token) | |
| matched_special = True | |
| break | |
| if matched_special: | |
| continue | |
| m = self.WORD_RE.match(text, i) | |
| if m is None: | |
| yield text[i] | |
| i += 1 | |
| continue | |
| tok = m.group(0) | |
| yield tok | |
| i = m.end() | |
| def encode( | |
| self, text: str, add_bos: bool = False, add_eos: bool = False | |
| ) -> List[int]: | |
| out: List[int] = [] | |
| if add_bos: | |
| out.append(self.bos_id) | |
| unk = self.unk_id | |
| t2i = self.token_to_id | |
| for tok in self.iter_lexical_tokens(text): | |
| tid = t2i.get(tok) | |
| if tid is not None: | |
| out.append(tid) | |
| continue | |
| for ch in tok: | |
| out.append(t2i.get(ch, unk)) | |
| if add_eos: | |
| out.append(self.eos_id) | |
| return out | |
| def decode(self, ids: Sequence[int], skip_special: bool = True) -> str: | |
| pieces: List[str] = [] | |
| for idx in ids: | |
| if int(idx) < 0 or int(idx) >= len(self.id_to_token): | |
| continue | |
| tok = self.id_to_token[int(idx)] | |
| if skip_special and tok in self.special: | |
| continue | |
| pieces.append(tok) | |
| return "".join(pieces) | |
| def save(self, path: Path) -> None: | |
| with path.open("w", encoding="utf-8") as f: | |
| json.dump( | |
| { | |
| "id_to_token": self.id_to_token, | |
| "format_tokens": self.format_tokens, | |
| "core_special": self.core_special, | |
| "tokenizer_type": "word_level_v1", | |
| }, | |
| f, | |
| ensure_ascii=False, | |
| indent=2, | |
| ) | |
| def load(cls, path: Path) -> WordTokenizer: | |
| with path.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| format_tokens = data.get("format_tokens", FORMAT_TOKENS) | |
| tokenizer = cls(extra_chars="", format_tokens=format_tokens) | |
| tokenizer.id_to_token = data["id_to_token"] | |
| tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)} | |
| tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens) | |
| tokenizer.special_multi_tokens = sorted( | |
| [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True | |
| ) | |
| tokenizer.multi_char_tokens = tokenizer.special_multi_tokens | |
| return tokenizer | |
| LetterTokenizer = WordTokenizer | |
| # --------------------------------------------------------------------------- | |
| # Model (from ailay.model) | |
| # --------------------------------------------------------------------------- | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6) -> None: | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if hasattr(torch.nn.functional, "rms_norm"): | |
| return torch.nn.functional.rms_norm( | |
| x, self.weight.shape, self.weight, self.eps | |
| ) | |
| x_fp = x.float() | |
| rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps) | |
| return (x_fp * rms).to(dtype=x.dtype) * self.weight | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim: int, base: float = 10000.0) -> None: | |
| super().__init__() | |
| inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv, persistent=False) | |
| def cos_sin( | |
| self, seq_len: int, device: torch.device, dtype: torch.dtype | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) | |
| freqs = torch.outer(t, self.inv_freq) | |
| emb = torch.cat([freqs, freqs], dim=-1) | |
| cos = emb.cos()[None, None, :, :].to(dtype=dtype) | |
| sin = emb.sin()[None, None, :, :].to(dtype=dtype) | |
| return cos, sin | |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| class CausalSelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| head_dim: int, | |
| dropout: float, | |
| sliding_window: int, | |
| rope_fraction: float, | |
| ) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.n_heads = n_heads | |
| self.n_kv_heads = n_kv_heads | |
| self.head_dim = head_dim | |
| self.n_rep = n_heads // n_kv_heads | |
| self.dropout = dropout | |
| self.sliding_window = sliding_window | |
| self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) | |
| self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | |
| self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | |
| self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) | |
| for lin in (self.wq, self.wk, self.wv): | |
| nn.init.normal_(lin.weight, std=dim ** -0.5) | |
| nn.init.normal_(self.wo.weight, std=(n_heads * head_dim) ** -0.5) | |
| self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2) | |
| self.rope = RotaryEmbedding(self.rope_dim) | |
| self.q_norm = RMSNorm(head_dim) | |
| self.k_norm = RMSNorm(head_dim) | |
| self.output_gate = nn.Parameter(torch.zeros(n_heads)) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| is_global: bool, | |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | |
| B, T, _ = x.shape | |
| q = self.wq(x).view(B, T, self.n_heads, self.head_dim) | |
| k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim) | |
| v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| past_len = past_kv[0].shape[2] if past_kv is not None else 0 | |
| cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype) | |
| cos_slice = cos[:, :, past_len : past_len + T, :] | |
| sin_slice = sin[:, :, past_len : past_len + T, :] | |
| q_rope = q[..., : self.rope_dim] | |
| q_pass = q[..., self.rope_dim :] | |
| k_rope = k[..., : self.rope_dim] | |
| k_pass = k[..., self.rope_dim :] | |
| q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice) | |
| k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice) | |
| q = torch.cat([q_rope, q_pass], dim=-1) | |
| k = torch.cat([k_rope, k_pass], dim=-1) | |
| if past_kv is not None: | |
| k = torch.cat([past_kv[0], k], dim=2) | |
| v = torch.cat([past_kv[1], v], dim=2) | |
| new_kv = (k, v) if use_cache else None | |
| S = k.shape[2] | |
| if self.n_rep > 1: | |
| k = ( | |
| k[:, :, None, :, :] | |
| .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) | |
| .reshape(B, self.n_heads, S, self.head_dim) | |
| ) | |
| v = ( | |
| v[:, :, None, :, :] | |
| .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim) | |
| .reshape(B, self.n_heads, S, self.head_dim) | |
| ) | |
| drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0 | |
| if is_global: | |
| if past_kv is None and T > 1: | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, is_causal=True, dropout_p=drop_p | |
| ) | |
| else: | |
| out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p) | |
| else: | |
| T_q = q.shape[2] | |
| q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1) | |
| k_pos = torch.arange(S, device=q.device).unsqueeze(0) | |
| mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window) | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p | |
| ) | |
| gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1) | |
| out = out * gate | |
| out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) | |
| out = self.wo(out) | |
| return out, new_kv | |
| class SwiGLU(nn.Module): | |
| def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None: | |
| super().__init__() | |
| self.gate = nn.Linear(dim, hidden_dim, bias=False) | |
| self.up = nn.Linear(dim, hidden_dim, bias=False) | |
| self.down = nn.Linear(hidden_dim, dim, bias=False) | |
| self.drop = nn.Dropout(dropout) | |
| nn.init.normal_(self.gate.weight, std=dim ** -0.5) | |
| nn.init.normal_(self.up.weight, std=dim ** -0.5) | |
| nn.init.normal_(self.down.weight, std=hidden_dim ** -0.5) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| h = F.silu(self.gate(x)) * self.up(x) | |
| out = self.down(h) | |
| if self.training and torch.is_grad_enabled(): | |
| out = self.drop(out) | |
| return out | |
| class EngramBlock(nn.Module): | |
| """Conditional memory via O(1) hashed N-gram lookup (DeepSeek Engram).""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| engram_dim: int, | |
| n_heads: int = 4, | |
| table_size: int = 8192, | |
| max_ngram: int = 3, | |
| ) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.engram_dim = engram_dim | |
| self.n_heads = n_heads | |
| self.table_size = table_size | |
| self.max_ngram = max_ngram | |
| self.embeddings = nn.ParameterDict() | |
| for n in range(2, max_ngram + 1): | |
| for k in range(n_heads): | |
| self.embeddings[f"{n}_{k}"] = nn.Parameter( | |
| torch.randn(table_size, engram_dim) * (engram_dim**-0.5) | |
| ) | |
| for n in range(2, max_ngram + 1): | |
| for k in range(n_heads): | |
| seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16) | |
| rng = torch.Generator().manual_seed(seed) | |
| a = torch.randint(1, 2**31, (1,), generator=rng).item() | |
| b = torch.randint(0, 2**31, (1,), generator=rng).item() | |
| self.register_buffer( | |
| f"hash_a_{n}_{k}", torch.tensor(a), persistent=False | |
| ) | |
| self.register_buffer( | |
| f"hash_b_{n}_{k}", torch.tensor(b), persistent=False | |
| ) | |
| total_branch_dim = engram_dim * n_heads * (max_ngram - 1) | |
| self.branch_conv = nn.Conv1d( | |
| total_branch_dim, | |
| total_branch_dim, | |
| kernel_size=4, | |
| dilation=max_ngram, | |
| padding=0, | |
| groups=total_branch_dim, | |
| bias=True, | |
| ) | |
| nn.init.zeros_(self.branch_conv.weight) | |
| nn.init.zeros_(self.branch_conv.bias) | |
| self.gate_query = nn.Linear(dim, engram_dim, bias=False) | |
| self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False) | |
| self.gate_value = nn.Linear(total_branch_dim, dim, bias=False) | |
| self.gate_scale = engram_dim**-0.5 | |
| def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor: | |
| a = getattr(self, f"hash_a_{n}_{k}") | |
| b = getattr(self, f"hash_b_{n}_{k}") | |
| B, T = token_ids.shape | |
| padded = F.pad(token_ids, (n - 1, 0), value=0) | |
| combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device) | |
| for i in range(n): | |
| combined = (combined * 31 + padded[:, i : i + T].long()) % self.table_size | |
| return ((a * combined) ^ b) % self.table_size | |
| def forward( | |
| self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| B, T, _ = hidden.shape | |
| if token_ids is None: | |
| token_ids = hidden.mean(dim=-1).long() % self.table_size | |
| all_indices = [] | |
| all_tables = [] | |
| for n in range(2, self.max_ngram + 1): | |
| for k in range(self.n_heads): | |
| all_indices.append(self._hash_ngram(token_ids, n, k)) | |
| all_tables.append(self.embeddings[f"{n}_{k}"]) | |
| branch_outputs = [tbl[idx] for idx, tbl in zip(all_indices, all_tables)] | |
| memory = torch.cat(branch_outputs, dim=-1) | |
| conv_in = memory.transpose(1, 2) | |
| conv_in = F.pad( | |
| conv_in, | |
| (self.branch_conv.dilation[0] * (self.branch_conv.kernel_size[0] - 1), 0), | |
| ) | |
| conv_out = self.branch_conv(conv_in) | |
| memory = conv_out.transpose(1, 2) | |
| query = self.gate_query(hidden) | |
| key = self.gate_key(memory) | |
| gate = torch.sigmoid( | |
| (query * key).sum(dim=-1, keepdim=True) * self.gate_scale | |
| ) | |
| value = self.gate_value(memory) | |
| return gate * value | |
| def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor: | |
| M = torch.exp(logits.clamp(-10, 10)) | |
| for _ in range(n_iters): | |
| M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10) | |
| M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10) | |
| return M | |
| class ManifoldHyperConnection(nn.Module): | |
| """Manifold-Constrained Hyper-Connections (mHC) residual wrapper.""" | |
| def __init__(self, dim: int, expansion: int = 2) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.expansion = expansion | |
| n = expansion | |
| self.bias_pre = nn.Parameter(torch.zeros(1, n)) | |
| self.bias_post = nn.Parameter(torch.zeros(1, n)) | |
| self.bias_res = nn.Parameter(torch.zeros(n, n)) | |
| self.theta_pre = nn.Linear(n * dim, n, bias=False) | |
| self.theta_post = nn.Linear(n * dim, n, bias=False) | |
| self.theta_res = nn.Linear(n * dim, n * n, bias=False) | |
| self.alpha_pre = nn.Parameter(torch.tensor(0.0)) | |
| self.alpha_post = nn.Parameter(torch.tensor(0.0)) | |
| self.alpha_res = nn.Parameter(torch.tensor(0.0)) | |
| def _compute_mappings( | |
| self, x_expanded: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| B, T, _ = x_expanded.shape | |
| n = self.expansion | |
| x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]]) | |
| d_pre = torch.tanh(self.theta_pre(x_norm)) | |
| d_post = torch.tanh(self.theta_post(x_norm)) | |
| d_res = self.theta_res(x_norm) | |
| H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre) | |
| H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post) | |
| H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape( | |
| B, T, n, n | |
| ) | |
| H_res = _sinkhorn_knopp(H_res_raw) | |
| return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res | |
| def expand_stream(self, x: torch.Tensor) -> torch.Tensor: | |
| return x.repeat(1, 1, self.expansion) | |
| def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor: | |
| B, T, _ = x_expanded.shape | |
| return x_expanded.view(B, T, self.expansion, self.dim).mean(dim=-2) | |
| def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor: | |
| B, T, _ = x_expanded.shape | |
| x_streams = x_expanded.view(B, T, self.expansion, self.dim) | |
| return (H_pre @ x_streams).squeeze(-2) | |
| def post_res_mix( | |
| self, | |
| layer_output: torch.Tensor, | |
| x_expanded: torch.Tensor, | |
| H_post: torch.Tensor, | |
| H_res: torch.Tensor, | |
| ) -> torch.Tensor: | |
| B, T, _ = x_expanded.shape | |
| x_streams = x_expanded.view(B, T, self.expansion, self.dim) | |
| mixed = torch.matmul(H_res, x_streams) | |
| post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2)) | |
| return (mixed + post_out).reshape(B, T, self.expansion * self.dim) | |
| class TransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| head_dim: int, | |
| ffn_dim: int, | |
| dropout: float, | |
| sliding_window: int, | |
| rope_fraction: float, | |
| engram_dim: int = 0, | |
| engram_heads: int = 4, | |
| engram_table_size: int = 8192, | |
| engram_max_ngram: int = 3, | |
| mhc_expansion: int = 1, | |
| ) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.norm1 = RMSNorm(dim) | |
| self.attn = CausalSelfAttention( | |
| dim=dim, | |
| n_heads=n_heads, | |
| n_kv_heads=n_kv_heads, | |
| head_dim=head_dim, | |
| dropout=dropout, | |
| sliding_window=sliding_window, | |
| rope_fraction=rope_fraction, | |
| ) | |
| self.norm2 = RMSNorm(dim) | |
| self.ffn = SwiGLU(dim, ffn_dim, dropout) | |
| self.use_engram = engram_dim > 0 | |
| if self.use_engram: | |
| self.engram = EngramBlock( | |
| dim=dim, | |
| engram_dim=engram_dim, | |
| n_heads=engram_heads, | |
| table_size=engram_table_size, | |
| max_ngram=engram_max_ngram, | |
| ) | |
| self.engram_norm = RMSNorm(dim) | |
| self.use_mhc = mhc_expansion > 1 | |
| if self.use_mhc: | |
| self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion) | |
| self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| is_global: bool, | |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: bool = False, | |
| token_ids: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | |
| if self.use_mhc: | |
| x_exp = self.mhc_attn.expand_stream(x) | |
| H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp) | |
| attn_in = self.mhc_attn.pre_mix(x_exp, H_pre) | |
| attn_out, new_kv = self.attn( | |
| self.norm1(attn_in), is_global, past_kv, use_cache | |
| ) | |
| x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res) | |
| if self.use_engram: | |
| collapsed = self.mhc_attn.collapse_stream(x_exp) | |
| collapsed = collapsed + self.engram( | |
| self.engram_norm(collapsed), token_ids=token_ids | |
| ) | |
| x_exp = self.mhc_attn.expand_stream(collapsed) | |
| H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp) | |
| ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2) | |
| ffn_out = self.ffn(self.norm2(ffn_in)) | |
| x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2) | |
| x = self.mhc_attn.collapse_stream(x_exp) | |
| else: | |
| attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache) | |
| x = x + attn_out | |
| if self.use_engram: | |
| x = x + self.engram(self.engram_norm(x), token_ids=token_ids) | |
| x = x + self.ffn(self.norm2(x)) | |
| return x, new_kv | |
| def _detect_engram_dim(state_dict: dict) -> int: | |
| for key in state_dict: | |
| if ".engram." in key and ".embeddings." in key: | |
| return state_dict[key].shape[-1] | |
| return 0 | |
| def _detect_mhc_expansion(state_dict: dict) -> int: | |
| for key, val in state_dict.items(): | |
| if ".mhc_attn.bias_pre" in key and val.dim() == 2: | |
| return val.shape[-1] | |
| return 1 | |
| class TinyMemoryLM(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| dim: int, | |
| n_unique_layers: int, | |
| n_logical_layers: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| ffn_dim: int, | |
| dropout: float, | |
| mtp_horizons: Sequence[int], | |
| grad_checkpoint: bool, | |
| sliding_window: int = 512, | |
| rope_fraction: float = 0.5, | |
| embed_scale: bool = True, | |
| engram_dim: int = 0, | |
| engram_heads: int = 4, | |
| engram_table_size: int = 8192, | |
| engram_max_ngram: int = 3, | |
| mhc_expansion: int = 1, | |
| ) -> None: | |
| super().__init__() | |
| self.dim = dim | |
| self.n_unique_layers = n_unique_layers | |
| self.n_logical_layers = n_logical_layers | |
| self.grad_checkpoint = grad_checkpoint | |
| self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0 | |
| head_dim = dim // n_heads | |
| self.embed_tokens = nn.Embedding(vocab_size, dim) | |
| self.head = nn.Linear(dim, vocab_size, bias=False) | |
| self.head.weight = self.embed_tokens.weight | |
| self.output_bias = nn.Parameter(torch.zeros(vocab_size)) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| TransformerBlock( | |
| dim=dim, | |
| n_heads=n_heads, | |
| n_kv_heads=n_kv_heads, | |
| head_dim=head_dim, | |
| ffn_dim=ffn_dim, | |
| dropout=dropout, | |
| sliding_window=sliding_window, | |
| rope_fraction=rope_fraction, | |
| engram_dim=engram_dim, | |
| engram_heads=engram_heads, | |
| engram_table_size=engram_table_size, | |
| engram_max_ngram=engram_max_ngram, | |
| mhc_expansion=mhc_expansion, | |
| ) | |
| for _ in range(n_unique_layers) | |
| ] | |
| ) | |
| self.norm = RMSNorm(dim) | |
| self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1}) | |
| self.mtp_adapters = nn.ModuleDict( | |
| {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons} | |
| ) | |
| self.mtp_norms = nn.ModuleDict( | |
| {str(h): RMSNorm(dim) for h in self.mtp_horizons} | |
| ) | |
| res_scale = (2 * n_logical_layers) ** -0.5 | |
| for block in self.blocks: | |
| block.attn.wo.weight.data.mul_(res_scale) | |
| block.ffn.down.weight.data.mul_(res_scale) | |
| def resize_token_embeddings(self, new_vocab_size: int) -> None: | |
| old_vocab_size = self.embed_tokens.num_embeddings | |
| if new_vocab_size == old_vocab_size: | |
| return | |
| device = self.embed_tokens.weight.device | |
| old_embed_weight = self.embed_tokens.weight.data.clone() | |
| self.embed_tokens = nn.Embedding( | |
| new_vocab_size, self.embed_tokens.embedding_dim | |
| ).to(device) | |
| self.head = nn.Linear( | |
| self.embed_tokens.embedding_dim, new_vocab_size, bias=False | |
| ).to(device) | |
| self.head.weight = self.embed_tokens.weight | |
| old_bias = self.output_bias.data.clone() | |
| self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device)) | |
| copy_size = min(old_vocab_size, new_vocab_size) | |
| self.output_bias.data[:copy_size] = old_bias[:copy_size] | |
| self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size] | |
| def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]: | |
| logical = [] | |
| blocks_list = list(self.blocks) | |
| full_sequence = blocks_list + blocks_list | |
| for logical_idx, block in enumerate(full_sequence[: self.n_logical_layers]): | |
| logical.append((block, logical_idx)) | |
| return logical | |
| def forward( | |
| self, | |
| ids: torch.Tensor, | |
| use_cache: bool = False, | |
| past_key_values: Optional[ | |
| List[Optional[Tuple[torch.Tensor, torch.Tensor]]] | |
| ] = None, | |
| return_hidden: bool = False, | |
| ) -> Tuple[ | |
| torch.Tensor, | |
| Dict[int, torch.Tensor], | |
| Optional[torch.Tensor], | |
| Optional[List[Tuple[torch.Tensor, torch.Tensor]]], | |
| ]: | |
| B, T = ids.shape | |
| x = self.embed_tokens(ids) * self.embed_scale_factor | |
| logical_layers = self._build_logical_layers() | |
| new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = ( | |
| [] if use_cache else None | |
| ) | |
| for layer_idx, (block, logical_idx) in enumerate(logical_layers): | |
| is_global = logical_idx % 2 == 0 | |
| past_kv = ( | |
| past_key_values[layer_idx] | |
| if past_key_values is not None and layer_idx < len(past_key_values) | |
| else None | |
| ) | |
| if self.grad_checkpoint and self.training and not use_cache: | |
| x, layer_kv = checkpoint( | |
| block, x, is_global, past_kv, use_cache, ids, use_reentrant=True | |
| ) | |
| else: | |
| x, layer_kv = block(x, is_global, past_kv, use_cache, ids) | |
| if new_past_key_values is not None: | |
| new_past_key_values.append(layer_kv) | |
| x = self.norm(x) | |
| h_out = x if return_hidden else None | |
| logits = self.head(x) | |
| if self.embed_scale_factor != 1.0: | |
| logits = logits / self.embed_scale_factor | |
| logits = logits + self.output_bias | |
| mtp: Dict[int, torch.Tensor] = {} | |
| if self.mtp_horizons and self.training: | |
| for horizon in self.mtp_horizons: | |
| if horizon > 1 and horizon <= T - 1: | |
| shifted_h = x[:, :-horizon, :] | |
| adapted_h = self.mtp_adapters[str(horizon)](shifted_h) | |
| adapted_h = self.mtp_norms[str(horizon)](adapted_h) | |
| mtp_logits = self.head(adapted_h) | |
| if self.embed_scale_factor != 1.0: | |
| mtp_logits = mtp_logits / self.embed_scale_factor | |
| mtp_logits = mtp_logits + self.output_bias | |
| mtp[horizon] = mtp_logits | |
| return logits, mtp, h_out, new_past_key_values | |
| # --------------------------------------------------------------------------- | |
| # Generation (from ailay.generation) | |
| # --------------------------------------------------------------------------- | |
| def sample_text( | |
| model: TinyMemoryLM, | |
| tokenizer: WordTokenizer, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| branches: int, | |
| branch_len: int, | |
| device: torch.device, | |
| seq_len: int, | |
| ) -> str: | |
| def _sample_id(logits: torch.Tensor) -> torch.Tensor: | |
| if not torch.isfinite(logits).any(): | |
| logits = torch.zeros_like(logits) | |
| logits = torch.where( | |
| torch.isfinite(logits), logits, torch.full_like(logits, -1e9) | |
| ) | |
| if top_k > 0: | |
| v, idx = torch.topk(logits, k=min(top_k, logits.shape[-1])) | |
| p = torch.softmax(v, dim=-1) | |
| return idx.gather(-1, torch.multinomial(p, 1)) | |
| p = torch.softmax(logits, dim=-1) | |
| return torch.multinomial(p, 1) | |
| model.eval() | |
| ids = tokenizer.encode(prompt, add_bos=True, add_eos=False) | |
| prompt_len = len(ids) | |
| x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) | |
| with torch.no_grad(): | |
| generated = 0 | |
| while generated < max_new_tokens: | |
| if branches <= 1: | |
| ctx = x[:, -seq_len:] | |
| logits, _, _, _ = model(ctx) | |
| nlogits = logits[:, -1, :] / max(temperature, 1e-6) | |
| nid = _sample_id(nlogits) | |
| x = torch.cat([x, nid], dim=1) | |
| generated += 1 | |
| continue | |
| rollout = min(branch_len, max_new_tokens - generated) | |
| best_nll: Optional[float] = None | |
| best_tokens: Optional[List[torch.Tensor]] = None | |
| for _ in range(branches): | |
| cand = x | |
| cand_tokens: List[torch.Tensor] = [] | |
| nll = 0.0 | |
| for _ in range(rollout): | |
| ctx = cand[:, -seq_len:] | |
| logits, _, _, _ = model(ctx) | |
| nlogits = logits[:, -1, :] / max(temperature, 1e-6) | |
| nid = _sample_id(nlogits) | |
| lp = F.log_softmax(nlogits.float(), dim=-1) | |
| nll += float(-lp.gather(-1, nid).item()) | |
| cand = torch.cat([cand, nid], dim=1) | |
| cand_tokens.append(nid) | |
| if best_nll is None or nll < best_nll: | |
| best_nll = nll | |
| best_tokens = cand_tokens | |
| assert best_tokens is not None | |
| for t in best_tokens: | |
| x = torch.cat([x, t], dim=1) | |
| generated += 1 | |
| generated_ids = x[0, prompt_len:].tolist() | |
| return tokenizer.decode(generated_ids, skip_special=True) | |
| def sample_text_cached( | |
| model: TinyMemoryLM, | |
| tokenizer: WordTokenizer, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| device: torch.device, | |
| seq_len: int, | |
| ) -> str: | |
| model.eval() | |
| ids = tokenizer.encode(prompt, add_bos=True, add_eos=False) | |
| prompt_len = len(ids) | |
| x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits, _, _, past_kv = model(x, use_cache=True) | |
| nlogits = logits[:, -1, :] / max(temperature, 1e-6) | |
| if top_k > 0: | |
| v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1])) | |
| p = torch.softmax(v, dim=-1) | |
| nid = idx.gather(-1, torch.multinomial(p, 1)) | |
| else: | |
| p = torch.softmax(nlogits, dim=-1) | |
| nid = torch.multinomial(p, 1) | |
| all_ids = [int(nid.item())] | |
| for _ in range(max_new_tokens - 1): | |
| logits, _, _, past_kv = model(nid, use_cache=True, past_key_values=past_kv) | |
| nlogits = logits[:, -1, :] / max(temperature, 1e-6) | |
| if top_k > 0: | |
| v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1])) | |
| p = torch.softmax(v, dim=-1) | |
| nid = idx.gather(-1, torch.multinomial(p, 1)) | |
| else: | |
| p = torch.softmax(nlogits, dim=-1) | |
| nid = torch.multinomial(p, 1) | |
| tid = int(nid.item()) | |
| all_ids.append(tid) | |
| if tid == tokenizer.eos_id: | |
| break | |
| return tokenizer.decode(all_ids, skip_special=True) | |
| def speculative_decode( | |
| model: TinyMemoryLM, | |
| tokenizer: WordTokenizer, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| device: torch.device, | |
| seq_len: int, | |
| ) -> str: | |
| model.eval() | |
| ids = tokenizer.encode(prompt, add_bos=True, add_eos=False) | |
| x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0) | |
| all_generated: List[int] = [] | |
| with torch.no_grad(): | |
| logits, _, h_out, past_kv = model(x, use_cache=True, return_hidden=True) | |
| def _sample_from(lg: torch.Tensor) -> int: | |
| lg = lg / max(temperature, 1e-6) | |
| if top_k > 0: | |
| v, idx = torch.topk(lg, k=min(top_k, lg.shape[-1])) | |
| p = torch.softmax(v, dim=-1) | |
| return int(idx[torch.multinomial(p, 1)].item()) | |
| p = torch.softmax(lg, dim=-1) | |
| return int(torch.multinomial(p, 1).item()) | |
| main_token = _sample_from(logits[0, -1, :]) | |
| all_generated.append(main_token) | |
| while len(all_generated) < max_new_tokens: | |
| if main_token == tokenizer.eos_id: | |
| break | |
| draft_tokens = [] | |
| if h_out is not None and model.mtp_horizons: | |
| last_hidden = h_out[:, -1:, :] | |
| for h in model.mtp_horizons: | |
| adapter = model.mtp_adapters[str(h)] | |
| norm = model.mtp_norms[str(h)] | |
| adapted = norm(adapter(last_hidden)) | |
| draft_logits = model.head(adapted) + model.output_bias | |
| draft_tok = _sample_from(draft_logits[0, 0, :]) | |
| draft_tokens.append(draft_tok) | |
| if not draft_tokens: | |
| nid = torch.tensor([[main_token]], dtype=torch.long, device=device) | |
| logits, _, h_out, past_kv = model( | |
| nid, use_cache=True, past_key_values=past_kv, return_hidden=True | |
| ) | |
| main_token = _sample_from(logits[0, -1, :]) | |
| all_generated.append(main_token) | |
| continue | |
| verify_input = torch.tensor( | |
| [[main_token] + draft_tokens], dtype=torch.long, device=device | |
| ) | |
| verify_logits, _, h_out, past_kv = model( | |
| verify_input, | |
| use_cache=True, | |
| past_key_values=past_kv, | |
| return_hidden=True, | |
| ) | |
| accepted = 0 | |
| all_generated.append(main_token) if main_token not in all_generated[ | |
| -1: | |
| ] else None | |
| for i, draft_tok in enumerate(draft_tokens): | |
| verified_tok = _sample_from(verify_logits[0, i, :]) | |
| if verified_tok == draft_tok: | |
| all_generated.append(draft_tok) | |
| accepted += 1 | |
| if draft_tok == tokenizer.eos_id: | |
| break | |
| else: | |
| all_generated.append(verified_tok) | |
| break | |
| if accepted < len(draft_tokens): | |
| trim_len = len(draft_tokens) - accepted - 1 | |
| if trim_len > 0 and past_kv is not None: | |
| past_kv = [ | |
| (k[:, :, :-trim_len, :], v[:, :, :-trim_len, :]) | |
| if k is not None | |
| else None | |
| for k, v in past_kv | |
| ] | |
| main_token = all_generated[-1] | |
| return tokenizer.decode(all_generated, skip_special=True) | |
| def build_stop_token_ids(tokenizer: WordTokenizer) -> set: | |
| stop_tokens = {tokenizer.eos_id} | |
| for tok in ("<|user|>", "<|system|>", "<|assistant|>"): | |
| tid = tokenizer.token_to_id.get(tok) | |
| if tid is not None: | |
| stop_tokens.add(int(tid)) | |
| return stop_tokens | |
| def apply_no_repeat_ngram( | |
| logits: torch.Tensor, | |
| token_history: Sequence[int], | |
| ngram_size: int, | |
| ) -> torch.Tensor: | |
| if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1): | |
| return logits | |
| prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple() | |
| banned: set = set() | |
| for i in range(len(token_history) - ngram_size + 1): | |
| if tuple(token_history[i : i + ngram_size - 1]) == prefix: | |
| banned.add(int(token_history[i + ngram_size - 1])) | |
| if not banned: | |
| return logits | |
| out = logits.clone() | |
| banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long) | |
| out[banned_ids] = float("-inf") | |
| return out | |
| def score_candidate( | |
| prompt: str, | |
| raw_text: str, | |
| visible_text: str, | |
| avg_logprob: float, | |
| ) -> float: | |
| clean = visible_text.strip() | |
| if not clean: | |
| return -1e9 | |
| score = avg_logprob | |
| words = clean.lower().split() | |
| prompt_words = re.findall(r"[A-Za-z][A-Za-z'-]{2,}", prompt.lower()) | |
| prompt_stop = { | |
| "what", | |
| "which", | |
| "when", | |
| "where", | |
| "why", | |
| "how", | |
| "are", | |
| "is", | |
| "the", | |
| "and", | |
| "for", | |
| "with", | |
| "that", | |
| "this", | |
| "from", | |
| "into", | |
| "about", | |
| "explain", | |
| "tell", | |
| "give", | |
| "list", | |
| "show", | |
| "write", | |
| "their", | |
| "there", | |
| "your", | |
| } | |
| prompt_keywords = {w for w in prompt_words if w not in prompt_stop} | |
| candidate_keywords = set(re.findall(r"[A-Za-z][A-Za-z'-]{2,}", clean.lower())) | |
| if len(words) < 6: | |
| score -= 2.0 | |
| else: | |
| score += min(2.0, len(words) * 0.03) | |
| if clean[-1:] in ".!?": | |
| score += 0.5 | |
| if "<|user|>" in raw_text or "<|system|>" in raw_text: | |
| score -= 4.0 | |
| if raw_text.count("<|assistant|>") > 1: | |
| score -= 2.0 | |
| if prompt_keywords: | |
| overlap = len(prompt_keywords & candidate_keywords) / len(prompt_keywords) | |
| if overlap == 0.0: | |
| score -= 2.5 | |
| else: | |
| score += min(3.5, overlap * 4.0) | |
| for open_tok, close_tok in [ | |
| ("<|begin_of_thought|>", "<|end_of_thought|>"), | |
| ("<|begin_of_solution|>", "<|end_of_solution|>"), | |
| ]: | |
| if (open_tok in raw_text) != (close_tok in raw_text): | |
| score -= 1.0 | |
| if len(words) >= 3: | |
| trigrams = [tuple(words[i : i + 3]) for i in range(len(words) - 2)] | |
| if trigrams: | |
| unique_ratio = len(set(trigrams)) / len(trigrams) | |
| if unique_ratio < 0.35: | |
| score -= 4.0 | |
| elif unique_ratio < 0.55: | |
| score -= 2.0 | |
| else: | |
| score += min(1.0, (unique_ratio - 0.55) * 2.0) | |
| alpha_words = [ | |
| w | |
| for w in words | |
| if len(w) <= 18 and (sum(ch.isalpha() for ch in w) / max(len(w), 1)) > 0.7 | |
| ] | |
| alpha_ratio = len(alpha_words) / max(len(words), 1) | |
| if alpha_ratio < 0.45: | |
| score -= 3.0 | |
| elif alpha_ratio < 0.65: | |
| score -= 1.0 | |
| return score | |
| def generate_candidate( | |
| model: TinyMemoryLM, | |
| tokenizer: WordTokenizer, | |
| prompt: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| no_repeat_ngram_size: int, | |
| device: str, | |
| sft_mode: bool, | |
| force_thought: bool, | |
| stream: bool, | |
| context_window: int, | |
| ) -> Tuple[str, str, float, int]: | |
| if sft_mode: | |
| full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" | |
| else: | |
| full_prompt = prompt | |
| if force_thought: | |
| full_prompt = f"{full_prompt}<|begin_of_thought|> " | |
| input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) | |
| input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device) | |
| visible_tokens: List[str] = [] | |
| raw_tokens: List[str] = [] | |
| stop_token_ids = build_stop_token_ids(tokenizer) | |
| total_logprob = 0.0 | |
| sampled_tokens = 0 | |
| with torch.no_grad(): | |
| for _ in range(max_new_tokens): | |
| ctx_ids = ( | |
| input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t | |
| ) | |
| logits, _, _, _ = model(ctx_ids) | |
| next_logits = logits[0, -1, :].clone() | |
| raw_next_logits = next_logits.clone() | |
| if repetition_penalty != 1.0: | |
| seen = set(input_ids_t[0].tolist()) | |
| for token_id in seen: | |
| if next_logits[token_id] > 0: | |
| next_logits[token_id] /= repetition_penalty | |
| else: | |
| next_logits[token_id] *= repetition_penalty | |
| if temperature != 1.0: | |
| next_logits = next_logits / max(temperature, 1e-6) | |
| if no_repeat_ngram_size > 1: | |
| next_logits = apply_no_repeat_ngram( | |
| next_logits, | |
| input_ids_t[0].tolist(), | |
| no_repeat_ngram_size, | |
| ) | |
| if top_k > 0: | |
| v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0))) | |
| next_logits[next_logits < v[-1]] = float("-inf") | |
| top_p = 0.9 | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) | |
| cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| remove_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p | |
| sorted_logits[remove_mask] = float("-inf") | |
| next_logits = sorted_logits.scatter(0, sorted_indices, sorted_logits) | |
| if not torch.isfinite(next_logits).any(): | |
| next_logits = raw_next_logits | |
| if temperature != 1.0: | |
| next_logits = next_logits / max(temperature, 1e-6) | |
| probs = torch.softmax(next_logits, dim=-1) | |
| next_id = torch.multinomial(probs, num_samples=1).item() | |
| total_logprob += float(torch.log(probs[next_id] + 1e-12).item()) | |
| sampled_tokens += 1 | |
| if next_id in stop_token_ids: | |
| break | |
| token_str = ( | |
| tokenizer.id_to_token[next_id] | |
| if next_id < len(tokenizer.id_to_token) | |
| else "" | |
| ) | |
| raw_tokens.append(token_str) | |
| if token_str not in tokenizer.special: | |
| visible_tokens.append(token_str) | |
| if stream: | |
| print(token_str, end="", flush=True) | |
| input_ids_t = torch.cat( | |
| [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1 | |
| ) | |
| if stream: | |
| print() | |
| avg_logprob = total_logprob / max(1, sampled_tokens) | |
| return "".join(visible_tokens), "".join(raw_tokens), avg_logprob, 0 | |
| def generate_beam_search( | |
| model: TinyMemoryLM, | |
| tokenizer: WordTokenizer, | |
| prompt: str, | |
| max_new_tokens: int = 60, | |
| beam_width: int = 8, | |
| length_penalty: float = 0.7, | |
| no_repeat_ngram_size: int = 3, | |
| device: str = "cuda", | |
| sft_mode: bool = False, | |
| context_window: int = 2048, | |
| ) -> str: | |
| if sft_mode: | |
| full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n" | |
| else: | |
| full_prompt = prompt | |
| prompt_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False) | |
| prompt_len = len(prompt_ids) | |
| stop_ids = build_stop_token_ids(tokenizer) | |
| beams: List[Tuple[float, List[int]]] = [(0.0, list(prompt_ids))] | |
| completed: List[Tuple[float, List[int]]] = [] | |
| for _step in range(max_new_tokens): | |
| if not beams: | |
| break | |
| candidates: List[Tuple[float, List[int]]] = [] | |
| for beam_score, beam_ids in beams: | |
| x = torch.tensor( | |
| [beam_ids[-context_window:]], dtype=torch.long, device=device | |
| ) | |
| with torch.no_grad(): | |
| logits, _, _, _ = model(x) | |
| nl = logits[0, -1, :] | |
| log_probs = F.log_softmax(nl, dim=-1) | |
| gen_ids = beam_ids[prompt_len:] | |
| if no_repeat_ngram_size > 1 and len(gen_ids) >= no_repeat_ngram_size - 1: | |
| prefix = tuple(gen_ids[-(no_repeat_ngram_size - 1) :]) | |
| for i in range(len(gen_ids) - no_repeat_ngram_size + 1): | |
| if tuple(gen_ids[i : i + no_repeat_ngram_size - 1]) == prefix: | |
| log_probs[gen_ids[i + no_repeat_ngram_size - 1]] = float("-inf") | |
| topk_lp, topk_ids = torch.topk(log_probs, beam_width) | |
| for i in range(beam_width): | |
| tid = topk_ids[i].item() | |
| new_score = beam_score + topk_lp[i].item() | |
| new_ids = beam_ids + [tid] | |
| if tid in stop_ids: | |
| completed.append((new_score, new_ids)) | |
| else: | |
| candidates.append((new_score, new_ids)) | |
| def _norm_score(pair): | |
| gen_len = max(1, len(pair[1]) - prompt_len) | |
| return pair[0] / (gen_len**length_penalty) | |
| candidates.sort(key=_norm_score, reverse=True) | |
| beams = candidates[:beam_width] | |
| pool = completed + beams | |
| if not pool: | |
| return "" | |
| def _norm_score_final(pair): | |
| gen_len = max(1, len(pair[1]) - prompt_len) | |
| return pair[0] / (gen_len**length_penalty) | |
| pool.sort(key=_norm_score_final, reverse=True) | |
| best_ids = pool[0][1][prompt_len:] | |
| text = tokenizer.decode(best_ids, skip_special=True) | |
| nl_pos = text.find("\n") | |
| if nl_pos > 5: | |
| text = text[:nl_pos] | |
| return text.strip() | |
| def generate( | |
| model: TinyMemoryLM, | |
| tokenizer: WordTokenizer, | |
| prompt: str, | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.8, | |
| top_k: int = 40, | |
| repetition_penalty: float = 1.0, | |
| device: str = "cuda", | |
| sft_mode: bool = False, | |
| force_thought: bool = False, | |
| stream: bool = True, | |
| decode_mode: str = "legacy", | |
| best_of: int = 3, | |
| no_repeat_ngram_size: int = 3, | |
| context_window: int = 2048, | |
| beam_width: int = 8, | |
| length_penalty: float = 0.7, | |
| ) -> str: | |
| if decode_mode == "beam": | |
| text = generate_beam_search( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| max_new_tokens=max_new_tokens, | |
| beam_width=beam_width, | |
| length_penalty=length_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| device=device, | |
| sft_mode=sft_mode, | |
| context_window=context_window, | |
| ) | |
| if stream: | |
| print(text) | |
| return text | |
| if decode_mode == "legacy": | |
| text, _, _, _ = generate_candidate( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| device=device, | |
| sft_mode=sft_mode, | |
| force_thought=force_thought, | |
| stream=stream, | |
| context_window=context_window, | |
| ) | |
| return text | |
| candidates: List[Tuple[float, str, str, float]] = [] | |
| for _ in range(max(1, best_of)): | |
| candidate_text, raw_text, avg_logprob, _ = generate_candidate( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=prompt, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| device=device, | |
| sft_mode=sft_mode, | |
| force_thought=force_thought, | |
| stream=False, | |
| context_window=context_window, | |
| ) | |
| score = score_candidate(prompt, raw_text, candidate_text, avg_logprob) | |
| candidates.append((score, candidate_text, raw_text, avg_logprob)) | |
| best_score, best_text, _, _ = max(candidates, key=lambda item: item[0]) | |
| if stream: | |
| print(best_text, end="", flush=True) | |
| print() | |
| return best_text | |
| # --------------------------------------------------------------------------- | |
| # Web server (from interactive.py) | |
| # --------------------------------------------------------------------------- | |
| ROOT = Path(__file__).resolve().parent | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| HF_ORG = "CompactAI" | |
| HF_API = "https://huggingface.co/api" | |
| CACHE_ROOT = Path.home() / ".cache" / "compactai_web" | |
| USER_AGENT = "Mozilla/5.0 CompactAI-Web" | |
| MODEL_CACHE: dict[tuple[str, str], dict[str, object]] = {} | |
| MODEL_CACHE_LOCK = threading.RLock() | |
| GENERATION_LOCK = threading.Lock() | |
| def request_json(url: str): | |
| req = Request(url, headers={"User-Agent": USER_AGENT}) | |
| with urlopen(req, timeout=60) as response: | |
| return json.loads(response.read().decode("utf-8")) | |
| def request_text(url: str) -> str: | |
| req = Request(url, headers={"User-Agent": USER_AGENT}) | |
| with urlopen(req, timeout=60) as response: | |
| return response.read().decode("utf-8", errors="replace") | |
| def download_file(url: str, destination: Path) -> None: | |
| destination.parent.mkdir(parents=True, exist_ok=True) | |
| temp_path = destination.with_suffix(destination.suffix + ".tmp") | |
| req = Request(url, headers={"User-Agent": USER_AGENT}) | |
| with urlopen(req, timeout=120) as response, temp_path.open("wb") as handle: | |
| shutil.copyfileobj(response, handle) | |
| temp_path.replace(destination) | |
| def normalize_repo_id(raw_repo_id: str) -> str: | |
| if not isinstance(raw_repo_id, str): | |
| return "" | |
| repo_id = raw_repo_id.strip() | |
| if not repo_id: | |
| return "" | |
| try: | |
| repo_id = unquote(repo_id) | |
| except Exception: | |
| pass | |
| return ( | |
| repo_id.replace("https://huggingface.co/", "") | |
| .replace("http://huggingface.co/", "") | |
| .replace("api/models/", "") | |
| .replace("models/", "") | |
| .split("?", 1)[0] | |
| .split("#", 1)[0] | |
| .strip("/") | |
| ) | |
| def series_from_name(name: str) -> str | None: | |
| lower = (name or "").lower() | |
| if "haiku" in lower: | |
| return "Haiku" | |
| if "sonnet" in lower: | |
| return "Sonnet" | |
| if "opus" in lower: | |
| return "Opus" | |
| return None | |
| def encoded_repo_id(repo_id: str) -> str: | |
| return "/".join( | |
| quote(part, safe="") for part in normalize_repo_id(repo_id).split("/") if part | |
| ) | |
| def hf_file_url(repo_id: str, filename: str) -> str: | |
| encoded_name = "/".join( | |
| quote(part, safe="") for part in filename.split("/") if part | |
| ) | |
| return ( | |
| f"https://huggingface.co/{encoded_repo_id(repo_id)}/resolve/main/{encoded_name}" | |
| ) | |
| def model_list() -> list[dict[str, object]]: | |
| data = request_json(f"{HF_API}/models?author={quote(HF_ORG)}&full=true&limit=200") | |
| models: list[dict[str, object]] = [] | |
| for item in data: | |
| siblings = item.get("siblings") or [] | |
| filenames = [s.get("rfilename", "") for s in siblings if isinstance(s, dict)] | |
| has_model = "model.pt" in filenames or "model/model.pt" in filenames | |
| has_pretrain = "pretrain.pt" in filenames or "model/pretrain.pt" in filenames | |
| has_tokenizer = ( | |
| "tokenizer.json" in filenames or "model/tokenizer.json" in filenames | |
| ) | |
| if not has_model and not has_pretrain: | |
| continue | |
| name = (item.get("id") or "").split("/")[-1] | |
| series = series_from_name(name) | |
| if not series: | |
| continue | |
| models.append( | |
| { | |
| "id": item.get("id", ""), | |
| "name": name, | |
| "series": series, | |
| "downloads": item.get("downloads", 0) or 0, | |
| "likes": item.get("likes", 0) or 0, | |
| "has_model": has_model, | |
| "has_pretrain": has_pretrain, | |
| "has_tokenizer": has_tokenizer, | |
| } | |
| ) | |
| return sorted(models, key=lambda entry: entry["downloads"], reverse=True) | |
| def model_details(repo_id: str) -> dict[str, object] | None: | |
| normalized = normalize_repo_id(repo_id) | |
| if not normalized: | |
| return None | |
| data = request_json(f"{HF_API}/models/{encoded_repo_id(normalized)}") | |
| siblings = data.get("siblings") or [] | |
| files: dict[str, dict[str, float]] = {} | |
| has_model = False | |
| has_pretrain = False | |
| for sibling in siblings: | |
| if not isinstance(sibling, dict): | |
| continue | |
| filename = sibling.get("rfilename") or "" | |
| if not filename: | |
| continue | |
| size_mb = round((sibling.get("size") or 0) / (1024 * 1024), 2) | |
| files[filename] = {"size_mb": size_mb} | |
| if filename.startswith("model/"): | |
| files[filename.removeprefix("model/")] = {"size_mb": size_mb} | |
| if filename in {"model.pt", "model/model.pt"}: | |
| has_model = True | |
| if filename in {"pretrain.pt", "model/pretrain.pt"}: | |
| has_pretrain = True | |
| readme_raw = "" | |
| try: | |
| readme_raw = request_text( | |
| f"https://huggingface.co/{encoded_repo_id(normalized)}/raw/main/README.md" | |
| ) | |
| except Exception: | |
| readme_raw = "" | |
| name = (data.get("id") or normalized).split("/")[-1] | |
| return { | |
| "id": normalized, | |
| "name": name, | |
| "series": series_from_name(name) or "Sonnet", | |
| "downloads": data.get("downloads", 0) or 0, | |
| "files": files, | |
| "readme_raw": readme_raw, | |
| "hf_model_id": normalized, | |
| "has_model": has_model, | |
| "has_pretrain": has_pretrain, | |
| } | |
| def cache_dir(repo_id: str, model_type: str) -> Path: | |
| return CACHE_ROOT / normalize_repo_id(repo_id).replace("/", "__") / model_type | |
| def artifact_candidates(model_type: str) -> list[str]: | |
| return ( | |
| ["model/pretrain.pt", "pretrain.pt"] | |
| if model_type == "pretrain" | |
| else ["model/model.pt", "model.pt"] | |
| ) | |
| def ensure_artifact(repo_id: str, model_type: str, destination_name: str) -> Path: | |
| normalized = normalize_repo_id(repo_id) | |
| target = cache_dir(normalized, model_type) / destination_name | |
| if target.exists(): | |
| return target | |
| last_error: Exception | None = None | |
| for candidate in ( | |
| artifact_candidates(model_type) | |
| if destination_name.endswith(".pt") | |
| else ["model/tokenizer.json", "tokenizer.json"] | |
| ): | |
| try: | |
| download_file(hf_file_url(normalized, candidate), target) | |
| return target | |
| except Exception as exc: | |
| last_error = exc | |
| raise RuntimeError( | |
| f"Unable to download {destination_name} for {normalized}: {last_error}" | |
| ) | |
| def series_config(series: str) -> dict[str, object]: | |
| return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"]) | |
| def load_bundle(repo_id: str, model_type: str) -> dict[str, object]: | |
| normalized = normalize_repo_id(repo_id) | |
| details = model_details(normalized) | |
| if not details: | |
| raise RuntimeError("Model details are unavailable.") | |
| series = str(details["series"]) | |
| key = (normalized, model_type) | |
| with MODEL_CACHE_LOCK: | |
| cached = MODEL_CACHE.get(key) | |
| if cached: | |
| return cached | |
| bundle_dir = cache_dir(normalized, model_type) | |
| bundle_dir.mkdir(parents=True, exist_ok=True) | |
| model_path = bundle_dir / ( | |
| "pretrain.pt" if model_type == "pretrain" else "model.pt" | |
| ) | |
| tokenizer_path = bundle_dir / "tokenizer.json" | |
| if not model_path.exists(): | |
| ensure_artifact(normalized, model_type, model_path.name) | |
| if not tokenizer_path.exists(): | |
| ensure_artifact(normalized, model_type, tokenizer_path.name) | |
| tokenizer = WordTokenizer.load(tokenizer_path) | |
| ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False) | |
| cfg = series_config(series) | |
| vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size)) | |
| state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt | |
| # Auto-detect new arch features from checkpoint weights | |
| engram_dim = _detect_engram_dim(state_dict) or int( | |
| cfg.get("engram_dim", model_config.engram_dim) | |
| ) | |
| mhc_expansion = _detect_mhc_expansion(state_dict) or int( | |
| cfg.get("mhc_expansion", model_config.mhc_expansion) | |
| ) | |
| model = TinyMemoryLM( | |
| vocab_size=vocab_size, | |
| dim=int(cfg.get("dim", model_config.dim)), | |
| n_unique_layers=int( | |
| cfg.get("n_unique_layers", model_config.n_unique_layers) | |
| ), | |
| n_logical_layers=int( | |
| cfg.get("n_logical_layers", model_config.n_logical_layers) | |
| ), | |
| n_heads=int(cfg.get("n_heads", model_config.n_heads)), | |
| n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)), | |
| ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)), | |
| dropout=float(cfg.get("dropout", model_config.dropout)), | |
| mtp_horizons=tuple( | |
| int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons) | |
| ), | |
| grad_checkpoint=False, | |
| sliding_window=int( | |
| cfg.get("sliding_window_size", model_config.sliding_window_size) | |
| ), | |
| rope_fraction=float( | |
| cfg.get("rope_fraction", model_config.rope_fraction) | |
| ), | |
| embed_scale=bool( | |
| cfg.get("embed_scale", model_config.embed_scale) | |
| ), | |
| engram_dim=engram_dim, | |
| engram_heads=int(cfg.get("engram_heads", model_config.engram_heads)), | |
| engram_table_size=int( | |
| cfg.get("engram_table_size", model_config.engram_table_size) | |
| ), | |
| engram_max_ngram=int( | |
| cfg.get("engram_max_ngram", model_config.engram_max_ngram) | |
| ), | |
| mhc_expansion=mhc_expansion, | |
| ) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| if tokenizer.vocab_size > vocab_size: | |
| model.resize_token_embeddings(tokenizer.vocab_size) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| bundle = { | |
| "repo_id": normalized, | |
| "name": details["name"], | |
| "series": series, | |
| "type": model_type, | |
| "model": model, | |
| "tokenizer": tokenizer, | |
| "device": device, | |
| "model_path": str(model_path), | |
| "tokenizer_path": str(tokenizer_path), | |
| "downloads": details["downloads"], | |
| } | |
| MODEL_CACHE[key] = bundle | |
| return bundle | |
| def ensure_port(start_port: int) -> int: | |
| for port in range(start_port, start_port + 50): | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: | |
| try: | |
| sock.bind(("127.0.0.1", port)) | |
| except OSError: | |
| continue | |
| return port | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: | |
| sock.bind(("127.0.0.1", 0)) | |
| return sock.getsockname()[1] | |
| def page_html() -> str: | |
| return f"""<!doctype html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1"> | |
| <title>CompactAI Web</title> | |
| <style> | |
| :root {{ | |
| color-scheme: dark; | |
| --bg: #050505; | |
| --panel: #111111; | |
| --panel-2: #161616; | |
| --line: #262626; | |
| --text: #f5f5f5; | |
| --muted: #a3a3a3; | |
| --accent: #d97706; | |
| --accent-2: #b45309; | |
| --soft: #1f1f1f; | |
| }} | |
| * {{ box-sizing: border-box; }} | |
| body {{ | |
| margin: 0; | |
| font-family: Geist, -apple-system, BlinkMacSystemFont, sans-serif; | |
| background: var(--bg); | |
| color: var(--text); | |
| line-height: 1.5; | |
| }} | |
| a {{ color: inherit; }} | |
| .wrap {{ max-width: 1120px; margin: 0 auto; padding: 28px 20px 40px; }} | |
| .hero {{ | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: end; | |
| gap: 16px; | |
| padding: 22px 0 28px; | |
| border-bottom: 1px solid var(--line); | |
| margin-bottom: 22px; | |
| }} | |
| h1 {{ margin: 0; font-size: clamp(2rem, 5vw, 3.5rem); letter-spacing: -0.04em; }} | |
| .subtitle {{ margin: 10px 0 0; color: var(--muted); max-width: 58ch; }} | |
| .grid {{ | |
| display: grid; | |
| grid-template-columns: 1.1fr 1fr; | |
| gap: 18px; | |
| }} | |
| .panel {{ | |
| background: var(--panel); | |
| border: 1px solid var(--line); | |
| border-radius: 18px; | |
| padding: 18px; | |
| }} | |
| .panel h2 {{ margin: 0 0 12px; font-size: 15px; letter-spacing: 0.02em; text-transform: uppercase; color: var(--muted); }} | |
| .row {{ display: flex; gap: 10px; flex-wrap: wrap; }} | |
| select, textarea, input {{ | |
| width: 100%; | |
| background: var(--panel-2); | |
| color: var(--text); | |
| border: 1px solid var(--line); | |
| border-radius: 12px; | |
| padding: 12px 14px; | |
| font: inherit; | |
| outline: none; | |
| }} | |
| textarea {{ min-height: 170px; resize: vertical; }} | |
| select {{ appearance: none; }} | |
| .choice {{ | |
| flex: 1 1 150px; | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| padding: 10px 12px; | |
| border: 1px solid var(--line); | |
| border-radius: 12px; | |
| background: var(--panel-2); | |
| cursor: pointer; | |
| }} | |
| .choice input {{ width: auto; }} | |
| .btns {{ display: flex; flex-wrap: wrap; gap: 10px; }} | |
| button {{ | |
| border: 1px solid var(--line); | |
| border-radius: 12px; | |
| padding: 11px 14px; | |
| background: var(--soft); | |
| color: var(--text); | |
| font: inherit; | |
| cursor: pointer; | |
| transition: transform 0.15s ease, border-color 0.15s ease, background 0.15s ease; | |
| }} | |
| button:hover {{ transform: translateY(-1px); border-color: #3a3a3a; }} | |
| .primary {{ background: var(--accent); border-color: var(--accent); color: #fff; }} | |
| .primary:hover {{ background: var(--accent-2); border-color: var(--accent-2); }} | |
| .status {{ | |
| margin-top: 12px; | |
| color: var(--muted); | |
| font-size: 13px; | |
| min-height: 1.4em; | |
| }} | |
| .output {{ | |
| white-space: pre-wrap; | |
| background: #0b0b0b; | |
| border: 1px solid var(--line); | |
| border-radius: 16px; | |
| min-height: 280px; | |
| padding: 16px; | |
| color: #e7e5e4; | |
| overflow: auto; | |
| }} | |
| .meta {{ | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 8px; | |
| margin-top: 8px; | |
| }} | |
| .chip {{ | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 6px; | |
| padding: 6px 10px; | |
| border-radius: 999px; | |
| border: 1px solid var(--line); | |
| background: var(--panel-2); | |
| font-size: 12px; | |
| color: var(--muted); | |
| }} | |
| .code {{ | |
| margin-top: 14px; | |
| padding: 12px 14px; | |
| border-radius: 12px; | |
| border: 1px solid var(--line); | |
| background: #0b0b0b; | |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace; | |
| font-size: 13px; | |
| overflow-x: auto; | |
| }} | |
| @media (max-width: 900px) {{ | |
| .grid {{ grid-template-columns: 1fr; }} | |
| .hero {{ align-items: start; flex-direction: column; }} | |
| }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="wrap"> | |
| <div class="hero"> | |
| <div> | |
| <h1>CompactAI Web</h1> | |
| <p class="subtitle">Pull a model from Hugging Face, keep it cached locally, and chat in the browser.</p> | |
| </div> | |
| <div class="meta"> | |
| <span class="chip">Hugging Face: CompactAI</span> | |
| <span class="chip">pip install -r requirements.txt</span> | |
| <span class="chip">Local inference</span> | |
| </div> | |
| </div> | |
| <div class="grid"> | |
| <section class="panel"> | |
| <h2>Model</h2> | |
| <select id="modelSelect"></select> | |
| <div class="row" style="margin-top: 10px;"> | |
| <label class="choice"><input type="radio" name="type" value="model" checked> Instruct / final</label> | |
| <label class="choice"><input type="radio" name="type" value="pretrain"> Pretrain</label> | |
| </div> | |
| <div class="btns" style="margin-top: 12px;"> | |
| <button id="downloadBtn">Download</button> | |
| <button id="refreshBtn">Refresh models</button> | |
| </div> | |
| <div class="status" id="modelStatus">Loading model list…</div> | |
| <div class="code">python3 interactive_web.py</div> | |
| </section> | |
| <section class="panel"> | |
| <h2>Prompt</h2> | |
| <textarea id="prompt" placeholder="Ask something…"></textarea> | |
| <div class="row" style="margin-top: 10px;"> | |
| <input id="temperature" type="number" min="0.1" max="2" step="0.05" value="0.8" style="flex: 1 1 120px;"> | |
| <input id="topK" type="number" min="1" max="100" step="1" value="40" style="flex: 1 1 120px;"> | |
| <input id="maxTokens" type="number" min="16" max="2048" step="16" value="256" style="flex: 1 1 120px;"> | |
| </div> | |
| <div class="btns" style="margin-top: 12px;"> | |
| <button id="generateBtn" class="primary">Generate</button> | |
| </div> | |
| <div class="status" id="genStatus"></div> | |
| </section> | |
| </div> | |
| <section class="panel" style="margin-top: 18px;"> | |
| <h2>Response</h2> | |
| <div id="output" class="output"></div> | |
| </section> | |
| </div> | |
| <script> | |
| const modelSelect = document.getElementById('modelSelect'); | |
| const modelStatus = document.getElementById('modelStatus'); | |
| const genStatus = document.getElementById('genStatus'); | |
| const output = document.getElementById('output'); | |
| const promptBox = document.getElementById('prompt'); | |
| async function api(path, body) {{ | |
| const response = await fetch(path, {{ | |
| method: body ? 'POST' : 'GET', | |
| headers: body ? {{ 'Content-Type': 'application/json' }} : undefined, | |
| body: body ? JSON.stringify(body) : undefined, | |
| }}); | |
| return response.json(); | |
| }} | |
| function currentType() {{ | |
| return document.querySelector('input[name="type"]:checked').value; | |
| }} | |
| function currentModelId() {{ | |
| return modelSelect.value; | |
| }} | |
| function setModels(models) {{ | |
| modelSelect.innerHTML = ''; | |
| for (const model of models) {{ | |
| const option = document.createElement('option'); | |
| option.value = model.id; | |
| option.textContent = `${{model.name}} • ${{model.series}}`; | |
| modelSelect.appendChild(option); | |
| }} | |
| if (models.length === 0) {{ | |
| const option = document.createElement('option'); | |
| option.value = ''; | |
| option.textContent = 'No CompactAI models found'; | |
| modelSelect.appendChild(option); | |
| }} | |
| }} | |
| async function refreshModels() {{ | |
| modelStatus.textContent = 'Loading model list…'; | |
| try {{ | |
| const models = await api('/api/models'); | |
| setModels(models); | |
| modelStatus.textContent = models.length ? `${{models.length}} models available from CompactAI` : 'No compatible models found.'; | |
| }} catch (error) {{ | |
| modelStatus.textContent = 'Failed to load model list.'; | |
| }} | |
| }} | |
| async function ensureModel() {{ | |
| const modelId = currentModelId(); | |
| if (!modelId) {{ | |
| modelStatus.textContent = 'Pick a model first.'; | |
| return null; | |
| }} | |
| modelStatus.textContent = 'Downloading model files…'; | |
| const result = await api('/api/ensure', {{ modelId, type: currentType() }}); | |
| if (!result.success) {{ | |
| modelStatus.textContent = result.error || 'Download failed.'; | |
| return null; | |
| }} | |
| modelStatus.textContent = `${{result.name}} ready on ${{result.series}}`; | |
| return result; | |
| }} | |
| async function generate() {{ | |
| output.textContent = ''; | |
| genStatus.textContent = ''; | |
| const modelId = currentModelId(); | |
| const prompt = promptBox.value.trim(); | |
| if (!modelId) {{ | |
| genStatus.textContent = 'Pick a model first.'; | |
| return; | |
| }} | |
| if (!prompt) {{ | |
| genStatus.textContent = 'Enter a prompt first.'; | |
| return; | |
| }} | |
| genStatus.textContent = 'Preparing model…'; | |
| const result = await api('/api/generate', {{ | |
| modelId, | |
| type: currentType(), | |
| prompt, | |
| temperature: Number(document.getElementById('temperature').value || 0.8), | |
| top_k: Number(document.getElementById('topK').value || 40), | |
| max_new_tokens: Number(document.getElementById('maxTokens').value || 256), | |
| }}); | |
| if (!result.success) {{ | |
| genStatus.textContent = result.error || 'Generation failed.'; | |
| return; | |
| }} | |
| output.textContent = result.text || ''; | |
| genStatus.textContent = 'Done.'; | |
| }} | |
| document.getElementById('refreshBtn').addEventListener('click', refreshModels); | |
| document.getElementById('downloadBtn').addEventListener('click', ensureModel); | |
| document.getElementById('generateBtn').addEventListener('click', generate); | |
| promptBox.addEventListener('keydown', (event) => {{ | |
| if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {{ | |
| event.preventDefault(); | |
| generate(); | |
| }} | |
| }}); | |
| refreshModels(); | |
| </script> | |
| </body> | |
| </html>""" | |
| class Handler(BaseHTTPRequestHandler): | |
| def _send_json(self, payload, status=200): | |
| body = json.dumps(payload).encode("utf-8") | |
| self.send_response(status) | |
| self.send_header("Content-Type", "application/json; charset=utf-8") | |
| self.send_header("Content-Length", str(len(body))) | |
| self.send_header("Cache-Control", "no-store") | |
| self.end_headers() | |
| self.wfile.write(body) | |
| def _send_html(self, payload: str, status=200): | |
| body = payload.encode("utf-8") | |
| self.send_response(status) | |
| self.send_header("Content-Type", "text/html; charset=utf-8") | |
| self.send_header("Content-Length", str(len(body))) | |
| self.send_header("Cache-Control", "no-store") | |
| self.end_headers() | |
| self.wfile.write(body) | |
| def do_GET(self): | |
| parsed = urlparse(self.path) | |
| if parsed.path in {"/", "/index.html"}: | |
| self._send_html(page_html()) | |
| return | |
| if parsed.path == "/api/models": | |
| try: | |
| self._send_json(model_list()) | |
| except Exception as exc: | |
| self._send_json({"success": False, "error": str(exc)}, 500) | |
| return | |
| if parsed.path.startswith("/api/models/"): | |
| repo_id = normalize_repo_id(parsed.path.removeprefix("/api/models/")) | |
| try: | |
| details = model_details(repo_id) | |
| if not details: | |
| self._send_json( | |
| {"success": False, "error": "Model not found."}, 404 | |
| ) | |
| else: | |
| self._send_json(details) | |
| except Exception as exc: | |
| self._send_json({"success": False, "error": str(exc)}, 500) | |
| return | |
| self._send_json({"success": False, "error": "Not found."}, 404) | |
| def do_POST(self): | |
| parsed = urlparse(self.path) | |
| length = int(self.headers.get("Content-Length", "0") or "0") | |
| raw = self.rfile.read(length).decode("utf-8") if length else "{}" | |
| try: | |
| payload = json.loads(raw or "{}") | |
| except Exception: | |
| payload = {} | |
| if parsed.path == "/api/ensure": | |
| try: | |
| repo_id = normalize_repo_id(payload.get("modelId", "")) | |
| model_type = payload.get("type", "model") | |
| if not repo_id: | |
| self._send_json( | |
| {"success": False, "error": "Missing model ID."}, 400 | |
| ) | |
| return | |
| details = model_details(repo_id) | |
| if not details: | |
| self._send_json( | |
| {"success": False, "error": "Model not found."}, 404 | |
| ) | |
| return | |
| bundle = load_bundle(repo_id, model_type) | |
| self._send_json( | |
| { | |
| "success": True, | |
| "id": bundle["repo_id"], | |
| "name": bundle["name"], | |
| "series": bundle["series"], | |
| "type": bundle["type"], | |
| } | |
| ) | |
| except Exception as exc: | |
| self._send_json({"success": False, "error": str(exc)}, 500) | |
| return | |
| if parsed.path == "/api/generate": | |
| try: | |
| repo_id = normalize_repo_id(payload.get("modelId", "")) | |
| model_type = payload.get("type", "model") | |
| prompt = str(payload.get("prompt", "")) | |
| if not repo_id: | |
| self._send_json( | |
| {"success": False, "error": "Missing model ID."}, 400 | |
| ) | |
| return | |
| bundle = load_bundle(repo_id, model_type) | |
| with GENERATION_LOCK: | |
| text = generate( | |
| model=bundle["model"], | |
| tokenizer=bundle["tokenizer"], | |
| prompt=prompt, | |
| max_new_tokens=int(payload.get("max_new_tokens", 256)), | |
| temperature=float(payload.get("temperature", 0.8)), | |
| top_k=int(payload.get("top_k", 40)), | |
| repetition_penalty=float( | |
| payload.get("repetition_penalty", 1.0) | |
| ), | |
| device=str(bundle["device"]), | |
| sft_mode=model_type != "pretrain", | |
| force_thought=bool(payload.get("force_thought", False)), | |
| stream=False, | |
| decode_mode=str(payload.get("decode_mode", "legacy")), | |
| best_of=int(payload.get("best_of", 3)), | |
| no_repeat_ngram_size=int( | |
| payload.get("no_repeat_ngram_size", 3) | |
| ), | |
| context_window=int(payload.get("context_window", 2048)), | |
| beam_width=int(payload.get("beam_width", 8)), | |
| length_penalty=float(payload.get("length_penalty", 0.7)), | |
| ) | |
| self._send_json( | |
| { | |
| "success": True, | |
| "text": text, | |
| "name": bundle["name"], | |
| "series": bundle["series"], | |
| } | |
| ) | |
| except Exception as exc: | |
| self._send_json({"success": False, "error": str(exc)}, 500) | |
| return | |
| self._send_json({"success": False, "error": "Not found."}, 404) | |
| def log_message(self, format, *args): | |
| return | |
| def main(): | |
| CACHE_ROOT.mkdir(parents=True, exist_ok=True) | |
| port = ensure_port(int(os.environ.get("PORT", "7860"))) | |
| server = ThreadingHTTPServer(("127.0.0.1", port), Handler) | |
| url = f"http://127.0.0.1:{port}" | |
| print(url, flush=True) | |
| try: | |
| webbrowser.open(url) | |
| except Exception: | |
| pass | |
| try: | |
| server.serve_forever() | |
| except KeyboardInterrupt: | |
| pass | |
| finally: | |
| server.server_close() | |
| if __name__ == "__main__": | |
| main() | |