Spaces:
Running
Running
Upload 88 files
Browse files- downloads/CompactAI Studio Setup 1.0.0.exe +2 -2
- downloads/CompactAI Studio-1.0.0.AppImage +2 -2
- downloads/index.html +0 -0
- downloads/interactive.py +337 -48
- interactive.py +2277 -0
- requirements.txt +1 -0
downloads/CompactAI Studio Setup 1.0.0.exe
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f685d561cb3d1e9b8f41bfea7b50c8d5bd0b72007000c7f70f63747127c5a57f
|
| 3 |
+
size 128
|
downloads/CompactAI Studio-1.0.0.AppImage
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa6103e15bfdc80bfea40471ef418e54ca4f6f8b6c90ec166072d821d93dfe3c
|
| 3 |
+
size 128
|
downloads/index.html
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
downloads/interactive.py
CHANGED
|
@@ -18,6 +18,8 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
|
|
| 18 |
from urllib.parse import quote, unquote, urlparse
|
| 19 |
from urllib.request import Request, urlopen
|
| 20 |
|
|
|
|
|
|
|
| 21 |
import torch
|
| 22 |
import torch.nn as nn
|
| 23 |
import torch.nn.functional as F
|
|
@@ -41,10 +43,17 @@ class ModelConfig:
|
|
| 41 |
seq_len: int = 2048
|
| 42 |
sliding_window_size: int = 512
|
| 43 |
mtp_horizons: Tuple[int, ...] = (2, 3, 4)
|
| 44 |
-
rope_fraction: float = 0.
|
| 45 |
embed_scale: bool = True
|
| 46 |
logit_soft_cap: float = -1.0
|
| 47 |
quantization: str = "nvfp4"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
@property
|
| 50 |
def head_dim(self) -> int:
|
|
@@ -55,16 +64,17 @@ model_config = ModelConfig()
|
|
| 55 |
|
| 56 |
MODEL_SERIES = {
|
| 57 |
"haiku": {
|
| 58 |
-
"dim":
|
| 59 |
-
"n_unique_layers":
|
| 60 |
-
"n_logical_layers":
|
| 61 |
"n_heads": 4,
|
| 62 |
"n_kv_heads": 2,
|
| 63 |
-
"ffn_dim":
|
| 64 |
"dropout": 0.0,
|
| 65 |
"seq_len": 2048,
|
| 66 |
"mtp_horizons": (2, 3, 4),
|
| 67 |
-
"
|
|
|
|
| 68 |
"grad_accum": 1,
|
| 69 |
"lr": 8e-4,
|
| 70 |
"min_lr": 1e-5,
|
|
@@ -74,29 +84,34 @@ MODEL_SERIES = {
|
|
| 74 |
"weight_decay": 0.02,
|
| 75 |
"pretrain_passes": 2,
|
| 76 |
"sft_passes": 3,
|
| 77 |
-
"max_sft_target_chars":
|
| 78 |
-
"use_grad_checkpoint":
|
| 79 |
-
"use_torch_compile": True,
|
| 80 |
"num_workers": 24,
|
| 81 |
"prefetch_factor": 64,
|
| 82 |
"shuffle_buffer": 8192,
|
| 83 |
"max_pretrain_tokens": 0,
|
| 84 |
"min_pretrain_tokens": 100_000_000,
|
| 85 |
"quantization": "nvfp4",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
},
|
| 87 |
"sonnet": {
|
| 88 |
-
"dim":
|
| 89 |
-
"n_unique_layers":
|
| 90 |
-
"n_logical_layers":
|
| 91 |
-
"n_heads":
|
| 92 |
"n_kv_heads": 4,
|
| 93 |
-
"ffn_dim":
|
| 94 |
"dropout": 0.0,
|
| 95 |
"seq_len": 2048,
|
| 96 |
"mtp_horizons": (2,),
|
| 97 |
-
"
|
|
|
|
| 98 |
"grad_accum": 1,
|
| 99 |
-
"lr":
|
| 100 |
"min_lr": 2e-5,
|
| 101 |
"sft_lr": 5e-5,
|
| 102 |
"sft_min_lr": 5e-6,
|
|
@@ -104,27 +119,32 @@ MODEL_SERIES = {
|
|
| 104 |
"weight_decay": 0.1,
|
| 105 |
"pretrain_passes": 1,
|
| 106 |
"sft_passes": 1,
|
| 107 |
-
"max_sft_target_chars":
|
| 108 |
"use_grad_checkpoint": True,
|
| 109 |
-
"use_torch_compile": True,
|
| 110 |
"num_workers": 32,
|
| 111 |
-
"prefetch_factor":
|
| 112 |
"shuffle_buffer": 16384,
|
| 113 |
"max_pretrain_tokens": 0,
|
| 114 |
-
"min_pretrain_tokens":
|
| 115 |
"quantization": "nvfp4",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
},
|
| 117 |
"opus": {
|
| 118 |
-
"dim":
|
| 119 |
-
"n_unique_layers":
|
| 120 |
-
"n_logical_layers":
|
| 121 |
"n_heads": 16,
|
| 122 |
"n_kv_heads": 4,
|
| 123 |
-
"ffn_dim":
|
| 124 |
"dropout": 0.0,
|
| 125 |
"seq_len": 2048,
|
| 126 |
"mtp_horizons": (2,),
|
| 127 |
-
"
|
|
|
|
| 128 |
"grad_accum": 1,
|
| 129 |
"lr": 1.6e-4,
|
| 130 |
"min_lr": 1.6e-5,
|
|
@@ -134,15 +154,19 @@ MODEL_SERIES = {
|
|
| 134 |
"weight_decay": 0.1,
|
| 135 |
"pretrain_passes": 1,
|
| 136 |
"sft_passes": 1,
|
| 137 |
-
"max_sft_target_chars":
|
| 138 |
"use_grad_checkpoint": True,
|
| 139 |
-
"use_torch_compile": True,
|
| 140 |
"num_workers": 48,
|
| 141 |
-
"prefetch_factor":
|
| 142 |
"shuffle_buffer": 16384,
|
| 143 |
"max_pretrain_tokens": 0,
|
| 144 |
-
"min_pretrain_tokens":
|
| 145 |
"quantization": "nvfp4",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
},
|
| 147 |
}
|
| 148 |
|
|
@@ -381,6 +405,10 @@ class CausalSelfAttention(nn.Module):
|
|
| 381 |
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
| 382 |
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
|
| 385 |
self.rope = RotaryEmbedding(self.rope_dim)
|
| 386 |
|
|
@@ -444,7 +472,7 @@ class CausalSelfAttention(nn.Module):
|
|
| 444 |
.reshape(B, self.n_heads, S, self.head_dim)
|
| 445 |
)
|
| 446 |
|
| 447 |
-
drop_p = self.dropout if self.training else 0.0
|
| 448 |
|
| 449 |
if is_global:
|
| 450 |
if past_kv is None and T > 1:
|
|
@@ -479,9 +507,184 @@ class SwiGLU(nn.Module):
|
|
| 479 |
self.down = nn.Linear(hidden_dim, dim, bias=False)
|
| 480 |
self.drop = nn.Dropout(dropout)
|
| 481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 483 |
h = F.silu(self.gate(x)) * self.up(x)
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
|
| 487 |
class TransformerBlock(nn.Module):
|
|
@@ -495,8 +698,14 @@ class TransformerBlock(nn.Module):
|
|
| 495 |
dropout: float,
|
| 496 |
sliding_window: int,
|
| 497 |
rope_fraction: float,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
) -> None:
|
| 499 |
super().__init__()
|
|
|
|
| 500 |
self.norm1 = RMSNorm(dim)
|
| 501 |
self.attn = CausalSelfAttention(
|
| 502 |
dim=dim,
|
|
@@ -509,6 +718,20 @@ class TransformerBlock(nn.Module):
|
|
| 509 |
)
|
| 510 |
self.norm2 = RMSNorm(dim)
|
| 511 |
self.ffn = SwiGLU(dim, ffn_dim, dropout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
def forward(
|
| 514 |
self,
|
|
@@ -516,13 +739,50 @@ class TransformerBlock(nn.Module):
|
|
| 516 |
is_global: bool,
|
| 517 |
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 518 |
use_cache: bool = False,
|
|
|
|
| 519 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
return x, new_kv
|
| 524 |
|
| 525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
class TinyMemoryLM(nn.Module):
|
| 527 |
def __init__(
|
| 528 |
self,
|
|
@@ -537,8 +797,13 @@ class TinyMemoryLM(nn.Module):
|
|
| 537 |
mtp_horizons: Sequence[int],
|
| 538 |
grad_checkpoint: bool,
|
| 539 |
sliding_window: int = 512,
|
| 540 |
-
rope_fraction: float = 0.
|
| 541 |
embed_scale: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
) -> None:
|
| 543 |
super().__init__()
|
| 544 |
self.dim = dim
|
|
@@ -565,6 +830,11 @@ class TinyMemoryLM(nn.Module):
|
|
| 565 |
dropout=dropout,
|
| 566 |
sliding_window=sliding_window,
|
| 567 |
rope_fraction=rope_fraction,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
)
|
| 569 |
for _ in range(n_unique_layers)
|
| 570 |
]
|
|
@@ -634,7 +904,7 @@ class TinyMemoryLM(nn.Module):
|
|
| 634 |
)
|
| 635 |
|
| 636 |
for layer_idx, (block, logical_idx) in enumerate(logical_layers):
|
| 637 |
-
is_global = logical_idx %
|
| 638 |
past_kv = (
|
| 639 |
past_key_values[layer_idx]
|
| 640 |
if past_key_values is not None and layer_idx < len(past_key_values)
|
|
@@ -643,17 +913,20 @@ class TinyMemoryLM(nn.Module):
|
|
| 643 |
|
| 644 |
if self.grad_checkpoint and self.training and not use_cache:
|
| 645 |
x, layer_kv = checkpoint(
|
| 646 |
-
block, x, is_global, past_kv, use_cache, use_reentrant=
|
| 647 |
)
|
| 648 |
else:
|
| 649 |
-
x, layer_kv = block(x, is_global, past_kv, use_cache)
|
| 650 |
|
| 651 |
if new_past_key_values is not None:
|
| 652 |
new_past_key_values.append(layer_kv)
|
| 653 |
|
| 654 |
x = self.norm(x)
|
| 655 |
h_out = x if return_hidden else None
|
| 656 |
-
logits = self.head(x)
|
|
|
|
|
|
|
|
|
|
| 657 |
|
| 658 |
mtp: Dict[int, torch.Tensor] = {}
|
| 659 |
if self.mtp_horizons and self.training:
|
|
@@ -662,7 +935,10 @@ class TinyMemoryLM(nn.Module):
|
|
| 662 |
shifted_h = x[:, :-horizon, :]
|
| 663 |
adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
|
| 664 |
adapted_h = self.mtp_norms[str(horizon)](adapted_h)
|
| 665 |
-
mtp_logits = self.head(adapted_h)
|
|
|
|
|
|
|
|
|
|
| 666 |
mtp[horizon] = mtp_logits
|
| 667 |
|
| 668 |
return logits, mtp, h_out, new_past_key_values
|
|
@@ -1462,6 +1738,14 @@ def load_bundle(repo_id: str, model_type: str) -> dict[str, object]:
|
|
| 1462 |
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
|
| 1463 |
cfg = series_config(series)
|
| 1464 |
vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1465 |
model = TinyMemoryLM(
|
| 1466 |
vocab_size=vocab_size,
|
| 1467 |
dim=int(cfg.get("dim", model_config.dim)),
|
|
@@ -1480,19 +1764,24 @@ def load_bundle(repo_id: str, model_type: str) -> dict[str, object]:
|
|
| 1480 |
),
|
| 1481 |
grad_checkpoint=False,
|
| 1482 |
sliding_window=int(
|
| 1483 |
-
cfg.get(
|
| 1484 |
-
"sliding_window_size",
|
| 1485 |
-
getattr(model_config, "sliding_window_size", 512),
|
| 1486 |
-
)
|
| 1487 |
),
|
| 1488 |
rope_fraction=float(
|
| 1489 |
-
cfg.get("rope_fraction",
|
| 1490 |
),
|
| 1491 |
embed_scale=bool(
|
| 1492 |
-
cfg.get("embed_scale",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1493 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1494 |
)
|
| 1495 |
-
state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
|
| 1496 |
model.load_state_dict(state_dict, strict=False)
|
| 1497 |
model.eval()
|
| 1498 |
if tokenizer.vocab_size > vocab_size:
|
|
@@ -1678,7 +1967,7 @@ def page_html() -> str:
|
|
| 1678 |
</div>
|
| 1679 |
<div class="meta">
|
| 1680 |
<span class="chip">Hugging Face: CompactAI</span>
|
| 1681 |
-
<span class="chip">
|
| 1682 |
<span class="chip">Local inference</span>
|
| 1683 |
</div>
|
| 1684 |
</div>
|
|
|
|
| 18 |
from urllib.parse import quote, unquote, urlparse
|
| 19 |
from urllib.request import Request, urlopen
|
| 20 |
|
| 21 |
+
import hashlib
|
| 22 |
+
|
| 23 |
import torch
|
| 24 |
import torch.nn as nn
|
| 25 |
import torch.nn.functional as F
|
|
|
|
| 43 |
seq_len: int = 2048
|
| 44 |
sliding_window_size: int = 512
|
| 45 |
mtp_horizons: Tuple[int, ...] = (2, 3, 4)
|
| 46 |
+
rope_fraction: float = 0.5
|
| 47 |
embed_scale: bool = True
|
| 48 |
logit_soft_cap: float = -1.0
|
| 49 |
quantization: str = "nvfp4"
|
| 50 |
+
# Engram (conditional memory) config
|
| 51 |
+
engram_dim: int = 0
|
| 52 |
+
engram_heads: int = 4
|
| 53 |
+
engram_table_size: int = 8192
|
| 54 |
+
engram_max_ngram: int = 3
|
| 55 |
+
# mHC (Manifold-Constrained Hyper-Connections) config
|
| 56 |
+
mhc_expansion: int = 1
|
| 57 |
|
| 58 |
@property
|
| 59 |
def head_dim(self) -> int:
|
|
|
|
| 64 |
|
| 65 |
MODEL_SERIES = {
|
| 66 |
"haiku": {
|
| 67 |
+
"dim": 64,
|
| 68 |
+
"n_unique_layers": 12,
|
| 69 |
+
"n_logical_layers": 24,
|
| 70 |
"n_heads": 4,
|
| 71 |
"n_kv_heads": 2,
|
| 72 |
+
"ffn_dim": 384,
|
| 73 |
"dropout": 0.0,
|
| 74 |
"seq_len": 2048,
|
| 75 |
"mtp_horizons": (2, 3, 4),
|
| 76 |
+
"rope_fraction": 0.5,
|
| 77 |
+
"batch_size": 80,
|
| 78 |
"grad_accum": 1,
|
| 79 |
"lr": 8e-4,
|
| 80 |
"min_lr": 1e-5,
|
|
|
|
| 84 |
"weight_decay": 0.02,
|
| 85 |
"pretrain_passes": 2,
|
| 86 |
"sft_passes": 3,
|
| 87 |
+
"max_sft_target_chars": 0,
|
| 88 |
+
"use_grad_checkpoint": True,
|
|
|
|
| 89 |
"num_workers": 24,
|
| 90 |
"prefetch_factor": 64,
|
| 91 |
"shuffle_buffer": 8192,
|
| 92 |
"max_pretrain_tokens": 0,
|
| 93 |
"min_pretrain_tokens": 100_000_000,
|
| 94 |
"quantization": "nvfp4",
|
| 95 |
+
"engram_dim": 8,
|
| 96 |
+
"engram_heads": 2,
|
| 97 |
+
"engram_table_size": 64,
|
| 98 |
+
"engram_max_ngram": 2,
|
| 99 |
+
"mhc_expansion": 2,
|
| 100 |
},
|
| 101 |
"sonnet": {
|
| 102 |
+
"dim": 1024,
|
| 103 |
+
"n_unique_layers": 20,
|
| 104 |
+
"n_logical_layers": 40,
|
| 105 |
+
"n_heads": 16,
|
| 106 |
"n_kv_heads": 4,
|
| 107 |
+
"ffn_dim": 4096,
|
| 108 |
"dropout": 0.0,
|
| 109 |
"seq_len": 2048,
|
| 110 |
"mtp_horizons": (2,),
|
| 111 |
+
"rope_fraction": 0.5,
|
| 112 |
+
"batch_size": 24,
|
| 113 |
"grad_accum": 1,
|
| 114 |
+
"lr": 1e-4,
|
| 115 |
"min_lr": 2e-5,
|
| 116 |
"sft_lr": 5e-5,
|
| 117 |
"sft_min_lr": 5e-6,
|
|
|
|
| 119 |
"weight_decay": 0.1,
|
| 120 |
"pretrain_passes": 1,
|
| 121 |
"sft_passes": 1,
|
| 122 |
+
"max_sft_target_chars": 0,
|
| 123 |
"use_grad_checkpoint": True,
|
|
|
|
| 124 |
"num_workers": 32,
|
| 125 |
+
"prefetch_factor": 64,
|
| 126 |
"shuffle_buffer": 16384,
|
| 127 |
"max_pretrain_tokens": 0,
|
| 128 |
+
"min_pretrain_tokens": 100_000_000,
|
| 129 |
"quantization": "nvfp4",
|
| 130 |
+
"engram_dim": 32,
|
| 131 |
+
"engram_heads": 8,
|
| 132 |
+
"engram_table_size": 4096,
|
| 133 |
+
"engram_max_ngram": 2,
|
| 134 |
+
"mhc_expansion": 2,
|
| 135 |
},
|
| 136 |
"opus": {
|
| 137 |
+
"dim": 1536,
|
| 138 |
+
"n_unique_layers": 18,
|
| 139 |
+
"n_logical_layers": 36,
|
| 140 |
"n_heads": 16,
|
| 141 |
"n_kv_heads": 4,
|
| 142 |
+
"ffn_dim": 5888,
|
| 143 |
"dropout": 0.0,
|
| 144 |
"seq_len": 2048,
|
| 145 |
"mtp_horizons": (2,),
|
| 146 |
+
"rope_fraction": 0.5,
|
| 147 |
+
"batch_size": 24,
|
| 148 |
"grad_accum": 1,
|
| 149 |
"lr": 1.6e-4,
|
| 150 |
"min_lr": 1.6e-5,
|
|
|
|
| 154 |
"weight_decay": 0.1,
|
| 155 |
"pretrain_passes": 1,
|
| 156 |
"sft_passes": 1,
|
| 157 |
+
"max_sft_target_chars": 0,
|
| 158 |
"use_grad_checkpoint": True,
|
|
|
|
| 159 |
"num_workers": 48,
|
| 160 |
+
"prefetch_factor": 64,
|
| 161 |
"shuffle_buffer": 16384,
|
| 162 |
"max_pretrain_tokens": 0,
|
| 163 |
+
"min_pretrain_tokens": 100_000_000,
|
| 164 |
"quantization": "nvfp4",
|
| 165 |
+
"engram_dim": 64,
|
| 166 |
+
"engram_heads": 8,
|
| 167 |
+
"engram_table_size": 8192,
|
| 168 |
+
"engram_max_ngram": 2,
|
| 169 |
+
"mhc_expansion": 4,
|
| 170 |
},
|
| 171 |
}
|
| 172 |
|
|
|
|
| 405 |
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
| 406 |
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
|
| 407 |
|
| 408 |
+
for lin in (self.wq, self.wk, self.wv):
|
| 409 |
+
nn.init.normal_(lin.weight, std=dim ** -0.5)
|
| 410 |
+
nn.init.normal_(self.wo.weight, std=(n_heads * head_dim) ** -0.5)
|
| 411 |
+
|
| 412 |
self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
|
| 413 |
self.rope = RotaryEmbedding(self.rope_dim)
|
| 414 |
|
|
|
|
| 472 |
.reshape(B, self.n_heads, S, self.head_dim)
|
| 473 |
)
|
| 474 |
|
| 475 |
+
drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0
|
| 476 |
|
| 477 |
if is_global:
|
| 478 |
if past_kv is None and T > 1:
|
|
|
|
| 507 |
self.down = nn.Linear(hidden_dim, dim, bias=False)
|
| 508 |
self.drop = nn.Dropout(dropout)
|
| 509 |
|
| 510 |
+
nn.init.normal_(self.gate.weight, std=dim ** -0.5)
|
| 511 |
+
nn.init.normal_(self.up.weight, std=dim ** -0.5)
|
| 512 |
+
nn.init.normal_(self.down.weight, std=hidden_dim ** -0.5)
|
| 513 |
+
|
| 514 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 515 |
h = F.silu(self.gate(x)) * self.up(x)
|
| 516 |
+
out = self.down(h)
|
| 517 |
+
if self.training and torch.is_grad_enabled():
|
| 518 |
+
out = self.drop(out)
|
| 519 |
+
return out
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class EngramBlock(nn.Module):
|
| 523 |
+
"""Conditional memory via O(1) hashed N-gram lookup (DeepSeek Engram)."""
|
| 524 |
+
|
| 525 |
+
def __init__(
|
| 526 |
+
self,
|
| 527 |
+
dim: int,
|
| 528 |
+
engram_dim: int,
|
| 529 |
+
n_heads: int = 4,
|
| 530 |
+
table_size: int = 8192,
|
| 531 |
+
max_ngram: int = 3,
|
| 532 |
+
) -> None:
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.dim = dim
|
| 535 |
+
self.engram_dim = engram_dim
|
| 536 |
+
self.n_heads = n_heads
|
| 537 |
+
self.table_size = table_size
|
| 538 |
+
self.max_ngram = max_ngram
|
| 539 |
+
|
| 540 |
+
self.embeddings = nn.ParameterDict()
|
| 541 |
+
for n in range(2, max_ngram + 1):
|
| 542 |
+
for k in range(n_heads):
|
| 543 |
+
self.embeddings[f"{n}_{k}"] = nn.Parameter(
|
| 544 |
+
torch.randn(table_size, engram_dim) * (engram_dim**-0.5)
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
for n in range(2, max_ngram + 1):
|
| 548 |
+
for k in range(n_heads):
|
| 549 |
+
seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16)
|
| 550 |
+
rng = torch.Generator().manual_seed(seed)
|
| 551 |
+
a = torch.randint(1, 2**31, (1,), generator=rng).item()
|
| 552 |
+
b = torch.randint(0, 2**31, (1,), generator=rng).item()
|
| 553 |
+
self.register_buffer(
|
| 554 |
+
f"hash_a_{n}_{k}", torch.tensor(a), persistent=False
|
| 555 |
+
)
|
| 556 |
+
self.register_buffer(
|
| 557 |
+
f"hash_b_{n}_{k}", torch.tensor(b), persistent=False
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
|
| 561 |
+
self.branch_conv = nn.Conv1d(
|
| 562 |
+
total_branch_dim,
|
| 563 |
+
total_branch_dim,
|
| 564 |
+
kernel_size=4,
|
| 565 |
+
dilation=max_ngram,
|
| 566 |
+
padding=0,
|
| 567 |
+
groups=total_branch_dim,
|
| 568 |
+
bias=True,
|
| 569 |
+
)
|
| 570 |
+
nn.init.zeros_(self.branch_conv.weight)
|
| 571 |
+
nn.init.zeros_(self.branch_conv.bias)
|
| 572 |
+
|
| 573 |
+
self.gate_query = nn.Linear(dim, engram_dim, bias=False)
|
| 574 |
+
self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False)
|
| 575 |
+
self.gate_value = nn.Linear(total_branch_dim, dim, bias=False)
|
| 576 |
+
self.gate_scale = engram_dim**-0.5
|
| 577 |
+
|
| 578 |
+
def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor:
|
| 579 |
+
a = getattr(self, f"hash_a_{n}_{k}")
|
| 580 |
+
b = getattr(self, f"hash_b_{n}_{k}")
|
| 581 |
+
B, T = token_ids.shape
|
| 582 |
+
padded = F.pad(token_ids, (n - 1, 0), value=0)
|
| 583 |
+
combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device)
|
| 584 |
+
for i in range(n):
|
| 585 |
+
combined = (combined * 31 + padded[:, i : i + T].long()) % self.table_size
|
| 586 |
+
return ((a * combined) ^ b) % self.table_size
|
| 587 |
+
|
| 588 |
+
def forward(
|
| 589 |
+
self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None
|
| 590 |
+
) -> torch.Tensor:
|
| 591 |
+
B, T, _ = hidden.shape
|
| 592 |
+
if token_ids is None:
|
| 593 |
+
token_ids = hidden.mean(dim=-1).long() % self.table_size
|
| 594 |
+
all_indices = []
|
| 595 |
+
all_tables = []
|
| 596 |
+
for n in range(2, self.max_ngram + 1):
|
| 597 |
+
for k in range(self.n_heads):
|
| 598 |
+
all_indices.append(self._hash_ngram(token_ids, n, k))
|
| 599 |
+
all_tables.append(self.embeddings[f"{n}_{k}"])
|
| 600 |
+
branch_outputs = [tbl[idx] for idx, tbl in zip(all_indices, all_tables)]
|
| 601 |
+
memory = torch.cat(branch_outputs, dim=-1)
|
| 602 |
+
conv_in = memory.transpose(1, 2)
|
| 603 |
+
conv_in = F.pad(
|
| 604 |
+
conv_in,
|
| 605 |
+
(self.branch_conv.dilation[0] * (self.branch_conv.kernel_size[0] - 1), 0),
|
| 606 |
+
)
|
| 607 |
+
conv_out = self.branch_conv(conv_in)
|
| 608 |
+
memory = conv_out.transpose(1, 2)
|
| 609 |
+
query = self.gate_query(hidden)
|
| 610 |
+
key = self.gate_key(memory)
|
| 611 |
+
gate = torch.sigmoid(
|
| 612 |
+
(query * key).sum(dim=-1, keepdim=True) * self.gate_scale
|
| 613 |
+
)
|
| 614 |
+
value = self.gate_value(memory)
|
| 615 |
+
return gate * value
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
|
| 619 |
+
M = torch.exp(logits.clamp(-10, 10))
|
| 620 |
+
for _ in range(n_iters):
|
| 621 |
+
M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10)
|
| 622 |
+
M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10)
|
| 623 |
+
return M
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class ManifoldHyperConnection(nn.Module):
|
| 627 |
+
"""Manifold-Constrained Hyper-Connections (mHC) residual wrapper."""
|
| 628 |
+
|
| 629 |
+
def __init__(self, dim: int, expansion: int = 2) -> None:
|
| 630 |
+
super().__init__()
|
| 631 |
+
self.dim = dim
|
| 632 |
+
self.expansion = expansion
|
| 633 |
+
n = expansion
|
| 634 |
+
|
| 635 |
+
self.bias_pre = nn.Parameter(torch.zeros(1, n))
|
| 636 |
+
self.bias_post = nn.Parameter(torch.zeros(1, n))
|
| 637 |
+
self.bias_res = nn.Parameter(torch.zeros(n, n))
|
| 638 |
+
|
| 639 |
+
self.theta_pre = nn.Linear(n * dim, n, bias=False)
|
| 640 |
+
self.theta_post = nn.Linear(n * dim, n, bias=False)
|
| 641 |
+
self.theta_res = nn.Linear(n * dim, n * n, bias=False)
|
| 642 |
+
|
| 643 |
+
self.alpha_pre = nn.Parameter(torch.tensor(0.0))
|
| 644 |
+
self.alpha_post = nn.Parameter(torch.tensor(0.0))
|
| 645 |
+
self.alpha_res = nn.Parameter(torch.tensor(0.0))
|
| 646 |
+
|
| 647 |
+
def _compute_mappings(
|
| 648 |
+
self, x_expanded: torch.Tensor
|
| 649 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 650 |
+
B, T, _ = x_expanded.shape
|
| 651 |
+
n = self.expansion
|
| 652 |
+
x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]])
|
| 653 |
+
d_pre = torch.tanh(self.theta_pre(x_norm))
|
| 654 |
+
d_post = torch.tanh(self.theta_post(x_norm))
|
| 655 |
+
d_res = self.theta_res(x_norm)
|
| 656 |
+
H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre)
|
| 657 |
+
H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post)
|
| 658 |
+
H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape(
|
| 659 |
+
B, T, n, n
|
| 660 |
+
)
|
| 661 |
+
H_res = _sinkhorn_knopp(H_res_raw)
|
| 662 |
+
return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res
|
| 663 |
+
|
| 664 |
+
def expand_stream(self, x: torch.Tensor) -> torch.Tensor:
|
| 665 |
+
return x.repeat(1, 1, self.expansion)
|
| 666 |
+
|
| 667 |
+
def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor:
|
| 668 |
+
B, T, _ = x_expanded.shape
|
| 669 |
+
return x_expanded.view(B, T, self.expansion, self.dim).mean(dim=-2)
|
| 670 |
+
|
| 671 |
+
def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor:
|
| 672 |
+
B, T, _ = x_expanded.shape
|
| 673 |
+
x_streams = x_expanded.view(B, T, self.expansion, self.dim)
|
| 674 |
+
return (H_pre @ x_streams).squeeze(-2)
|
| 675 |
+
|
| 676 |
+
def post_res_mix(
|
| 677 |
+
self,
|
| 678 |
+
layer_output: torch.Tensor,
|
| 679 |
+
x_expanded: torch.Tensor,
|
| 680 |
+
H_post: torch.Tensor,
|
| 681 |
+
H_res: torch.Tensor,
|
| 682 |
+
) -> torch.Tensor:
|
| 683 |
+
B, T, _ = x_expanded.shape
|
| 684 |
+
x_streams = x_expanded.view(B, T, self.expansion, self.dim)
|
| 685 |
+
mixed = torch.matmul(H_res, x_streams)
|
| 686 |
+
post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2))
|
| 687 |
+
return (mixed + post_out).reshape(B, T, self.expansion * self.dim)
|
| 688 |
|
| 689 |
|
| 690 |
class TransformerBlock(nn.Module):
|
|
|
|
| 698 |
dropout: float,
|
| 699 |
sliding_window: int,
|
| 700 |
rope_fraction: float,
|
| 701 |
+
engram_dim: int = 0,
|
| 702 |
+
engram_heads: int = 4,
|
| 703 |
+
engram_table_size: int = 8192,
|
| 704 |
+
engram_max_ngram: int = 3,
|
| 705 |
+
mhc_expansion: int = 1,
|
| 706 |
) -> None:
|
| 707 |
super().__init__()
|
| 708 |
+
self.dim = dim
|
| 709 |
self.norm1 = RMSNorm(dim)
|
| 710 |
self.attn = CausalSelfAttention(
|
| 711 |
dim=dim,
|
|
|
|
| 718 |
)
|
| 719 |
self.norm2 = RMSNorm(dim)
|
| 720 |
self.ffn = SwiGLU(dim, ffn_dim, dropout)
|
| 721 |
+
self.use_engram = engram_dim > 0
|
| 722 |
+
if self.use_engram:
|
| 723 |
+
self.engram = EngramBlock(
|
| 724 |
+
dim=dim,
|
| 725 |
+
engram_dim=engram_dim,
|
| 726 |
+
n_heads=engram_heads,
|
| 727 |
+
table_size=engram_table_size,
|
| 728 |
+
max_ngram=engram_max_ngram,
|
| 729 |
+
)
|
| 730 |
+
self.engram_norm = RMSNorm(dim)
|
| 731 |
+
self.use_mhc = mhc_expansion > 1
|
| 732 |
+
if self.use_mhc:
|
| 733 |
+
self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
|
| 734 |
+
self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
|
| 735 |
|
| 736 |
def forward(
|
| 737 |
self,
|
|
|
|
| 739 |
is_global: bool,
|
| 740 |
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 741 |
use_cache: bool = False,
|
| 742 |
+
token_ids: Optional[torch.Tensor] = None,
|
| 743 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 744 |
+
if self.use_mhc:
|
| 745 |
+
x_exp = self.mhc_attn.expand_stream(x)
|
| 746 |
+
H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp)
|
| 747 |
+
attn_in = self.mhc_attn.pre_mix(x_exp, H_pre)
|
| 748 |
+
attn_out, new_kv = self.attn(
|
| 749 |
+
self.norm1(attn_in), is_global, past_kv, use_cache
|
| 750 |
+
)
|
| 751 |
+
x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res)
|
| 752 |
+
if self.use_engram:
|
| 753 |
+
collapsed = self.mhc_attn.collapse_stream(x_exp)
|
| 754 |
+
collapsed = collapsed + self.engram(
|
| 755 |
+
self.engram_norm(collapsed), token_ids=token_ids
|
| 756 |
+
)
|
| 757 |
+
x_exp = self.mhc_attn.expand_stream(collapsed)
|
| 758 |
+
H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp)
|
| 759 |
+
ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2)
|
| 760 |
+
ffn_out = self.ffn(self.norm2(ffn_in))
|
| 761 |
+
x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2)
|
| 762 |
+
x = self.mhc_attn.collapse_stream(x_exp)
|
| 763 |
+
else:
|
| 764 |
+
attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
|
| 765 |
+
x = x + attn_out
|
| 766 |
+
if self.use_engram:
|
| 767 |
+
x = x + self.engram(self.engram_norm(x), token_ids=token_ids)
|
| 768 |
+
x = x + self.ffn(self.norm2(x))
|
| 769 |
return x, new_kv
|
| 770 |
|
| 771 |
|
| 772 |
+
def _detect_engram_dim(state_dict: dict) -> int:
|
| 773 |
+
for key in state_dict:
|
| 774 |
+
if ".engram." in key and ".embeddings." in key:
|
| 775 |
+
return state_dict[key].shape[-1]
|
| 776 |
+
return 0
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
def _detect_mhc_expansion(state_dict: dict) -> int:
|
| 780 |
+
for key, val in state_dict.items():
|
| 781 |
+
if ".mhc_attn.bias_pre" in key and val.dim() == 2:
|
| 782 |
+
return val.shape[-1]
|
| 783 |
+
return 1
|
| 784 |
+
|
| 785 |
+
|
| 786 |
class TinyMemoryLM(nn.Module):
|
| 787 |
def __init__(
|
| 788 |
self,
|
|
|
|
| 797 |
mtp_horizons: Sequence[int],
|
| 798 |
grad_checkpoint: bool,
|
| 799 |
sliding_window: int = 512,
|
| 800 |
+
rope_fraction: float = 0.5,
|
| 801 |
embed_scale: bool = True,
|
| 802 |
+
engram_dim: int = 0,
|
| 803 |
+
engram_heads: int = 4,
|
| 804 |
+
engram_table_size: int = 8192,
|
| 805 |
+
engram_max_ngram: int = 3,
|
| 806 |
+
mhc_expansion: int = 1,
|
| 807 |
) -> None:
|
| 808 |
super().__init__()
|
| 809 |
self.dim = dim
|
|
|
|
| 830 |
dropout=dropout,
|
| 831 |
sliding_window=sliding_window,
|
| 832 |
rope_fraction=rope_fraction,
|
| 833 |
+
engram_dim=engram_dim,
|
| 834 |
+
engram_heads=engram_heads,
|
| 835 |
+
engram_table_size=engram_table_size,
|
| 836 |
+
engram_max_ngram=engram_max_ngram,
|
| 837 |
+
mhc_expansion=mhc_expansion,
|
| 838 |
)
|
| 839 |
for _ in range(n_unique_layers)
|
| 840 |
]
|
|
|
|
| 904 |
)
|
| 905 |
|
| 906 |
for layer_idx, (block, logical_idx) in enumerate(logical_layers):
|
| 907 |
+
is_global = logical_idx % 2 == 0
|
| 908 |
past_kv = (
|
| 909 |
past_key_values[layer_idx]
|
| 910 |
if past_key_values is not None and layer_idx < len(past_key_values)
|
|
|
|
| 913 |
|
| 914 |
if self.grad_checkpoint and self.training and not use_cache:
|
| 915 |
x, layer_kv = checkpoint(
|
| 916 |
+
block, x, is_global, past_kv, use_cache, ids, use_reentrant=True
|
| 917 |
)
|
| 918 |
else:
|
| 919 |
+
x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
|
| 920 |
|
| 921 |
if new_past_key_values is not None:
|
| 922 |
new_past_key_values.append(layer_kv)
|
| 923 |
|
| 924 |
x = self.norm(x)
|
| 925 |
h_out = x if return_hidden else None
|
| 926 |
+
logits = self.head(x)
|
| 927 |
+
if self.embed_scale_factor != 1.0:
|
| 928 |
+
logits = logits / self.embed_scale_factor
|
| 929 |
+
logits = logits + self.output_bias
|
| 930 |
|
| 931 |
mtp: Dict[int, torch.Tensor] = {}
|
| 932 |
if self.mtp_horizons and self.training:
|
|
|
|
| 935 |
shifted_h = x[:, :-horizon, :]
|
| 936 |
adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
|
| 937 |
adapted_h = self.mtp_norms[str(horizon)](adapted_h)
|
| 938 |
+
mtp_logits = self.head(adapted_h)
|
| 939 |
+
if self.embed_scale_factor != 1.0:
|
| 940 |
+
mtp_logits = mtp_logits / self.embed_scale_factor
|
| 941 |
+
mtp_logits = mtp_logits + self.output_bias
|
| 942 |
mtp[horizon] = mtp_logits
|
| 943 |
|
| 944 |
return logits, mtp, h_out, new_past_key_values
|
|
|
|
| 1738 |
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
|
| 1739 |
cfg = series_config(series)
|
| 1740 |
vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
|
| 1741 |
+
state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
|
| 1742 |
+
# Auto-detect new arch features from checkpoint weights
|
| 1743 |
+
engram_dim = _detect_engram_dim(state_dict) or int(
|
| 1744 |
+
cfg.get("engram_dim", model_config.engram_dim)
|
| 1745 |
+
)
|
| 1746 |
+
mhc_expansion = _detect_mhc_expansion(state_dict) or int(
|
| 1747 |
+
cfg.get("mhc_expansion", model_config.mhc_expansion)
|
| 1748 |
+
)
|
| 1749 |
model = TinyMemoryLM(
|
| 1750 |
vocab_size=vocab_size,
|
| 1751 |
dim=int(cfg.get("dim", model_config.dim)),
|
|
|
|
| 1764 |
),
|
| 1765 |
grad_checkpoint=False,
|
| 1766 |
sliding_window=int(
|
| 1767 |
+
cfg.get("sliding_window_size", model_config.sliding_window_size)
|
|
|
|
|
|
|
|
|
|
| 1768 |
),
|
| 1769 |
rope_fraction=float(
|
| 1770 |
+
cfg.get("rope_fraction", model_config.rope_fraction)
|
| 1771 |
),
|
| 1772 |
embed_scale=bool(
|
| 1773 |
+
cfg.get("embed_scale", model_config.embed_scale)
|
| 1774 |
+
),
|
| 1775 |
+
engram_dim=engram_dim,
|
| 1776 |
+
engram_heads=int(cfg.get("engram_heads", model_config.engram_heads)),
|
| 1777 |
+
engram_table_size=int(
|
| 1778 |
+
cfg.get("engram_table_size", model_config.engram_table_size)
|
| 1779 |
),
|
| 1780 |
+
engram_max_ngram=int(
|
| 1781 |
+
cfg.get("engram_max_ngram", model_config.engram_max_ngram)
|
| 1782 |
+
),
|
| 1783 |
+
mhc_expansion=mhc_expansion,
|
| 1784 |
)
|
|
|
|
| 1785 |
model.load_state_dict(state_dict, strict=False)
|
| 1786 |
model.eval()
|
| 1787 |
if tokenizer.vocab_size > vocab_size:
|
|
|
|
| 1967 |
</div>
|
| 1968 |
<div class="meta">
|
| 1969 |
<span class="chip">Hugging Face: CompactAI</span>
|
| 1970 |
+
<span class="chip">pip install -r requirements.txt</span>
|
| 1971 |
<span class="chip">Local inference</span>
|
| 1972 |
</div>
|
| 1973 |
</div>
|
interactive.py
ADDED
|
@@ -0,0 +1,2277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import shutil
|
| 9 |
+
import socket
|
| 10 |
+
import string
|
| 11 |
+
import sys
|
| 12 |
+
import threading
|
| 13 |
+
import webbrowser
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
|
| 18 |
+
from urllib.parse import quote, unquote, urlparse
|
| 19 |
+
from urllib.request import Request, urlopen
|
| 20 |
+
|
| 21 |
+
import hashlib
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from torch.utils.checkpoint import checkpoint
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Config (from ailay.config)
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class ModelConfig:
|
| 36 |
+
dim: int = 128
|
| 37 |
+
n_unique_layers: int = 8
|
| 38 |
+
n_logical_layers: int = 16
|
| 39 |
+
n_heads: int = 4
|
| 40 |
+
n_kv_heads: int = 2
|
| 41 |
+
ffn_dim: int = 224
|
| 42 |
+
dropout: float = 0.0
|
| 43 |
+
seq_len: int = 2048
|
| 44 |
+
sliding_window_size: int = 512
|
| 45 |
+
mtp_horizons: Tuple[int, ...] = (2, 3, 4)
|
| 46 |
+
rope_fraction: float = 0.5
|
| 47 |
+
embed_scale: bool = True
|
| 48 |
+
logit_soft_cap: float = -1.0
|
| 49 |
+
quantization: str = "nvfp4"
|
| 50 |
+
# Engram (conditional memory) config
|
| 51 |
+
engram_dim: int = 0
|
| 52 |
+
engram_heads: int = 4
|
| 53 |
+
engram_table_size: int = 8192
|
| 54 |
+
engram_max_ngram: int = 3
|
| 55 |
+
# mHC (Manifold-Constrained Hyper-Connections) config
|
| 56 |
+
mhc_expansion: int = 1
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def head_dim(self) -> int:
|
| 60 |
+
return self.dim // self.n_heads
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
model_config = ModelConfig()
|
| 64 |
+
|
| 65 |
+
MODEL_SERIES = {
|
| 66 |
+
"haiku": {
|
| 67 |
+
"dim": 64,
|
| 68 |
+
"n_unique_layers": 12,
|
| 69 |
+
"n_logical_layers": 24,
|
| 70 |
+
"n_heads": 4,
|
| 71 |
+
"n_kv_heads": 2,
|
| 72 |
+
"ffn_dim": 384,
|
| 73 |
+
"dropout": 0.0,
|
| 74 |
+
"seq_len": 2048,
|
| 75 |
+
"mtp_horizons": (2, 3, 4),
|
| 76 |
+
"rope_fraction": 0.5,
|
| 77 |
+
"batch_size": 80,
|
| 78 |
+
"grad_accum": 1,
|
| 79 |
+
"lr": 8e-4,
|
| 80 |
+
"min_lr": 1e-5,
|
| 81 |
+
"sft_lr": 2e-4,
|
| 82 |
+
"sft_min_lr": 1e-5,
|
| 83 |
+
"warmup_steps": 300,
|
| 84 |
+
"weight_decay": 0.02,
|
| 85 |
+
"pretrain_passes": 2,
|
| 86 |
+
"sft_passes": 3,
|
| 87 |
+
"max_sft_target_chars": 0,
|
| 88 |
+
"use_grad_checkpoint": True,
|
| 89 |
+
"num_workers": 24,
|
| 90 |
+
"prefetch_factor": 64,
|
| 91 |
+
"shuffle_buffer": 8192,
|
| 92 |
+
"max_pretrain_tokens": 0,
|
| 93 |
+
"min_pretrain_tokens": 100_000_000,
|
| 94 |
+
"quantization": "nvfp4",
|
| 95 |
+
"engram_dim": 8,
|
| 96 |
+
"engram_heads": 2,
|
| 97 |
+
"engram_table_size": 64,
|
| 98 |
+
"engram_max_ngram": 2,
|
| 99 |
+
"mhc_expansion": 2,
|
| 100 |
+
},
|
| 101 |
+
"sonnet": {
|
| 102 |
+
"dim": 1024,
|
| 103 |
+
"n_unique_layers": 20,
|
| 104 |
+
"n_logical_layers": 40,
|
| 105 |
+
"n_heads": 16,
|
| 106 |
+
"n_kv_heads": 4,
|
| 107 |
+
"ffn_dim": 4096,
|
| 108 |
+
"dropout": 0.0,
|
| 109 |
+
"seq_len": 2048,
|
| 110 |
+
"mtp_horizons": (2,),
|
| 111 |
+
"rope_fraction": 0.5,
|
| 112 |
+
"batch_size": 24,
|
| 113 |
+
"grad_accum": 1,
|
| 114 |
+
"lr": 1e-4,
|
| 115 |
+
"min_lr": 2e-5,
|
| 116 |
+
"sft_lr": 5e-5,
|
| 117 |
+
"sft_min_lr": 5e-6,
|
| 118 |
+
"warmup_steps": 250,
|
| 119 |
+
"weight_decay": 0.1,
|
| 120 |
+
"pretrain_passes": 1,
|
| 121 |
+
"sft_passes": 1,
|
| 122 |
+
"max_sft_target_chars": 0,
|
| 123 |
+
"use_grad_checkpoint": True,
|
| 124 |
+
"num_workers": 32,
|
| 125 |
+
"prefetch_factor": 64,
|
| 126 |
+
"shuffle_buffer": 16384,
|
| 127 |
+
"max_pretrain_tokens": 0,
|
| 128 |
+
"min_pretrain_tokens": 100_000_000,
|
| 129 |
+
"quantization": "nvfp4",
|
| 130 |
+
"engram_dim": 32,
|
| 131 |
+
"engram_heads": 8,
|
| 132 |
+
"engram_table_size": 4096,
|
| 133 |
+
"engram_max_ngram": 2,
|
| 134 |
+
"mhc_expansion": 2,
|
| 135 |
+
},
|
| 136 |
+
"opus": {
|
| 137 |
+
"dim": 1536,
|
| 138 |
+
"n_unique_layers": 18,
|
| 139 |
+
"n_logical_layers": 36,
|
| 140 |
+
"n_heads": 16,
|
| 141 |
+
"n_kv_heads": 4,
|
| 142 |
+
"ffn_dim": 5888,
|
| 143 |
+
"dropout": 0.0,
|
| 144 |
+
"seq_len": 2048,
|
| 145 |
+
"mtp_horizons": (2,),
|
| 146 |
+
"rope_fraction": 0.5,
|
| 147 |
+
"batch_size": 24,
|
| 148 |
+
"grad_accum": 1,
|
| 149 |
+
"lr": 1.6e-4,
|
| 150 |
+
"min_lr": 1.6e-5,
|
| 151 |
+
"sft_lr": 3e-5,
|
| 152 |
+
"sft_min_lr": 3e-6,
|
| 153 |
+
"warmup_steps": 200,
|
| 154 |
+
"weight_decay": 0.1,
|
| 155 |
+
"pretrain_passes": 1,
|
| 156 |
+
"sft_passes": 1,
|
| 157 |
+
"max_sft_target_chars": 0,
|
| 158 |
+
"use_grad_checkpoint": True,
|
| 159 |
+
"num_workers": 48,
|
| 160 |
+
"prefetch_factor": 64,
|
| 161 |
+
"shuffle_buffer": 16384,
|
| 162 |
+
"max_pretrain_tokens": 0,
|
| 163 |
+
"min_pretrain_tokens": 100_000_000,
|
| 164 |
+
"quantization": "nvfp4",
|
| 165 |
+
"engram_dim": 64,
|
| 166 |
+
"engram_heads": 8,
|
| 167 |
+
"engram_table_size": 8192,
|
| 168 |
+
"engram_max_ngram": 2,
|
| 169 |
+
"mhc_expansion": 4,
|
| 170 |
+
},
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# Tokenizer (from ailay.tokenizer)
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
FORMAT_TOKENS = [
|
| 179 |
+
"<|user|>",
|
| 180 |
+
"<|assistant|>",
|
| 181 |
+
"<|system|>",
|
| 182 |
+
"<|start_header_id|>",
|
| 183 |
+
"<|end_header_id|>",
|
| 184 |
+
"<|begin_of_thought|>",
|
| 185 |
+
"<|end_of_thought|>",
|
| 186 |
+
"<|begin_of_solution|>",
|
| 187 |
+
"<|end_of_solution|>",
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class WordTokenizer:
|
| 192 |
+
WORD_RE = re.compile(
|
| 193 |
+
r"\s+|[^\W\d_]+(?:['\u2019][^\W\d_]+)?|\d+|[^\w\s]+", re.UNICODE
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def __init__(
|
| 197 |
+
self, extra_chars: str = "", format_tokens: Optional[List[str]] = None
|
| 198 |
+
) -> None:
|
| 199 |
+
base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r"
|
| 200 |
+
fallback_chars = sorted(set(base + extra_chars))
|
| 201 |
+
self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]
|
| 202 |
+
self.format_tokens = (
|
| 203 |
+
list(format_tokens) if format_tokens else list(FORMAT_TOKENS)
|
| 204 |
+
)
|
| 205 |
+
self.special = list(self.core_special) + list(self.format_tokens)
|
| 206 |
+
self.id_to_token: List[str] = (
|
| 207 |
+
list(self.core_special) + self.format_tokens + fallback_chars
|
| 208 |
+
)
|
| 209 |
+
self.token_to_id: Dict[str, int] = {
|
| 210 |
+
t: i for i, t in enumerate(self.id_to_token)
|
| 211 |
+
}
|
| 212 |
+
self.special_multi_tokens = sorted(
|
| 213 |
+
[t for t in self.special if len(t) > 1], key=len, reverse=True
|
| 214 |
+
)
|
| 215 |
+
self.multi_char_tokens = self.special_multi_tokens
|
| 216 |
+
self.dynamic_additions = 0
|
| 217 |
+
|
| 218 |
+
@property
|
| 219 |
+
def pad_id(self) -> int:
|
| 220 |
+
return self.token_to_id["<PAD>"]
|
| 221 |
+
|
| 222 |
+
@property
|
| 223 |
+
def bos_id(self) -> int:
|
| 224 |
+
return self.token_to_id["<BOS>"]
|
| 225 |
+
|
| 226 |
+
@property
|
| 227 |
+
def eos_id(self) -> int:
|
| 228 |
+
return self.token_to_id["<EOS>"]
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def unk_id(self) -> int:
|
| 232 |
+
return self.token_to_id["<UNK>"]
|
| 233 |
+
|
| 234 |
+
@property
|
| 235 |
+
def vocab_size(self) -> int:
|
| 236 |
+
return len(self.id_to_token)
|
| 237 |
+
|
| 238 |
+
def maybe_add_char(self, ch: str) -> bool:
|
| 239 |
+
if ch in self.token_to_id:
|
| 240 |
+
return False
|
| 241 |
+
self.token_to_id[ch] = len(self.id_to_token)
|
| 242 |
+
self.id_to_token.append(ch)
|
| 243 |
+
self.dynamic_additions += 1
|
| 244 |
+
return True
|
| 245 |
+
|
| 246 |
+
def maybe_add_token(self, token: str) -> bool:
|
| 247 |
+
if token in self.token_to_id:
|
| 248 |
+
return False
|
| 249 |
+
self.token_to_id[token] = len(self.id_to_token)
|
| 250 |
+
self.id_to_token.append(token)
|
| 251 |
+
self.dynamic_additions += 1
|
| 252 |
+
return True
|
| 253 |
+
|
| 254 |
+
def iter_lexical_tokens(self, text: str) -> Iterator[str]:
|
| 255 |
+
i = 0
|
| 256 |
+
n = len(text)
|
| 257 |
+
while i < n:
|
| 258 |
+
matched_special = False
|
| 259 |
+
for token in self.special_multi_tokens:
|
| 260 |
+
if text.startswith(token, i):
|
| 261 |
+
yield token
|
| 262 |
+
i += len(token)
|
| 263 |
+
matched_special = True
|
| 264 |
+
break
|
| 265 |
+
if matched_special:
|
| 266 |
+
continue
|
| 267 |
+
m = self.WORD_RE.match(text, i)
|
| 268 |
+
if m is None:
|
| 269 |
+
yield text[i]
|
| 270 |
+
i += 1
|
| 271 |
+
continue
|
| 272 |
+
tok = m.group(0)
|
| 273 |
+
yield tok
|
| 274 |
+
i = m.end()
|
| 275 |
+
|
| 276 |
+
def encode(
|
| 277 |
+
self, text: str, add_bos: bool = False, add_eos: bool = False
|
| 278 |
+
) -> List[int]:
|
| 279 |
+
out: List[int] = []
|
| 280 |
+
if add_bos:
|
| 281 |
+
out.append(self.bos_id)
|
| 282 |
+
unk = self.unk_id
|
| 283 |
+
t2i = self.token_to_id
|
| 284 |
+
for tok in self.iter_lexical_tokens(text):
|
| 285 |
+
tid = t2i.get(tok)
|
| 286 |
+
if tid is not None:
|
| 287 |
+
out.append(tid)
|
| 288 |
+
continue
|
| 289 |
+
for ch in tok:
|
| 290 |
+
out.append(t2i.get(ch, unk))
|
| 291 |
+
if add_eos:
|
| 292 |
+
out.append(self.eos_id)
|
| 293 |
+
return out
|
| 294 |
+
|
| 295 |
+
def decode(self, ids: Sequence[int], skip_special: bool = True) -> str:
|
| 296 |
+
pieces: List[str] = []
|
| 297 |
+
for idx in ids:
|
| 298 |
+
if int(idx) < 0 or int(idx) >= len(self.id_to_token):
|
| 299 |
+
continue
|
| 300 |
+
tok = self.id_to_token[int(idx)]
|
| 301 |
+
if skip_special and tok in self.special:
|
| 302 |
+
continue
|
| 303 |
+
pieces.append(tok)
|
| 304 |
+
return "".join(pieces)
|
| 305 |
+
|
| 306 |
+
def save(self, path: Path) -> None:
|
| 307 |
+
with path.open("w", encoding="utf-8") as f:
|
| 308 |
+
json.dump(
|
| 309 |
+
{
|
| 310 |
+
"id_to_token": self.id_to_token,
|
| 311 |
+
"format_tokens": self.format_tokens,
|
| 312 |
+
"core_special": self.core_special,
|
| 313 |
+
"tokenizer_type": "word_level_v1",
|
| 314 |
+
},
|
| 315 |
+
f,
|
| 316 |
+
ensure_ascii=False,
|
| 317 |
+
indent=2,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
@classmethod
|
| 321 |
+
def load(cls, path: Path) -> WordTokenizer:
|
| 322 |
+
with path.open("r", encoding="utf-8") as f:
|
| 323 |
+
data = json.load(f)
|
| 324 |
+
format_tokens = data.get("format_tokens", FORMAT_TOKENS)
|
| 325 |
+
tokenizer = cls(extra_chars="", format_tokens=format_tokens)
|
| 326 |
+
tokenizer.id_to_token = data["id_to_token"]
|
| 327 |
+
tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)}
|
| 328 |
+
tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens)
|
| 329 |
+
tokenizer.special_multi_tokens = sorted(
|
| 330 |
+
[t for t in tokenizer.special if len(t) > 1], key=len, reverse=True
|
| 331 |
+
)
|
| 332 |
+
tokenizer.multi_char_tokens = tokenizer.special_multi_tokens
|
| 333 |
+
return tokenizer
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
LetterTokenizer = WordTokenizer
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# ---------------------------------------------------------------------------
|
| 340 |
+
# Model (from ailay.model)
|
| 341 |
+
# ---------------------------------------------------------------------------
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class RMSNorm(nn.Module):
|
| 345 |
+
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
| 346 |
+
super().__init__()
|
| 347 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 348 |
+
self.eps = eps
|
| 349 |
+
|
| 350 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 351 |
+
if hasattr(torch.nn.functional, "rms_norm"):
|
| 352 |
+
return torch.nn.functional.rms_norm(
|
| 353 |
+
x, self.weight.shape, self.weight, self.eps
|
| 354 |
+
)
|
| 355 |
+
x_fp = x.float()
|
| 356 |
+
rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 357 |
+
return (x_fp * rms).to(dtype=x.dtype) * self.weight
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class RotaryEmbedding(nn.Module):
|
| 361 |
+
def __init__(self, dim: int, base: float = 10000.0) -> None:
|
| 362 |
+
super().__init__()
|
| 363 |
+
inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 364 |
+
self.register_buffer("inv_freq", inv, persistent=False)
|
| 365 |
+
|
| 366 |
+
def cos_sin(
|
| 367 |
+
self, seq_len: int, device: torch.device, dtype: torch.dtype
|
| 368 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 369 |
+
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
| 370 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 371 |
+
emb = torch.cat([freqs, freqs], dim=-1)
|
| 372 |
+
cos = emb.cos()[None, None, :, :].to(dtype=dtype)
|
| 373 |
+
sin = emb.sin()[None, None, :, :].to(dtype=dtype)
|
| 374 |
+
return cos, sin
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 378 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 379 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 380 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class CausalSelfAttention(nn.Module):
|
| 384 |
+
def __init__(
|
| 385 |
+
self,
|
| 386 |
+
dim: int,
|
| 387 |
+
n_heads: int,
|
| 388 |
+
n_kv_heads: int,
|
| 389 |
+
head_dim: int,
|
| 390 |
+
dropout: float,
|
| 391 |
+
sliding_window: int,
|
| 392 |
+
rope_fraction: float,
|
| 393 |
+
) -> None:
|
| 394 |
+
super().__init__()
|
| 395 |
+
self.dim = dim
|
| 396 |
+
self.n_heads = n_heads
|
| 397 |
+
self.n_kv_heads = n_kv_heads
|
| 398 |
+
self.head_dim = head_dim
|
| 399 |
+
self.n_rep = n_heads // n_kv_heads
|
| 400 |
+
self.dropout = dropout
|
| 401 |
+
self.sliding_window = sliding_window
|
| 402 |
+
|
| 403 |
+
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
|
| 404 |
+
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
| 405 |
+
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
| 406 |
+
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
|
| 407 |
+
|
| 408 |
+
for lin in (self.wq, self.wk, self.wv):
|
| 409 |
+
nn.init.normal_(lin.weight, std=dim ** -0.5)
|
| 410 |
+
nn.init.normal_(self.wo.weight, std=(n_heads * head_dim) ** -0.5)
|
| 411 |
+
|
| 412 |
+
self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
|
| 413 |
+
self.rope = RotaryEmbedding(self.rope_dim)
|
| 414 |
+
|
| 415 |
+
self.q_norm = RMSNorm(head_dim)
|
| 416 |
+
self.k_norm = RMSNorm(head_dim)
|
| 417 |
+
|
| 418 |
+
self.output_gate = nn.Parameter(torch.zeros(n_heads))
|
| 419 |
+
|
| 420 |
+
def forward(
|
| 421 |
+
self,
|
| 422 |
+
x: torch.Tensor,
|
| 423 |
+
is_global: bool,
|
| 424 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 425 |
+
use_cache: bool = False,
|
| 426 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 427 |
+
B, T, _ = x.shape
|
| 428 |
+
|
| 429 |
+
q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
|
| 430 |
+
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
|
| 431 |
+
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
|
| 432 |
+
|
| 433 |
+
q = self.q_norm(q)
|
| 434 |
+
k = self.k_norm(k)
|
| 435 |
+
|
| 436 |
+
q = q.transpose(1, 2)
|
| 437 |
+
k = k.transpose(1, 2)
|
| 438 |
+
v = v.transpose(1, 2)
|
| 439 |
+
|
| 440 |
+
past_len = past_kv[0].shape[2] if past_kv is not None else 0
|
| 441 |
+
cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype)
|
| 442 |
+
cos_slice = cos[:, :, past_len : past_len + T, :]
|
| 443 |
+
sin_slice = sin[:, :, past_len : past_len + T, :]
|
| 444 |
+
|
| 445 |
+
q_rope = q[..., : self.rope_dim]
|
| 446 |
+
q_pass = q[..., self.rope_dim :]
|
| 447 |
+
k_rope = k[..., : self.rope_dim]
|
| 448 |
+
k_pass = k[..., self.rope_dim :]
|
| 449 |
+
|
| 450 |
+
q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice)
|
| 451 |
+
k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice)
|
| 452 |
+
|
| 453 |
+
q = torch.cat([q_rope, q_pass], dim=-1)
|
| 454 |
+
k = torch.cat([k_rope, k_pass], dim=-1)
|
| 455 |
+
|
| 456 |
+
if past_kv is not None:
|
| 457 |
+
k = torch.cat([past_kv[0], k], dim=2)
|
| 458 |
+
v = torch.cat([past_kv[1], v], dim=2)
|
| 459 |
+
|
| 460 |
+
new_kv = (k, v) if use_cache else None
|
| 461 |
+
|
| 462 |
+
S = k.shape[2]
|
| 463 |
+
if self.n_rep > 1:
|
| 464 |
+
k = (
|
| 465 |
+
k[:, :, None, :, :]
|
| 466 |
+
.expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
|
| 467 |
+
.reshape(B, self.n_heads, S, self.head_dim)
|
| 468 |
+
)
|
| 469 |
+
v = (
|
| 470 |
+
v[:, :, None, :, :]
|
| 471 |
+
.expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
|
| 472 |
+
.reshape(B, self.n_heads, S, self.head_dim)
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0
|
| 476 |
+
|
| 477 |
+
if is_global:
|
| 478 |
+
if past_kv is None and T > 1:
|
| 479 |
+
out = F.scaled_dot_product_attention(
|
| 480 |
+
q, k, v, is_causal=True, dropout_p=drop_p
|
| 481 |
+
)
|
| 482 |
+
else:
|
| 483 |
+
out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p)
|
| 484 |
+
else:
|
| 485 |
+
T_q = q.shape[2]
|
| 486 |
+
q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1)
|
| 487 |
+
k_pos = torch.arange(S, device=q.device).unsqueeze(0)
|
| 488 |
+
mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window)
|
| 489 |
+
out = F.scaled_dot_product_attention(
|
| 490 |
+
q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1)
|
| 494 |
+
out = out * gate
|
| 495 |
+
|
| 496 |
+
out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
|
| 497 |
+
out = self.wo(out)
|
| 498 |
+
|
| 499 |
+
return out, new_kv
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class SwiGLU(nn.Module):
|
| 503 |
+
def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
|
| 504 |
+
super().__init__()
|
| 505 |
+
self.gate = nn.Linear(dim, hidden_dim, bias=False)
|
| 506 |
+
self.up = nn.Linear(dim, hidden_dim, bias=False)
|
| 507 |
+
self.down = nn.Linear(hidden_dim, dim, bias=False)
|
| 508 |
+
self.drop = nn.Dropout(dropout)
|
| 509 |
+
|
| 510 |
+
nn.init.normal_(self.gate.weight, std=dim ** -0.5)
|
| 511 |
+
nn.init.normal_(self.up.weight, std=dim ** -0.5)
|
| 512 |
+
nn.init.normal_(self.down.weight, std=hidden_dim ** -0.5)
|
| 513 |
+
|
| 514 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 515 |
+
h = F.silu(self.gate(x)) * self.up(x)
|
| 516 |
+
out = self.down(h)
|
| 517 |
+
if self.training and torch.is_grad_enabled():
|
| 518 |
+
out = self.drop(out)
|
| 519 |
+
return out
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class EngramBlock(nn.Module):
|
| 523 |
+
"""Conditional memory via O(1) hashed N-gram lookup (DeepSeek Engram)."""
|
| 524 |
+
|
| 525 |
+
def __init__(
|
| 526 |
+
self,
|
| 527 |
+
dim: int,
|
| 528 |
+
engram_dim: int,
|
| 529 |
+
n_heads: int = 4,
|
| 530 |
+
table_size: int = 8192,
|
| 531 |
+
max_ngram: int = 3,
|
| 532 |
+
) -> None:
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.dim = dim
|
| 535 |
+
self.engram_dim = engram_dim
|
| 536 |
+
self.n_heads = n_heads
|
| 537 |
+
self.table_size = table_size
|
| 538 |
+
self.max_ngram = max_ngram
|
| 539 |
+
|
| 540 |
+
self.embeddings = nn.ParameterDict()
|
| 541 |
+
for n in range(2, max_ngram + 1):
|
| 542 |
+
for k in range(n_heads):
|
| 543 |
+
self.embeddings[f"{n}_{k}"] = nn.Parameter(
|
| 544 |
+
torch.randn(table_size, engram_dim) * (engram_dim**-0.5)
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
for n in range(2, max_ngram + 1):
|
| 548 |
+
for k in range(n_heads):
|
| 549 |
+
seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16)
|
| 550 |
+
rng = torch.Generator().manual_seed(seed)
|
| 551 |
+
a = torch.randint(1, 2**31, (1,), generator=rng).item()
|
| 552 |
+
b = torch.randint(0, 2**31, (1,), generator=rng).item()
|
| 553 |
+
self.register_buffer(
|
| 554 |
+
f"hash_a_{n}_{k}", torch.tensor(a), persistent=False
|
| 555 |
+
)
|
| 556 |
+
self.register_buffer(
|
| 557 |
+
f"hash_b_{n}_{k}", torch.tensor(b), persistent=False
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
|
| 561 |
+
self.branch_conv = nn.Conv1d(
|
| 562 |
+
total_branch_dim,
|
| 563 |
+
total_branch_dim,
|
| 564 |
+
kernel_size=4,
|
| 565 |
+
dilation=max_ngram,
|
| 566 |
+
padding=0,
|
| 567 |
+
groups=total_branch_dim,
|
| 568 |
+
bias=True,
|
| 569 |
+
)
|
| 570 |
+
nn.init.zeros_(self.branch_conv.weight)
|
| 571 |
+
nn.init.zeros_(self.branch_conv.bias)
|
| 572 |
+
|
| 573 |
+
self.gate_query = nn.Linear(dim, engram_dim, bias=False)
|
| 574 |
+
self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False)
|
| 575 |
+
self.gate_value = nn.Linear(total_branch_dim, dim, bias=False)
|
| 576 |
+
self.gate_scale = engram_dim**-0.5
|
| 577 |
+
|
| 578 |
+
def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor:
|
| 579 |
+
a = getattr(self, f"hash_a_{n}_{k}")
|
| 580 |
+
b = getattr(self, f"hash_b_{n}_{k}")
|
| 581 |
+
B, T = token_ids.shape
|
| 582 |
+
padded = F.pad(token_ids, (n - 1, 0), value=0)
|
| 583 |
+
combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device)
|
| 584 |
+
for i in range(n):
|
| 585 |
+
combined = (combined * 31 + padded[:, i : i + T].long()) % self.table_size
|
| 586 |
+
return ((a * combined) ^ b) % self.table_size
|
| 587 |
+
|
| 588 |
+
def forward(
|
| 589 |
+
self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None
|
| 590 |
+
) -> torch.Tensor:
|
| 591 |
+
B, T, _ = hidden.shape
|
| 592 |
+
if token_ids is None:
|
| 593 |
+
token_ids = hidden.mean(dim=-1).long() % self.table_size
|
| 594 |
+
all_indices = []
|
| 595 |
+
all_tables = []
|
| 596 |
+
for n in range(2, self.max_ngram + 1):
|
| 597 |
+
for k in range(self.n_heads):
|
| 598 |
+
all_indices.append(self._hash_ngram(token_ids, n, k))
|
| 599 |
+
all_tables.append(self.embeddings[f"{n}_{k}"])
|
| 600 |
+
branch_outputs = [tbl[idx] for idx, tbl in zip(all_indices, all_tables)]
|
| 601 |
+
memory = torch.cat(branch_outputs, dim=-1)
|
| 602 |
+
conv_in = memory.transpose(1, 2)
|
| 603 |
+
conv_in = F.pad(
|
| 604 |
+
conv_in,
|
| 605 |
+
(self.branch_conv.dilation[0] * (self.branch_conv.kernel_size[0] - 1), 0),
|
| 606 |
+
)
|
| 607 |
+
conv_out = self.branch_conv(conv_in)
|
| 608 |
+
memory = conv_out.transpose(1, 2)
|
| 609 |
+
query = self.gate_query(hidden)
|
| 610 |
+
key = self.gate_key(memory)
|
| 611 |
+
gate = torch.sigmoid(
|
| 612 |
+
(query * key).sum(dim=-1, keepdim=True) * self.gate_scale
|
| 613 |
+
)
|
| 614 |
+
value = self.gate_value(memory)
|
| 615 |
+
return gate * value
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
|
| 619 |
+
M = torch.exp(logits.clamp(-10, 10))
|
| 620 |
+
for _ in range(n_iters):
|
| 621 |
+
M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10)
|
| 622 |
+
M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10)
|
| 623 |
+
return M
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class ManifoldHyperConnection(nn.Module):
|
| 627 |
+
"""Manifold-Constrained Hyper-Connections (mHC) residual wrapper."""
|
| 628 |
+
|
| 629 |
+
def __init__(self, dim: int, expansion: int = 2) -> None:
|
| 630 |
+
super().__init__()
|
| 631 |
+
self.dim = dim
|
| 632 |
+
self.expansion = expansion
|
| 633 |
+
n = expansion
|
| 634 |
+
|
| 635 |
+
self.bias_pre = nn.Parameter(torch.zeros(1, n))
|
| 636 |
+
self.bias_post = nn.Parameter(torch.zeros(1, n))
|
| 637 |
+
self.bias_res = nn.Parameter(torch.zeros(n, n))
|
| 638 |
+
|
| 639 |
+
self.theta_pre = nn.Linear(n * dim, n, bias=False)
|
| 640 |
+
self.theta_post = nn.Linear(n * dim, n, bias=False)
|
| 641 |
+
self.theta_res = nn.Linear(n * dim, n * n, bias=False)
|
| 642 |
+
|
| 643 |
+
self.alpha_pre = nn.Parameter(torch.tensor(0.0))
|
| 644 |
+
self.alpha_post = nn.Parameter(torch.tensor(0.0))
|
| 645 |
+
self.alpha_res = nn.Parameter(torch.tensor(0.0))
|
| 646 |
+
|
| 647 |
+
def _compute_mappings(
|
| 648 |
+
self, x_expanded: torch.Tensor
|
| 649 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 650 |
+
B, T, _ = x_expanded.shape
|
| 651 |
+
n = self.expansion
|
| 652 |
+
x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]])
|
| 653 |
+
d_pre = torch.tanh(self.theta_pre(x_norm))
|
| 654 |
+
d_post = torch.tanh(self.theta_post(x_norm))
|
| 655 |
+
d_res = self.theta_res(x_norm)
|
| 656 |
+
H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre)
|
| 657 |
+
H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post)
|
| 658 |
+
H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape(
|
| 659 |
+
B, T, n, n
|
| 660 |
+
)
|
| 661 |
+
H_res = _sinkhorn_knopp(H_res_raw)
|
| 662 |
+
return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res
|
| 663 |
+
|
| 664 |
+
def expand_stream(self, x: torch.Tensor) -> torch.Tensor:
|
| 665 |
+
return x.repeat(1, 1, self.expansion)
|
| 666 |
+
|
| 667 |
+
def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor:
|
| 668 |
+
B, T, _ = x_expanded.shape
|
| 669 |
+
return x_expanded.view(B, T, self.expansion, self.dim).mean(dim=-2)
|
| 670 |
+
|
| 671 |
+
def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor:
|
| 672 |
+
B, T, _ = x_expanded.shape
|
| 673 |
+
x_streams = x_expanded.view(B, T, self.expansion, self.dim)
|
| 674 |
+
return (H_pre @ x_streams).squeeze(-2)
|
| 675 |
+
|
| 676 |
+
def post_res_mix(
|
| 677 |
+
self,
|
| 678 |
+
layer_output: torch.Tensor,
|
| 679 |
+
x_expanded: torch.Tensor,
|
| 680 |
+
H_post: torch.Tensor,
|
| 681 |
+
H_res: torch.Tensor,
|
| 682 |
+
) -> torch.Tensor:
|
| 683 |
+
B, T, _ = x_expanded.shape
|
| 684 |
+
x_streams = x_expanded.view(B, T, self.expansion, self.dim)
|
| 685 |
+
mixed = torch.matmul(H_res, x_streams)
|
| 686 |
+
post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2))
|
| 687 |
+
return (mixed + post_out).reshape(B, T, self.expansion * self.dim)
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
class TransformerBlock(nn.Module):
|
| 691 |
+
def __init__(
|
| 692 |
+
self,
|
| 693 |
+
dim: int,
|
| 694 |
+
n_heads: int,
|
| 695 |
+
n_kv_heads: int,
|
| 696 |
+
head_dim: int,
|
| 697 |
+
ffn_dim: int,
|
| 698 |
+
dropout: float,
|
| 699 |
+
sliding_window: int,
|
| 700 |
+
rope_fraction: float,
|
| 701 |
+
engram_dim: int = 0,
|
| 702 |
+
engram_heads: int = 4,
|
| 703 |
+
engram_table_size: int = 8192,
|
| 704 |
+
engram_max_ngram: int = 3,
|
| 705 |
+
mhc_expansion: int = 1,
|
| 706 |
+
) -> None:
|
| 707 |
+
super().__init__()
|
| 708 |
+
self.dim = dim
|
| 709 |
+
self.norm1 = RMSNorm(dim)
|
| 710 |
+
self.attn = CausalSelfAttention(
|
| 711 |
+
dim=dim,
|
| 712 |
+
n_heads=n_heads,
|
| 713 |
+
n_kv_heads=n_kv_heads,
|
| 714 |
+
head_dim=head_dim,
|
| 715 |
+
dropout=dropout,
|
| 716 |
+
sliding_window=sliding_window,
|
| 717 |
+
rope_fraction=rope_fraction,
|
| 718 |
+
)
|
| 719 |
+
self.norm2 = RMSNorm(dim)
|
| 720 |
+
self.ffn = SwiGLU(dim, ffn_dim, dropout)
|
| 721 |
+
self.use_engram = engram_dim > 0
|
| 722 |
+
if self.use_engram:
|
| 723 |
+
self.engram = EngramBlock(
|
| 724 |
+
dim=dim,
|
| 725 |
+
engram_dim=engram_dim,
|
| 726 |
+
n_heads=engram_heads,
|
| 727 |
+
table_size=engram_table_size,
|
| 728 |
+
max_ngram=engram_max_ngram,
|
| 729 |
+
)
|
| 730 |
+
self.engram_norm = RMSNorm(dim)
|
| 731 |
+
self.use_mhc = mhc_expansion > 1
|
| 732 |
+
if self.use_mhc:
|
| 733 |
+
self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
|
| 734 |
+
self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
|
| 735 |
+
|
| 736 |
+
def forward(
|
| 737 |
+
self,
|
| 738 |
+
x: torch.Tensor,
|
| 739 |
+
is_global: bool,
|
| 740 |
+
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 741 |
+
use_cache: bool = False,
|
| 742 |
+
token_ids: Optional[torch.Tensor] = None,
|
| 743 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
| 744 |
+
if self.use_mhc:
|
| 745 |
+
x_exp = self.mhc_attn.expand_stream(x)
|
| 746 |
+
H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp)
|
| 747 |
+
attn_in = self.mhc_attn.pre_mix(x_exp, H_pre)
|
| 748 |
+
attn_out, new_kv = self.attn(
|
| 749 |
+
self.norm1(attn_in), is_global, past_kv, use_cache
|
| 750 |
+
)
|
| 751 |
+
x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res)
|
| 752 |
+
if self.use_engram:
|
| 753 |
+
collapsed = self.mhc_attn.collapse_stream(x_exp)
|
| 754 |
+
collapsed = collapsed + self.engram(
|
| 755 |
+
self.engram_norm(collapsed), token_ids=token_ids
|
| 756 |
+
)
|
| 757 |
+
x_exp = self.mhc_attn.expand_stream(collapsed)
|
| 758 |
+
H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp)
|
| 759 |
+
ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2)
|
| 760 |
+
ffn_out = self.ffn(self.norm2(ffn_in))
|
| 761 |
+
x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2)
|
| 762 |
+
x = self.mhc_attn.collapse_stream(x_exp)
|
| 763 |
+
else:
|
| 764 |
+
attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
|
| 765 |
+
x = x + attn_out
|
| 766 |
+
if self.use_engram:
|
| 767 |
+
x = x + self.engram(self.engram_norm(x), token_ids=token_ids)
|
| 768 |
+
x = x + self.ffn(self.norm2(x))
|
| 769 |
+
return x, new_kv
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def _detect_engram_dim(state_dict: dict) -> int:
|
| 773 |
+
for key in state_dict:
|
| 774 |
+
if ".engram." in key and ".embeddings." in key:
|
| 775 |
+
return state_dict[key].shape[-1]
|
| 776 |
+
return 0
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
def _detect_mhc_expansion(state_dict: dict) -> int:
|
| 780 |
+
for key, val in state_dict.items():
|
| 781 |
+
if ".mhc_attn.bias_pre" in key and val.dim() == 2:
|
| 782 |
+
return val.shape[-1]
|
| 783 |
+
return 1
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
class TinyMemoryLM(nn.Module):
|
| 787 |
+
def __init__(
|
| 788 |
+
self,
|
| 789 |
+
vocab_size: int,
|
| 790 |
+
dim: int,
|
| 791 |
+
n_unique_layers: int,
|
| 792 |
+
n_logical_layers: int,
|
| 793 |
+
n_heads: int,
|
| 794 |
+
n_kv_heads: int,
|
| 795 |
+
ffn_dim: int,
|
| 796 |
+
dropout: float,
|
| 797 |
+
mtp_horizons: Sequence[int],
|
| 798 |
+
grad_checkpoint: bool,
|
| 799 |
+
sliding_window: int = 512,
|
| 800 |
+
rope_fraction: float = 0.5,
|
| 801 |
+
embed_scale: bool = True,
|
| 802 |
+
engram_dim: int = 0,
|
| 803 |
+
engram_heads: int = 4,
|
| 804 |
+
engram_table_size: int = 8192,
|
| 805 |
+
engram_max_ngram: int = 3,
|
| 806 |
+
mhc_expansion: int = 1,
|
| 807 |
+
) -> None:
|
| 808 |
+
super().__init__()
|
| 809 |
+
self.dim = dim
|
| 810 |
+
self.n_unique_layers = n_unique_layers
|
| 811 |
+
self.n_logical_layers = n_logical_layers
|
| 812 |
+
self.grad_checkpoint = grad_checkpoint
|
| 813 |
+
self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0
|
| 814 |
+
head_dim = dim // n_heads
|
| 815 |
+
|
| 816 |
+
self.embed_tokens = nn.Embedding(vocab_size, dim)
|
| 817 |
+
self.head = nn.Linear(dim, vocab_size, bias=False)
|
| 818 |
+
self.head.weight = self.embed_tokens.weight
|
| 819 |
+
|
| 820 |
+
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
|
| 821 |
+
|
| 822 |
+
self.blocks = nn.ModuleList(
|
| 823 |
+
[
|
| 824 |
+
TransformerBlock(
|
| 825 |
+
dim=dim,
|
| 826 |
+
n_heads=n_heads,
|
| 827 |
+
n_kv_heads=n_kv_heads,
|
| 828 |
+
head_dim=head_dim,
|
| 829 |
+
ffn_dim=ffn_dim,
|
| 830 |
+
dropout=dropout,
|
| 831 |
+
sliding_window=sliding_window,
|
| 832 |
+
rope_fraction=rope_fraction,
|
| 833 |
+
engram_dim=engram_dim,
|
| 834 |
+
engram_heads=engram_heads,
|
| 835 |
+
engram_table_size=engram_table_size,
|
| 836 |
+
engram_max_ngram=engram_max_ngram,
|
| 837 |
+
mhc_expansion=mhc_expansion,
|
| 838 |
+
)
|
| 839 |
+
for _ in range(n_unique_layers)
|
| 840 |
+
]
|
| 841 |
+
)
|
| 842 |
+
self.norm = RMSNorm(dim)
|
| 843 |
+
|
| 844 |
+
self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
|
| 845 |
+
self.mtp_adapters = nn.ModuleDict(
|
| 846 |
+
{str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons}
|
| 847 |
+
)
|
| 848 |
+
self.mtp_norms = nn.ModuleDict(
|
| 849 |
+
{str(h): RMSNorm(dim) for h in self.mtp_horizons}
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
res_scale = (2 * n_logical_layers) ** -0.5
|
| 853 |
+
for block in self.blocks:
|
| 854 |
+
block.attn.wo.weight.data.mul_(res_scale)
|
| 855 |
+
block.ffn.down.weight.data.mul_(res_scale)
|
| 856 |
+
|
| 857 |
+
def resize_token_embeddings(self, new_vocab_size: int) -> None:
|
| 858 |
+
old_vocab_size = self.embed_tokens.num_embeddings
|
| 859 |
+
if new_vocab_size == old_vocab_size:
|
| 860 |
+
return
|
| 861 |
+
device = self.embed_tokens.weight.device
|
| 862 |
+
old_embed_weight = self.embed_tokens.weight.data.clone()
|
| 863 |
+
self.embed_tokens = nn.Embedding(
|
| 864 |
+
new_vocab_size, self.embed_tokens.embedding_dim
|
| 865 |
+
).to(device)
|
| 866 |
+
self.head = nn.Linear(
|
| 867 |
+
self.embed_tokens.embedding_dim, new_vocab_size, bias=False
|
| 868 |
+
).to(device)
|
| 869 |
+
self.head.weight = self.embed_tokens.weight
|
| 870 |
+
old_bias = self.output_bias.data.clone()
|
| 871 |
+
self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
|
| 872 |
+
copy_size = min(old_vocab_size, new_vocab_size)
|
| 873 |
+
self.output_bias.data[:copy_size] = old_bias[:copy_size]
|
| 874 |
+
self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
|
| 875 |
+
|
| 876 |
+
def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
|
| 877 |
+
logical = []
|
| 878 |
+
blocks_list = list(self.blocks)
|
| 879 |
+
full_sequence = blocks_list + blocks_list
|
| 880 |
+
for logical_idx, block in enumerate(full_sequence[: self.n_logical_layers]):
|
| 881 |
+
logical.append((block, logical_idx))
|
| 882 |
+
return logical
|
| 883 |
+
|
| 884 |
+
def forward(
|
| 885 |
+
self,
|
| 886 |
+
ids: torch.Tensor,
|
| 887 |
+
use_cache: bool = False,
|
| 888 |
+
past_key_values: Optional[
|
| 889 |
+
List[Optional[Tuple[torch.Tensor, torch.Tensor]]]
|
| 890 |
+
] = None,
|
| 891 |
+
return_hidden: bool = False,
|
| 892 |
+
) -> Tuple[
|
| 893 |
+
torch.Tensor,
|
| 894 |
+
Dict[int, torch.Tensor],
|
| 895 |
+
Optional[torch.Tensor],
|
| 896 |
+
Optional[List[Tuple[torch.Tensor, torch.Tensor]]],
|
| 897 |
+
]:
|
| 898 |
+
B, T = ids.shape
|
| 899 |
+
x = self.embed_tokens(ids) * self.embed_scale_factor
|
| 900 |
+
|
| 901 |
+
logical_layers = self._build_logical_layers()
|
| 902 |
+
new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = (
|
| 903 |
+
[] if use_cache else None
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
for layer_idx, (block, logical_idx) in enumerate(logical_layers):
|
| 907 |
+
is_global = logical_idx % 2 == 0
|
| 908 |
+
past_kv = (
|
| 909 |
+
past_key_values[layer_idx]
|
| 910 |
+
if past_key_values is not None and layer_idx < len(past_key_values)
|
| 911 |
+
else None
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
if self.grad_checkpoint and self.training and not use_cache:
|
| 915 |
+
x, layer_kv = checkpoint(
|
| 916 |
+
block, x, is_global, past_kv, use_cache, ids, use_reentrant=True
|
| 917 |
+
)
|
| 918 |
+
else:
|
| 919 |
+
x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
|
| 920 |
+
|
| 921 |
+
if new_past_key_values is not None:
|
| 922 |
+
new_past_key_values.append(layer_kv)
|
| 923 |
+
|
| 924 |
+
x = self.norm(x)
|
| 925 |
+
h_out = x if return_hidden else None
|
| 926 |
+
logits = self.head(x)
|
| 927 |
+
if self.embed_scale_factor != 1.0:
|
| 928 |
+
logits = logits / self.embed_scale_factor
|
| 929 |
+
logits = logits + self.output_bias
|
| 930 |
+
|
| 931 |
+
mtp: Dict[int, torch.Tensor] = {}
|
| 932 |
+
if self.mtp_horizons and self.training:
|
| 933 |
+
for horizon in self.mtp_horizons:
|
| 934 |
+
if horizon > 1 and horizon <= T - 1:
|
| 935 |
+
shifted_h = x[:, :-horizon, :]
|
| 936 |
+
adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
|
| 937 |
+
adapted_h = self.mtp_norms[str(horizon)](adapted_h)
|
| 938 |
+
mtp_logits = self.head(adapted_h)
|
| 939 |
+
if self.embed_scale_factor != 1.0:
|
| 940 |
+
mtp_logits = mtp_logits / self.embed_scale_factor
|
| 941 |
+
mtp_logits = mtp_logits + self.output_bias
|
| 942 |
+
mtp[horizon] = mtp_logits
|
| 943 |
+
|
| 944 |
+
return logits, mtp, h_out, new_past_key_values
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
# ---------------------------------------------------------------------------
|
| 948 |
+
# Generation (from ailay.generation)
|
| 949 |
+
# ---------------------------------------------------------------------------
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
def sample_text(
|
| 953 |
+
model: TinyMemoryLM,
|
| 954 |
+
tokenizer: WordTokenizer,
|
| 955 |
+
prompt: str,
|
| 956 |
+
max_new_tokens: int,
|
| 957 |
+
temperature: float,
|
| 958 |
+
top_k: int,
|
| 959 |
+
branches: int,
|
| 960 |
+
branch_len: int,
|
| 961 |
+
device: torch.device,
|
| 962 |
+
seq_len: int,
|
| 963 |
+
) -> str:
|
| 964 |
+
def _sample_id(logits: torch.Tensor) -> torch.Tensor:
|
| 965 |
+
if not torch.isfinite(logits).any():
|
| 966 |
+
logits = torch.zeros_like(logits)
|
| 967 |
+
logits = torch.where(
|
| 968 |
+
torch.isfinite(logits), logits, torch.full_like(logits, -1e9)
|
| 969 |
+
)
|
| 970 |
+
if top_k > 0:
|
| 971 |
+
v, idx = torch.topk(logits, k=min(top_k, logits.shape[-1]))
|
| 972 |
+
p = torch.softmax(v, dim=-1)
|
| 973 |
+
return idx.gather(-1, torch.multinomial(p, 1))
|
| 974 |
+
p = torch.softmax(logits, dim=-1)
|
| 975 |
+
return torch.multinomial(p, 1)
|
| 976 |
+
|
| 977 |
+
model.eval()
|
| 978 |
+
ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
|
| 979 |
+
prompt_len = len(ids)
|
| 980 |
+
x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 981 |
+
|
| 982 |
+
with torch.no_grad():
|
| 983 |
+
generated = 0
|
| 984 |
+
while generated < max_new_tokens:
|
| 985 |
+
if branches <= 1:
|
| 986 |
+
ctx = x[:, -seq_len:]
|
| 987 |
+
logits, _, _, _ = model(ctx)
|
| 988 |
+
nlogits = logits[:, -1, :] / max(temperature, 1e-6)
|
| 989 |
+
nid = _sample_id(nlogits)
|
| 990 |
+
x = torch.cat([x, nid], dim=1)
|
| 991 |
+
generated += 1
|
| 992 |
+
continue
|
| 993 |
+
rollout = min(branch_len, max_new_tokens - generated)
|
| 994 |
+
best_nll: Optional[float] = None
|
| 995 |
+
best_tokens: Optional[List[torch.Tensor]] = None
|
| 996 |
+
for _ in range(branches):
|
| 997 |
+
cand = x
|
| 998 |
+
cand_tokens: List[torch.Tensor] = []
|
| 999 |
+
nll = 0.0
|
| 1000 |
+
for _ in range(rollout):
|
| 1001 |
+
ctx = cand[:, -seq_len:]
|
| 1002 |
+
logits, _, _, _ = model(ctx)
|
| 1003 |
+
nlogits = logits[:, -1, :] / max(temperature, 1e-6)
|
| 1004 |
+
nid = _sample_id(nlogits)
|
| 1005 |
+
lp = F.log_softmax(nlogits.float(), dim=-1)
|
| 1006 |
+
nll += float(-lp.gather(-1, nid).item())
|
| 1007 |
+
cand = torch.cat([cand, nid], dim=1)
|
| 1008 |
+
cand_tokens.append(nid)
|
| 1009 |
+
if best_nll is None or nll < best_nll:
|
| 1010 |
+
best_nll = nll
|
| 1011 |
+
best_tokens = cand_tokens
|
| 1012 |
+
assert best_tokens is not None
|
| 1013 |
+
for t in best_tokens:
|
| 1014 |
+
x = torch.cat([x, t], dim=1)
|
| 1015 |
+
generated += 1
|
| 1016 |
+
|
| 1017 |
+
generated_ids = x[0, prompt_len:].tolist()
|
| 1018 |
+
return tokenizer.decode(generated_ids, skip_special=True)
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
def sample_text_cached(
|
| 1022 |
+
model: TinyMemoryLM,
|
| 1023 |
+
tokenizer: WordTokenizer,
|
| 1024 |
+
prompt: str,
|
| 1025 |
+
max_new_tokens: int,
|
| 1026 |
+
temperature: float,
|
| 1027 |
+
top_k: int,
|
| 1028 |
+
device: torch.device,
|
| 1029 |
+
seq_len: int,
|
| 1030 |
+
) -> str:
|
| 1031 |
+
model.eval()
|
| 1032 |
+
ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
|
| 1033 |
+
prompt_len = len(ids)
|
| 1034 |
+
x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 1035 |
+
|
| 1036 |
+
with torch.no_grad():
|
| 1037 |
+
logits, _, _, past_kv = model(x, use_cache=True)
|
| 1038 |
+
nlogits = logits[:, -1, :] / max(temperature, 1e-6)
|
| 1039 |
+
if top_k > 0:
|
| 1040 |
+
v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1]))
|
| 1041 |
+
p = torch.softmax(v, dim=-1)
|
| 1042 |
+
nid = idx.gather(-1, torch.multinomial(p, 1))
|
| 1043 |
+
else:
|
| 1044 |
+
p = torch.softmax(nlogits, dim=-1)
|
| 1045 |
+
nid = torch.multinomial(p, 1)
|
| 1046 |
+
all_ids = [int(nid.item())]
|
| 1047 |
+
|
| 1048 |
+
for _ in range(max_new_tokens - 1):
|
| 1049 |
+
logits, _, _, past_kv = model(nid, use_cache=True, past_key_values=past_kv)
|
| 1050 |
+
nlogits = logits[:, -1, :] / max(temperature, 1e-6)
|
| 1051 |
+
if top_k > 0:
|
| 1052 |
+
v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1]))
|
| 1053 |
+
p = torch.softmax(v, dim=-1)
|
| 1054 |
+
nid = idx.gather(-1, torch.multinomial(p, 1))
|
| 1055 |
+
else:
|
| 1056 |
+
p = torch.softmax(nlogits, dim=-1)
|
| 1057 |
+
nid = torch.multinomial(p, 1)
|
| 1058 |
+
tid = int(nid.item())
|
| 1059 |
+
all_ids.append(tid)
|
| 1060 |
+
if tid == tokenizer.eos_id:
|
| 1061 |
+
break
|
| 1062 |
+
|
| 1063 |
+
return tokenizer.decode(all_ids, skip_special=True)
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
def speculative_decode(
|
| 1067 |
+
model: TinyMemoryLM,
|
| 1068 |
+
tokenizer: WordTokenizer,
|
| 1069 |
+
prompt: str,
|
| 1070 |
+
max_new_tokens: int,
|
| 1071 |
+
temperature: float,
|
| 1072 |
+
top_k: int,
|
| 1073 |
+
device: torch.device,
|
| 1074 |
+
seq_len: int,
|
| 1075 |
+
) -> str:
|
| 1076 |
+
model.eval()
|
| 1077 |
+
ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
|
| 1078 |
+
x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 1079 |
+
all_generated: List[int] = []
|
| 1080 |
+
|
| 1081 |
+
with torch.no_grad():
|
| 1082 |
+
logits, _, h_out, past_kv = model(x, use_cache=True, return_hidden=True)
|
| 1083 |
+
|
| 1084 |
+
def _sample_from(lg: torch.Tensor) -> int:
|
| 1085 |
+
lg = lg / max(temperature, 1e-6)
|
| 1086 |
+
if top_k > 0:
|
| 1087 |
+
v, idx = torch.topk(lg, k=min(top_k, lg.shape[-1]))
|
| 1088 |
+
p = torch.softmax(v, dim=-1)
|
| 1089 |
+
return int(idx[torch.multinomial(p, 1)].item())
|
| 1090 |
+
p = torch.softmax(lg, dim=-1)
|
| 1091 |
+
return int(torch.multinomial(p, 1).item())
|
| 1092 |
+
|
| 1093 |
+
main_token = _sample_from(logits[0, -1, :])
|
| 1094 |
+
all_generated.append(main_token)
|
| 1095 |
+
|
| 1096 |
+
while len(all_generated) < max_new_tokens:
|
| 1097 |
+
if main_token == tokenizer.eos_id:
|
| 1098 |
+
break
|
| 1099 |
+
|
| 1100 |
+
draft_tokens = []
|
| 1101 |
+
if h_out is not None and model.mtp_horizons:
|
| 1102 |
+
last_hidden = h_out[:, -1:, :]
|
| 1103 |
+
for h in model.mtp_horizons:
|
| 1104 |
+
adapter = model.mtp_adapters[str(h)]
|
| 1105 |
+
norm = model.mtp_norms[str(h)]
|
| 1106 |
+
adapted = norm(adapter(last_hidden))
|
| 1107 |
+
draft_logits = model.head(adapted) + model.output_bias
|
| 1108 |
+
draft_tok = _sample_from(draft_logits[0, 0, :])
|
| 1109 |
+
draft_tokens.append(draft_tok)
|
| 1110 |
+
|
| 1111 |
+
if not draft_tokens:
|
| 1112 |
+
nid = torch.tensor([[main_token]], dtype=torch.long, device=device)
|
| 1113 |
+
logits, _, h_out, past_kv = model(
|
| 1114 |
+
nid, use_cache=True, past_key_values=past_kv, return_hidden=True
|
| 1115 |
+
)
|
| 1116 |
+
main_token = _sample_from(logits[0, -1, :])
|
| 1117 |
+
all_generated.append(main_token)
|
| 1118 |
+
continue
|
| 1119 |
+
|
| 1120 |
+
verify_input = torch.tensor(
|
| 1121 |
+
[[main_token] + draft_tokens], dtype=torch.long, device=device
|
| 1122 |
+
)
|
| 1123 |
+
verify_logits, _, h_out, past_kv = model(
|
| 1124 |
+
verify_input,
|
| 1125 |
+
use_cache=True,
|
| 1126 |
+
past_key_values=past_kv,
|
| 1127 |
+
return_hidden=True,
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
accepted = 0
|
| 1131 |
+
all_generated.append(main_token) if main_token not in all_generated[
|
| 1132 |
+
-1:
|
| 1133 |
+
] else None
|
| 1134 |
+
for i, draft_tok in enumerate(draft_tokens):
|
| 1135 |
+
verified_tok = _sample_from(verify_logits[0, i, :])
|
| 1136 |
+
if verified_tok == draft_tok:
|
| 1137 |
+
all_generated.append(draft_tok)
|
| 1138 |
+
accepted += 1
|
| 1139 |
+
if draft_tok == tokenizer.eos_id:
|
| 1140 |
+
break
|
| 1141 |
+
else:
|
| 1142 |
+
all_generated.append(verified_tok)
|
| 1143 |
+
break
|
| 1144 |
+
|
| 1145 |
+
if accepted < len(draft_tokens):
|
| 1146 |
+
trim_len = len(draft_tokens) - accepted - 1
|
| 1147 |
+
if trim_len > 0 and past_kv is not None:
|
| 1148 |
+
past_kv = [
|
| 1149 |
+
(k[:, :, :-trim_len, :], v[:, :, :-trim_len, :])
|
| 1150 |
+
if k is not None
|
| 1151 |
+
else None
|
| 1152 |
+
for k, v in past_kv
|
| 1153 |
+
]
|
| 1154 |
+
|
| 1155 |
+
main_token = all_generated[-1]
|
| 1156 |
+
|
| 1157 |
+
return tokenizer.decode(all_generated, skip_special=True)
|
| 1158 |
+
|
| 1159 |
+
|
| 1160 |
+
def build_stop_token_ids(tokenizer: WordTokenizer) -> set:
|
| 1161 |
+
stop_tokens = {tokenizer.eos_id}
|
| 1162 |
+
for tok in ("<|user|>", "<|system|>", "<|assistant|>"):
|
| 1163 |
+
tid = tokenizer.token_to_id.get(tok)
|
| 1164 |
+
if tid is not None:
|
| 1165 |
+
stop_tokens.add(int(tid))
|
| 1166 |
+
return stop_tokens
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
def apply_no_repeat_ngram(
|
| 1170 |
+
logits: torch.Tensor,
|
| 1171 |
+
token_history: Sequence[int],
|
| 1172 |
+
ngram_size: int,
|
| 1173 |
+
) -> torch.Tensor:
|
| 1174 |
+
if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1):
|
| 1175 |
+
return logits
|
| 1176 |
+
prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple()
|
| 1177 |
+
banned: set = set()
|
| 1178 |
+
for i in range(len(token_history) - ngram_size + 1):
|
| 1179 |
+
if tuple(token_history[i : i + ngram_size - 1]) == prefix:
|
| 1180 |
+
banned.add(int(token_history[i + ngram_size - 1]))
|
| 1181 |
+
if not banned:
|
| 1182 |
+
return logits
|
| 1183 |
+
out = logits.clone()
|
| 1184 |
+
banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long)
|
| 1185 |
+
out[banned_ids] = float("-inf")
|
| 1186 |
+
return out
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
def score_candidate(
|
| 1190 |
+
prompt: str,
|
| 1191 |
+
raw_text: str,
|
| 1192 |
+
visible_text: str,
|
| 1193 |
+
avg_logprob: float,
|
| 1194 |
+
) -> float:
|
| 1195 |
+
clean = visible_text.strip()
|
| 1196 |
+
if not clean:
|
| 1197 |
+
return -1e9
|
| 1198 |
+
score = avg_logprob
|
| 1199 |
+
words = clean.lower().split()
|
| 1200 |
+
prompt_words = re.findall(r"[A-Za-z][A-Za-z'-]{2,}", prompt.lower())
|
| 1201 |
+
prompt_stop = {
|
| 1202 |
+
"what",
|
| 1203 |
+
"which",
|
| 1204 |
+
"when",
|
| 1205 |
+
"where",
|
| 1206 |
+
"why",
|
| 1207 |
+
"how",
|
| 1208 |
+
"are",
|
| 1209 |
+
"is",
|
| 1210 |
+
"the",
|
| 1211 |
+
"and",
|
| 1212 |
+
"for",
|
| 1213 |
+
"with",
|
| 1214 |
+
"that",
|
| 1215 |
+
"this",
|
| 1216 |
+
"from",
|
| 1217 |
+
"into",
|
| 1218 |
+
"about",
|
| 1219 |
+
"explain",
|
| 1220 |
+
"tell",
|
| 1221 |
+
"give",
|
| 1222 |
+
"list",
|
| 1223 |
+
"show",
|
| 1224 |
+
"write",
|
| 1225 |
+
"their",
|
| 1226 |
+
"there",
|
| 1227 |
+
"your",
|
| 1228 |
+
}
|
| 1229 |
+
prompt_keywords = {w for w in prompt_words if w not in prompt_stop}
|
| 1230 |
+
candidate_keywords = set(re.findall(r"[A-Za-z][A-Za-z'-]{2,}", clean.lower()))
|
| 1231 |
+
if len(words) < 6:
|
| 1232 |
+
score -= 2.0
|
| 1233 |
+
else:
|
| 1234 |
+
score += min(2.0, len(words) * 0.03)
|
| 1235 |
+
if clean[-1:] in ".!?":
|
| 1236 |
+
score += 0.5
|
| 1237 |
+
if "<|user|>" in raw_text or "<|system|>" in raw_text:
|
| 1238 |
+
score -= 4.0
|
| 1239 |
+
if raw_text.count("<|assistant|>") > 1:
|
| 1240 |
+
score -= 2.0
|
| 1241 |
+
if prompt_keywords:
|
| 1242 |
+
overlap = len(prompt_keywords & candidate_keywords) / len(prompt_keywords)
|
| 1243 |
+
if overlap == 0.0:
|
| 1244 |
+
score -= 2.5
|
| 1245 |
+
else:
|
| 1246 |
+
score += min(3.5, overlap * 4.0)
|
| 1247 |
+
for open_tok, close_tok in [
|
| 1248 |
+
("<|begin_of_thought|>", "<|end_of_thought|>"),
|
| 1249 |
+
("<|begin_of_solution|>", "<|end_of_solution|>"),
|
| 1250 |
+
]:
|
| 1251 |
+
if (open_tok in raw_text) != (close_tok in raw_text):
|
| 1252 |
+
score -= 1.0
|
| 1253 |
+
if len(words) >= 3:
|
| 1254 |
+
trigrams = [tuple(words[i : i + 3]) for i in range(len(words) - 2)]
|
| 1255 |
+
if trigrams:
|
| 1256 |
+
unique_ratio = len(set(trigrams)) / len(trigrams)
|
| 1257 |
+
if unique_ratio < 0.35:
|
| 1258 |
+
score -= 4.0
|
| 1259 |
+
elif unique_ratio < 0.55:
|
| 1260 |
+
score -= 2.0
|
| 1261 |
+
else:
|
| 1262 |
+
score += min(1.0, (unique_ratio - 0.55) * 2.0)
|
| 1263 |
+
alpha_words = [
|
| 1264 |
+
w
|
| 1265 |
+
for w in words
|
| 1266 |
+
if len(w) <= 18 and (sum(ch.isalpha() for ch in w) / max(len(w), 1)) > 0.7
|
| 1267 |
+
]
|
| 1268 |
+
alpha_ratio = len(alpha_words) / max(len(words), 1)
|
| 1269 |
+
if alpha_ratio < 0.45:
|
| 1270 |
+
score -= 3.0
|
| 1271 |
+
elif alpha_ratio < 0.65:
|
| 1272 |
+
score -= 1.0
|
| 1273 |
+
return score
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
def generate_candidate(
|
| 1277 |
+
model: TinyMemoryLM,
|
| 1278 |
+
tokenizer: WordTokenizer,
|
| 1279 |
+
prompt: str,
|
| 1280 |
+
max_new_tokens: int,
|
| 1281 |
+
temperature: float,
|
| 1282 |
+
top_k: int,
|
| 1283 |
+
repetition_penalty: float,
|
| 1284 |
+
no_repeat_ngram_size: int,
|
| 1285 |
+
device: str,
|
| 1286 |
+
sft_mode: bool,
|
| 1287 |
+
force_thought: bool,
|
| 1288 |
+
stream: bool,
|
| 1289 |
+
context_window: int,
|
| 1290 |
+
) -> Tuple[str, str, float, int]:
|
| 1291 |
+
if sft_mode:
|
| 1292 |
+
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
|
| 1293 |
+
else:
|
| 1294 |
+
full_prompt = prompt
|
| 1295 |
+
if force_thought:
|
| 1296 |
+
full_prompt = f"{full_prompt}<|begin_of_thought|> "
|
| 1297 |
+
input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
|
| 1298 |
+
input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device)
|
| 1299 |
+
visible_tokens: List[str] = []
|
| 1300 |
+
raw_tokens: List[str] = []
|
| 1301 |
+
stop_token_ids = build_stop_token_ids(tokenizer)
|
| 1302 |
+
total_logprob = 0.0
|
| 1303 |
+
sampled_tokens = 0
|
| 1304 |
+
with torch.no_grad():
|
| 1305 |
+
for _ in range(max_new_tokens):
|
| 1306 |
+
ctx_ids = (
|
| 1307 |
+
input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
|
| 1308 |
+
)
|
| 1309 |
+
logits, _, _, _ = model(ctx_ids)
|
| 1310 |
+
next_logits = logits[0, -1, :].clone()
|
| 1311 |
+
raw_next_logits = next_logits.clone()
|
| 1312 |
+
if repetition_penalty != 1.0:
|
| 1313 |
+
seen = set(input_ids_t[0].tolist())
|
| 1314 |
+
for token_id in seen:
|
| 1315 |
+
if next_logits[token_id] > 0:
|
| 1316 |
+
next_logits[token_id] /= repetition_penalty
|
| 1317 |
+
else:
|
| 1318 |
+
next_logits[token_id] *= repetition_penalty
|
| 1319 |
+
if temperature != 1.0:
|
| 1320 |
+
next_logits = next_logits / max(temperature, 1e-6)
|
| 1321 |
+
if no_repeat_ngram_size > 1:
|
| 1322 |
+
next_logits = apply_no_repeat_ngram(
|
| 1323 |
+
next_logits,
|
| 1324 |
+
input_ids_t[0].tolist(),
|
| 1325 |
+
no_repeat_ngram_size,
|
| 1326 |
+
)
|
| 1327 |
+
if top_k > 0:
|
| 1328 |
+
v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0)))
|
| 1329 |
+
next_logits[next_logits < v[-1]] = float("-inf")
|
| 1330 |
+
top_p = 0.9
|
| 1331 |
+
if top_p < 1.0:
|
| 1332 |
+
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
|
| 1333 |
+
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 1334 |
+
remove_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
|
| 1335 |
+
sorted_logits[remove_mask] = float("-inf")
|
| 1336 |
+
next_logits = sorted_logits.scatter(0, sorted_indices, sorted_logits)
|
| 1337 |
+
if not torch.isfinite(next_logits).any():
|
| 1338 |
+
next_logits = raw_next_logits
|
| 1339 |
+
if temperature != 1.0:
|
| 1340 |
+
next_logits = next_logits / max(temperature, 1e-6)
|
| 1341 |
+
probs = torch.softmax(next_logits, dim=-1)
|
| 1342 |
+
next_id = torch.multinomial(probs, num_samples=1).item()
|
| 1343 |
+
total_logprob += float(torch.log(probs[next_id] + 1e-12).item())
|
| 1344 |
+
sampled_tokens += 1
|
| 1345 |
+
if next_id in stop_token_ids:
|
| 1346 |
+
break
|
| 1347 |
+
token_str = (
|
| 1348 |
+
tokenizer.id_to_token[next_id]
|
| 1349 |
+
if next_id < len(tokenizer.id_to_token)
|
| 1350 |
+
else ""
|
| 1351 |
+
)
|
| 1352 |
+
raw_tokens.append(token_str)
|
| 1353 |
+
if token_str not in tokenizer.special:
|
| 1354 |
+
visible_tokens.append(token_str)
|
| 1355 |
+
if stream:
|
| 1356 |
+
print(token_str, end="", flush=True)
|
| 1357 |
+
input_ids_t = torch.cat(
|
| 1358 |
+
[input_ids_t, torch.tensor([[next_id]], device=device)], dim=1
|
| 1359 |
+
)
|
| 1360 |
+
if stream:
|
| 1361 |
+
print()
|
| 1362 |
+
avg_logprob = total_logprob / max(1, sampled_tokens)
|
| 1363 |
+
return "".join(visible_tokens), "".join(raw_tokens), avg_logprob, 0
|
| 1364 |
+
|
| 1365 |
+
|
| 1366 |
+
def generate_beam_search(
|
| 1367 |
+
model: TinyMemoryLM,
|
| 1368 |
+
tokenizer: WordTokenizer,
|
| 1369 |
+
prompt: str,
|
| 1370 |
+
max_new_tokens: int = 60,
|
| 1371 |
+
beam_width: int = 8,
|
| 1372 |
+
length_penalty: float = 0.7,
|
| 1373 |
+
no_repeat_ngram_size: int = 3,
|
| 1374 |
+
device: str = "cuda",
|
| 1375 |
+
sft_mode: bool = False,
|
| 1376 |
+
context_window: int = 2048,
|
| 1377 |
+
) -> str:
|
| 1378 |
+
if sft_mode:
|
| 1379 |
+
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
|
| 1380 |
+
else:
|
| 1381 |
+
full_prompt = prompt
|
| 1382 |
+
prompt_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
|
| 1383 |
+
prompt_len = len(prompt_ids)
|
| 1384 |
+
stop_ids = build_stop_token_ids(tokenizer)
|
| 1385 |
+
beams: List[Tuple[float, List[int]]] = [(0.0, list(prompt_ids))]
|
| 1386 |
+
completed: List[Tuple[float, List[int]]] = []
|
| 1387 |
+
for _step in range(max_new_tokens):
|
| 1388 |
+
if not beams:
|
| 1389 |
+
break
|
| 1390 |
+
candidates: List[Tuple[float, List[int]]] = []
|
| 1391 |
+
for beam_score, beam_ids in beams:
|
| 1392 |
+
x = torch.tensor(
|
| 1393 |
+
[beam_ids[-context_window:]], dtype=torch.long, device=device
|
| 1394 |
+
)
|
| 1395 |
+
with torch.no_grad():
|
| 1396 |
+
logits, _, _, _ = model(x)
|
| 1397 |
+
nl = logits[0, -1, :]
|
| 1398 |
+
log_probs = F.log_softmax(nl, dim=-1)
|
| 1399 |
+
gen_ids = beam_ids[prompt_len:]
|
| 1400 |
+
if no_repeat_ngram_size > 1 and len(gen_ids) >= no_repeat_ngram_size - 1:
|
| 1401 |
+
prefix = tuple(gen_ids[-(no_repeat_ngram_size - 1) :])
|
| 1402 |
+
for i in range(len(gen_ids) - no_repeat_ngram_size + 1):
|
| 1403 |
+
if tuple(gen_ids[i : i + no_repeat_ngram_size - 1]) == prefix:
|
| 1404 |
+
log_probs[gen_ids[i + no_repeat_ngram_size - 1]] = float("-inf")
|
| 1405 |
+
topk_lp, topk_ids = torch.topk(log_probs, beam_width)
|
| 1406 |
+
for i in range(beam_width):
|
| 1407 |
+
tid = topk_ids[i].item()
|
| 1408 |
+
new_score = beam_score + topk_lp[i].item()
|
| 1409 |
+
new_ids = beam_ids + [tid]
|
| 1410 |
+
if tid in stop_ids:
|
| 1411 |
+
completed.append((new_score, new_ids))
|
| 1412 |
+
else:
|
| 1413 |
+
candidates.append((new_score, new_ids))
|
| 1414 |
+
|
| 1415 |
+
def _norm_score(pair):
|
| 1416 |
+
gen_len = max(1, len(pair[1]) - prompt_len)
|
| 1417 |
+
return pair[0] / (gen_len**length_penalty)
|
| 1418 |
+
|
| 1419 |
+
candidates.sort(key=_norm_score, reverse=True)
|
| 1420 |
+
beams = candidates[:beam_width]
|
| 1421 |
+
|
| 1422 |
+
pool = completed + beams
|
| 1423 |
+
if not pool:
|
| 1424 |
+
return ""
|
| 1425 |
+
|
| 1426 |
+
def _norm_score_final(pair):
|
| 1427 |
+
gen_len = max(1, len(pair[1]) - prompt_len)
|
| 1428 |
+
return pair[0] / (gen_len**length_penalty)
|
| 1429 |
+
|
| 1430 |
+
pool.sort(key=_norm_score_final, reverse=True)
|
| 1431 |
+
best_ids = pool[0][1][prompt_len:]
|
| 1432 |
+
text = tokenizer.decode(best_ids, skip_special=True)
|
| 1433 |
+
nl_pos = text.find("\n")
|
| 1434 |
+
if nl_pos > 5:
|
| 1435 |
+
text = text[:nl_pos]
|
| 1436 |
+
return text.strip()
|
| 1437 |
+
|
| 1438 |
+
|
| 1439 |
+
def generate(
|
| 1440 |
+
model: TinyMemoryLM,
|
| 1441 |
+
tokenizer: WordTokenizer,
|
| 1442 |
+
prompt: str,
|
| 1443 |
+
max_new_tokens: int = 256,
|
| 1444 |
+
temperature: float = 0.8,
|
| 1445 |
+
top_k: int = 40,
|
| 1446 |
+
repetition_penalty: float = 1.0,
|
| 1447 |
+
device: str = "cuda",
|
| 1448 |
+
sft_mode: bool = False,
|
| 1449 |
+
force_thought: bool = False,
|
| 1450 |
+
stream: bool = True,
|
| 1451 |
+
decode_mode: str = "legacy",
|
| 1452 |
+
best_of: int = 3,
|
| 1453 |
+
no_repeat_ngram_size: int = 3,
|
| 1454 |
+
context_window: int = 2048,
|
| 1455 |
+
beam_width: int = 8,
|
| 1456 |
+
length_penalty: float = 0.7,
|
| 1457 |
+
) -> str:
|
| 1458 |
+
if decode_mode == "beam":
|
| 1459 |
+
text = generate_beam_search(
|
| 1460 |
+
model=model,
|
| 1461 |
+
tokenizer=tokenizer,
|
| 1462 |
+
prompt=prompt,
|
| 1463 |
+
max_new_tokens=max_new_tokens,
|
| 1464 |
+
beam_width=beam_width,
|
| 1465 |
+
length_penalty=length_penalty,
|
| 1466 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 1467 |
+
device=device,
|
| 1468 |
+
sft_mode=sft_mode,
|
| 1469 |
+
context_window=context_window,
|
| 1470 |
+
)
|
| 1471 |
+
if stream:
|
| 1472 |
+
print(text)
|
| 1473 |
+
return text
|
| 1474 |
+
if decode_mode == "legacy":
|
| 1475 |
+
text, _, _, _ = generate_candidate(
|
| 1476 |
+
model=model,
|
| 1477 |
+
tokenizer=tokenizer,
|
| 1478 |
+
prompt=prompt,
|
| 1479 |
+
max_new_tokens=max_new_tokens,
|
| 1480 |
+
temperature=temperature,
|
| 1481 |
+
top_k=top_k,
|
| 1482 |
+
repetition_penalty=repetition_penalty,
|
| 1483 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 1484 |
+
device=device,
|
| 1485 |
+
sft_mode=sft_mode,
|
| 1486 |
+
force_thought=force_thought,
|
| 1487 |
+
stream=stream,
|
| 1488 |
+
context_window=context_window,
|
| 1489 |
+
)
|
| 1490 |
+
return text
|
| 1491 |
+
candidates: List[Tuple[float, str, str, float]] = []
|
| 1492 |
+
for _ in range(max(1, best_of)):
|
| 1493 |
+
candidate_text, raw_text, avg_logprob, _ = generate_candidate(
|
| 1494 |
+
model=model,
|
| 1495 |
+
tokenizer=tokenizer,
|
| 1496 |
+
prompt=prompt,
|
| 1497 |
+
max_new_tokens=max_new_tokens,
|
| 1498 |
+
temperature=temperature,
|
| 1499 |
+
top_k=top_k,
|
| 1500 |
+
repetition_penalty=repetition_penalty,
|
| 1501 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 1502 |
+
device=device,
|
| 1503 |
+
sft_mode=sft_mode,
|
| 1504 |
+
force_thought=force_thought,
|
| 1505 |
+
stream=False,
|
| 1506 |
+
context_window=context_window,
|
| 1507 |
+
)
|
| 1508 |
+
score = score_candidate(prompt, raw_text, candidate_text, avg_logprob)
|
| 1509 |
+
candidates.append((score, candidate_text, raw_text, avg_logprob))
|
| 1510 |
+
best_score, best_text, _, _ = max(candidates, key=lambda item: item[0])
|
| 1511 |
+
if stream:
|
| 1512 |
+
print(best_text, end="", flush=True)
|
| 1513 |
+
print()
|
| 1514 |
+
return best_text
|
| 1515 |
+
|
| 1516 |
+
|
| 1517 |
+
# ---------------------------------------------------------------------------
|
| 1518 |
+
# Web server (from interactive.py)
|
| 1519 |
+
# ---------------------------------------------------------------------------
|
| 1520 |
+
|
| 1521 |
+
ROOT = Path(__file__).resolve().parent
|
| 1522 |
+
if str(ROOT) not in sys.path:
|
| 1523 |
+
sys.path.insert(0, str(ROOT))
|
| 1524 |
+
|
| 1525 |
+
|
| 1526 |
+
HF_ORG = "CompactAI"
|
| 1527 |
+
HF_API = "https://huggingface.co/api"
|
| 1528 |
+
CACHE_ROOT = Path.home() / ".cache" / "compactai_web"
|
| 1529 |
+
USER_AGENT = "Mozilla/5.0 CompactAI-Web"
|
| 1530 |
+
MODEL_CACHE: dict[tuple[str, str], dict[str, object]] = {}
|
| 1531 |
+
MODEL_CACHE_LOCK = threading.RLock()
|
| 1532 |
+
GENERATION_LOCK = threading.Lock()
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
def request_json(url: str):
|
| 1536 |
+
req = Request(url, headers={"User-Agent": USER_AGENT})
|
| 1537 |
+
with urlopen(req, timeout=60) as response:
|
| 1538 |
+
return json.loads(response.read().decode("utf-8"))
|
| 1539 |
+
|
| 1540 |
+
|
| 1541 |
+
def request_text(url: str) -> str:
|
| 1542 |
+
req = Request(url, headers={"User-Agent": USER_AGENT})
|
| 1543 |
+
with urlopen(req, timeout=60) as response:
|
| 1544 |
+
return response.read().decode("utf-8", errors="replace")
|
| 1545 |
+
|
| 1546 |
+
|
| 1547 |
+
def download_file(url: str, destination: Path) -> None:
|
| 1548 |
+
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 1549 |
+
temp_path = destination.with_suffix(destination.suffix + ".tmp")
|
| 1550 |
+
req = Request(url, headers={"User-Agent": USER_AGENT})
|
| 1551 |
+
with urlopen(req, timeout=120) as response, temp_path.open("wb") as handle:
|
| 1552 |
+
shutil.copyfileobj(response, handle)
|
| 1553 |
+
temp_path.replace(destination)
|
| 1554 |
+
|
| 1555 |
+
|
| 1556 |
+
def normalize_repo_id(raw_repo_id: str) -> str:
|
| 1557 |
+
if not isinstance(raw_repo_id, str):
|
| 1558 |
+
return ""
|
| 1559 |
+
repo_id = raw_repo_id.strip()
|
| 1560 |
+
if not repo_id:
|
| 1561 |
+
return ""
|
| 1562 |
+
try:
|
| 1563 |
+
repo_id = unquote(repo_id)
|
| 1564 |
+
except Exception:
|
| 1565 |
+
pass
|
| 1566 |
+
return (
|
| 1567 |
+
repo_id.replace("https://huggingface.co/", "")
|
| 1568 |
+
.replace("http://huggingface.co/", "")
|
| 1569 |
+
.replace("api/models/", "")
|
| 1570 |
+
.replace("models/", "")
|
| 1571 |
+
.split("?", 1)[0]
|
| 1572 |
+
.split("#", 1)[0]
|
| 1573 |
+
.strip("/")
|
| 1574 |
+
)
|
| 1575 |
+
|
| 1576 |
+
|
| 1577 |
+
def series_from_name(name: str) -> str | None:
|
| 1578 |
+
lower = (name or "").lower()
|
| 1579 |
+
if "haiku" in lower:
|
| 1580 |
+
return "Haiku"
|
| 1581 |
+
if "sonnet" in lower:
|
| 1582 |
+
return "Sonnet"
|
| 1583 |
+
if "opus" in lower:
|
| 1584 |
+
return "Opus"
|
| 1585 |
+
return None
|
| 1586 |
+
|
| 1587 |
+
|
| 1588 |
+
def encoded_repo_id(repo_id: str) -> str:
|
| 1589 |
+
return "/".join(
|
| 1590 |
+
quote(part, safe="") for part in normalize_repo_id(repo_id).split("/") if part
|
| 1591 |
+
)
|
| 1592 |
+
|
| 1593 |
+
|
| 1594 |
+
def hf_file_url(repo_id: str, filename: str) -> str:
|
| 1595 |
+
encoded_name = "/".join(
|
| 1596 |
+
quote(part, safe="") for part in filename.split("/") if part
|
| 1597 |
+
)
|
| 1598 |
+
return (
|
| 1599 |
+
f"https://huggingface.co/{encoded_repo_id(repo_id)}/resolve/main/{encoded_name}"
|
| 1600 |
+
)
|
| 1601 |
+
|
| 1602 |
+
|
| 1603 |
+
def model_list() -> list[dict[str, object]]:
|
| 1604 |
+
data = request_json(f"{HF_API}/models?author={quote(HF_ORG)}&full=true&limit=200")
|
| 1605 |
+
models: list[dict[str, object]] = []
|
| 1606 |
+
for item in data:
|
| 1607 |
+
siblings = item.get("siblings") or []
|
| 1608 |
+
filenames = [s.get("rfilename", "") for s in siblings if isinstance(s, dict)]
|
| 1609 |
+
has_model = "model.pt" in filenames or "model/model.pt" in filenames
|
| 1610 |
+
has_pretrain = "pretrain.pt" in filenames or "model/pretrain.pt" in filenames
|
| 1611 |
+
has_tokenizer = (
|
| 1612 |
+
"tokenizer.json" in filenames or "model/tokenizer.json" in filenames
|
| 1613 |
+
)
|
| 1614 |
+
if not has_model and not has_pretrain:
|
| 1615 |
+
continue
|
| 1616 |
+
name = (item.get("id") or "").split("/")[-1]
|
| 1617 |
+
series = series_from_name(name)
|
| 1618 |
+
if not series:
|
| 1619 |
+
continue
|
| 1620 |
+
models.append(
|
| 1621 |
+
{
|
| 1622 |
+
"id": item.get("id", ""),
|
| 1623 |
+
"name": name,
|
| 1624 |
+
"series": series,
|
| 1625 |
+
"downloads": item.get("downloads", 0) or 0,
|
| 1626 |
+
"likes": item.get("likes", 0) or 0,
|
| 1627 |
+
"has_model": has_model,
|
| 1628 |
+
"has_pretrain": has_pretrain,
|
| 1629 |
+
"has_tokenizer": has_tokenizer,
|
| 1630 |
+
}
|
| 1631 |
+
)
|
| 1632 |
+
return sorted(models, key=lambda entry: entry["downloads"], reverse=True)
|
| 1633 |
+
|
| 1634 |
+
|
| 1635 |
+
def model_details(repo_id: str) -> dict[str, object] | None:
|
| 1636 |
+
normalized = normalize_repo_id(repo_id)
|
| 1637 |
+
if not normalized:
|
| 1638 |
+
return None
|
| 1639 |
+
data = request_json(f"{HF_API}/models/{encoded_repo_id(normalized)}")
|
| 1640 |
+
siblings = data.get("siblings") or []
|
| 1641 |
+
files: dict[str, dict[str, float]] = {}
|
| 1642 |
+
has_model = False
|
| 1643 |
+
has_pretrain = False
|
| 1644 |
+
for sibling in siblings:
|
| 1645 |
+
if not isinstance(sibling, dict):
|
| 1646 |
+
continue
|
| 1647 |
+
filename = sibling.get("rfilename") or ""
|
| 1648 |
+
if not filename:
|
| 1649 |
+
continue
|
| 1650 |
+
size_mb = round((sibling.get("size") or 0) / (1024 * 1024), 2)
|
| 1651 |
+
files[filename] = {"size_mb": size_mb}
|
| 1652 |
+
if filename.startswith("model/"):
|
| 1653 |
+
files[filename.removeprefix("model/")] = {"size_mb": size_mb}
|
| 1654 |
+
if filename in {"model.pt", "model/model.pt"}:
|
| 1655 |
+
has_model = True
|
| 1656 |
+
if filename in {"pretrain.pt", "model/pretrain.pt"}:
|
| 1657 |
+
has_pretrain = True
|
| 1658 |
+
readme_raw = ""
|
| 1659 |
+
try:
|
| 1660 |
+
readme_raw = request_text(
|
| 1661 |
+
f"https://huggingface.co/{encoded_repo_id(normalized)}/raw/main/README.md"
|
| 1662 |
+
)
|
| 1663 |
+
except Exception:
|
| 1664 |
+
readme_raw = ""
|
| 1665 |
+
name = (data.get("id") or normalized).split("/")[-1]
|
| 1666 |
+
return {
|
| 1667 |
+
"id": normalized,
|
| 1668 |
+
"name": name,
|
| 1669 |
+
"series": series_from_name(name) or "Sonnet",
|
| 1670 |
+
"downloads": data.get("downloads", 0) or 0,
|
| 1671 |
+
"files": files,
|
| 1672 |
+
"readme_raw": readme_raw,
|
| 1673 |
+
"hf_model_id": normalized,
|
| 1674 |
+
"has_model": has_model,
|
| 1675 |
+
"has_pretrain": has_pretrain,
|
| 1676 |
+
}
|
| 1677 |
+
|
| 1678 |
+
|
| 1679 |
+
def cache_dir(repo_id: str, model_type: str) -> Path:
|
| 1680 |
+
return CACHE_ROOT / normalize_repo_id(repo_id).replace("/", "__") / model_type
|
| 1681 |
+
|
| 1682 |
+
|
| 1683 |
+
def artifact_candidates(model_type: str) -> list[str]:
|
| 1684 |
+
return (
|
| 1685 |
+
["model/pretrain.pt", "pretrain.pt"]
|
| 1686 |
+
if model_type == "pretrain"
|
| 1687 |
+
else ["model/model.pt", "model.pt"]
|
| 1688 |
+
)
|
| 1689 |
+
|
| 1690 |
+
|
| 1691 |
+
def ensure_artifact(repo_id: str, model_type: str, destination_name: str) -> Path:
|
| 1692 |
+
normalized = normalize_repo_id(repo_id)
|
| 1693 |
+
target = cache_dir(normalized, model_type) / destination_name
|
| 1694 |
+
if target.exists():
|
| 1695 |
+
return target
|
| 1696 |
+
last_error: Exception | None = None
|
| 1697 |
+
for candidate in (
|
| 1698 |
+
artifact_candidates(model_type)
|
| 1699 |
+
if destination_name.endswith(".pt")
|
| 1700 |
+
else ["model/tokenizer.json", "tokenizer.json"]
|
| 1701 |
+
):
|
| 1702 |
+
try:
|
| 1703 |
+
download_file(hf_file_url(normalized, candidate), target)
|
| 1704 |
+
return target
|
| 1705 |
+
except Exception as exc:
|
| 1706 |
+
last_error = exc
|
| 1707 |
+
raise RuntimeError(
|
| 1708 |
+
f"Unable to download {destination_name} for {normalized}: {last_error}"
|
| 1709 |
+
)
|
| 1710 |
+
|
| 1711 |
+
|
| 1712 |
+
def series_config(series: str) -> dict[str, object]:
|
| 1713 |
+
return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"])
|
| 1714 |
+
|
| 1715 |
+
|
| 1716 |
+
def load_bundle(repo_id: str, model_type: str) -> dict[str, object]:
|
| 1717 |
+
normalized = normalize_repo_id(repo_id)
|
| 1718 |
+
details = model_details(normalized)
|
| 1719 |
+
if not details:
|
| 1720 |
+
raise RuntimeError("Model details are unavailable.")
|
| 1721 |
+
series = str(details["series"])
|
| 1722 |
+
key = (normalized, model_type)
|
| 1723 |
+
with MODEL_CACHE_LOCK:
|
| 1724 |
+
cached = MODEL_CACHE.get(key)
|
| 1725 |
+
if cached:
|
| 1726 |
+
return cached
|
| 1727 |
+
bundle_dir = cache_dir(normalized, model_type)
|
| 1728 |
+
bundle_dir.mkdir(parents=True, exist_ok=True)
|
| 1729 |
+
model_path = bundle_dir / (
|
| 1730 |
+
"pretrain.pt" if model_type == "pretrain" else "model.pt"
|
| 1731 |
+
)
|
| 1732 |
+
tokenizer_path = bundle_dir / "tokenizer.json"
|
| 1733 |
+
if not model_path.exists():
|
| 1734 |
+
ensure_artifact(normalized, model_type, model_path.name)
|
| 1735 |
+
if not tokenizer_path.exists():
|
| 1736 |
+
ensure_artifact(normalized, model_type, tokenizer_path.name)
|
| 1737 |
+
tokenizer = WordTokenizer.load(tokenizer_path)
|
| 1738 |
+
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
|
| 1739 |
+
cfg = series_config(series)
|
| 1740 |
+
vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
|
| 1741 |
+
state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
|
| 1742 |
+
# Auto-detect new arch features from checkpoint weights
|
| 1743 |
+
engram_dim = _detect_engram_dim(state_dict) or int(
|
| 1744 |
+
cfg.get("engram_dim", model_config.engram_dim)
|
| 1745 |
+
)
|
| 1746 |
+
mhc_expansion = _detect_mhc_expansion(state_dict) or int(
|
| 1747 |
+
cfg.get("mhc_expansion", model_config.mhc_expansion)
|
| 1748 |
+
)
|
| 1749 |
+
model = TinyMemoryLM(
|
| 1750 |
+
vocab_size=vocab_size,
|
| 1751 |
+
dim=int(cfg.get("dim", model_config.dim)),
|
| 1752 |
+
n_unique_layers=int(
|
| 1753 |
+
cfg.get("n_unique_layers", model_config.n_unique_layers)
|
| 1754 |
+
),
|
| 1755 |
+
n_logical_layers=int(
|
| 1756 |
+
cfg.get("n_logical_layers", model_config.n_logical_layers)
|
| 1757 |
+
),
|
| 1758 |
+
n_heads=int(cfg.get("n_heads", model_config.n_heads)),
|
| 1759 |
+
n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
|
| 1760 |
+
ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
|
| 1761 |
+
dropout=float(cfg.get("dropout", model_config.dropout)),
|
| 1762 |
+
mtp_horizons=tuple(
|
| 1763 |
+
int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)
|
| 1764 |
+
),
|
| 1765 |
+
grad_checkpoint=False,
|
| 1766 |
+
sliding_window=int(
|
| 1767 |
+
cfg.get("sliding_window_size", model_config.sliding_window_size)
|
| 1768 |
+
),
|
| 1769 |
+
rope_fraction=float(
|
| 1770 |
+
cfg.get("rope_fraction", model_config.rope_fraction)
|
| 1771 |
+
),
|
| 1772 |
+
embed_scale=bool(
|
| 1773 |
+
cfg.get("embed_scale", model_config.embed_scale)
|
| 1774 |
+
),
|
| 1775 |
+
engram_dim=engram_dim,
|
| 1776 |
+
engram_heads=int(cfg.get("engram_heads", model_config.engram_heads)),
|
| 1777 |
+
engram_table_size=int(
|
| 1778 |
+
cfg.get("engram_table_size", model_config.engram_table_size)
|
| 1779 |
+
),
|
| 1780 |
+
engram_max_ngram=int(
|
| 1781 |
+
cfg.get("engram_max_ngram", model_config.engram_max_ngram)
|
| 1782 |
+
),
|
| 1783 |
+
mhc_expansion=mhc_expansion,
|
| 1784 |
+
)
|
| 1785 |
+
model.load_state_dict(state_dict, strict=False)
|
| 1786 |
+
model.eval()
|
| 1787 |
+
if tokenizer.vocab_size > vocab_size:
|
| 1788 |
+
model.resize_token_embeddings(tokenizer.vocab_size)
|
| 1789 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 1790 |
+
model = model.to(device)
|
| 1791 |
+
bundle = {
|
| 1792 |
+
"repo_id": normalized,
|
| 1793 |
+
"name": details["name"],
|
| 1794 |
+
"series": series,
|
| 1795 |
+
"type": model_type,
|
| 1796 |
+
"model": model,
|
| 1797 |
+
"tokenizer": tokenizer,
|
| 1798 |
+
"device": device,
|
| 1799 |
+
"model_path": str(model_path),
|
| 1800 |
+
"tokenizer_path": str(tokenizer_path),
|
| 1801 |
+
"downloads": details["downloads"],
|
| 1802 |
+
}
|
| 1803 |
+
MODEL_CACHE[key] = bundle
|
| 1804 |
+
return bundle
|
| 1805 |
+
|
| 1806 |
+
|
| 1807 |
+
def ensure_port(start_port: int) -> int:
|
| 1808 |
+
for port in range(start_port, start_port + 50):
|
| 1809 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
| 1810 |
+
try:
|
| 1811 |
+
sock.bind(("127.0.0.1", port))
|
| 1812 |
+
except OSError:
|
| 1813 |
+
continue
|
| 1814 |
+
return port
|
| 1815 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
| 1816 |
+
sock.bind(("127.0.0.1", 0))
|
| 1817 |
+
return sock.getsockname()[1]
|
| 1818 |
+
|
| 1819 |
+
|
| 1820 |
+
def page_html() -> str:
|
| 1821 |
+
return f"""<!doctype html>
|
| 1822 |
+
<html lang="en">
|
| 1823 |
+
<head>
|
| 1824 |
+
<meta charset="utf-8">
|
| 1825 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
| 1826 |
+
<title>CompactAI Web</title>
|
| 1827 |
+
<style>
|
| 1828 |
+
:root {{
|
| 1829 |
+
color-scheme: dark;
|
| 1830 |
+
--bg: #050505;
|
| 1831 |
+
--panel: #111111;
|
| 1832 |
+
--panel-2: #161616;
|
| 1833 |
+
--line: #262626;
|
| 1834 |
+
--text: #f5f5f5;
|
| 1835 |
+
--muted: #a3a3a3;
|
| 1836 |
+
--accent: #d97706;
|
| 1837 |
+
--accent-2: #b45309;
|
| 1838 |
+
--soft: #1f1f1f;
|
| 1839 |
+
}}
|
| 1840 |
+
* {{ box-sizing: border-box; }}
|
| 1841 |
+
body {{
|
| 1842 |
+
margin: 0;
|
| 1843 |
+
font-family: Geist, -apple-system, BlinkMacSystemFont, sans-serif;
|
| 1844 |
+
background: var(--bg);
|
| 1845 |
+
color: var(--text);
|
| 1846 |
+
line-height: 1.5;
|
| 1847 |
+
}}
|
| 1848 |
+
a {{ color: inherit; }}
|
| 1849 |
+
.wrap {{ max-width: 1120px; margin: 0 auto; padding: 28px 20px 40px; }}
|
| 1850 |
+
.hero {{
|
| 1851 |
+
display: flex;
|
| 1852 |
+
justify-content: space-between;
|
| 1853 |
+
align-items: end;
|
| 1854 |
+
gap: 16px;
|
| 1855 |
+
padding: 22px 0 28px;
|
| 1856 |
+
border-bottom: 1px solid var(--line);
|
| 1857 |
+
margin-bottom: 22px;
|
| 1858 |
+
}}
|
| 1859 |
+
h1 {{ margin: 0; font-size: clamp(2rem, 5vw, 3.5rem); letter-spacing: -0.04em; }}
|
| 1860 |
+
.subtitle {{ margin: 10px 0 0; color: var(--muted); max-width: 58ch; }}
|
| 1861 |
+
.grid {{
|
| 1862 |
+
display: grid;
|
| 1863 |
+
grid-template-columns: 1.1fr 1fr;
|
| 1864 |
+
gap: 18px;
|
| 1865 |
+
}}
|
| 1866 |
+
.panel {{
|
| 1867 |
+
background: var(--panel);
|
| 1868 |
+
border: 1px solid var(--line);
|
| 1869 |
+
border-radius: 18px;
|
| 1870 |
+
padding: 18px;
|
| 1871 |
+
}}
|
| 1872 |
+
.panel h2 {{ margin: 0 0 12px; font-size: 15px; letter-spacing: 0.02em; text-transform: uppercase; color: var(--muted); }}
|
| 1873 |
+
.row {{ display: flex; gap: 10px; flex-wrap: wrap; }}
|
| 1874 |
+
select, textarea, input {{
|
| 1875 |
+
width: 100%;
|
| 1876 |
+
background: var(--panel-2);
|
| 1877 |
+
color: var(--text);
|
| 1878 |
+
border: 1px solid var(--line);
|
| 1879 |
+
border-radius: 12px;
|
| 1880 |
+
padding: 12px 14px;
|
| 1881 |
+
font: inherit;
|
| 1882 |
+
outline: none;
|
| 1883 |
+
}}
|
| 1884 |
+
textarea {{ min-height: 170px; resize: vertical; }}
|
| 1885 |
+
select {{ appearance: none; }}
|
| 1886 |
+
.choice {{
|
| 1887 |
+
flex: 1 1 150px;
|
| 1888 |
+
display: flex;
|
| 1889 |
+
align-items: center;
|
| 1890 |
+
gap: 10px;
|
| 1891 |
+
padding: 10px 12px;
|
| 1892 |
+
border: 1px solid var(--line);
|
| 1893 |
+
border-radius: 12px;
|
| 1894 |
+
background: var(--panel-2);
|
| 1895 |
+
cursor: pointer;
|
| 1896 |
+
}}
|
| 1897 |
+
.choice input {{ width: auto; }}
|
| 1898 |
+
.btns {{ display: flex; flex-wrap: wrap; gap: 10px; }}
|
| 1899 |
+
button {{
|
| 1900 |
+
border: 1px solid var(--line);
|
| 1901 |
+
border-radius: 12px;
|
| 1902 |
+
padding: 11px 14px;
|
| 1903 |
+
background: var(--soft);
|
| 1904 |
+
color: var(--text);
|
| 1905 |
+
font: inherit;
|
| 1906 |
+
cursor: pointer;
|
| 1907 |
+
transition: transform 0.15s ease, border-color 0.15s ease, background 0.15s ease;
|
| 1908 |
+
}}
|
| 1909 |
+
button:hover {{ transform: translateY(-1px); border-color: #3a3a3a; }}
|
| 1910 |
+
.primary {{ background: var(--accent); border-color: var(--accent); color: #fff; }}
|
| 1911 |
+
.primary:hover {{ background: var(--accent-2); border-color: var(--accent-2); }}
|
| 1912 |
+
.status {{
|
| 1913 |
+
margin-top: 12px;
|
| 1914 |
+
color: var(--muted);
|
| 1915 |
+
font-size: 13px;
|
| 1916 |
+
min-height: 1.4em;
|
| 1917 |
+
}}
|
| 1918 |
+
.output {{
|
| 1919 |
+
white-space: pre-wrap;
|
| 1920 |
+
background: #0b0b0b;
|
| 1921 |
+
border: 1px solid var(--line);
|
| 1922 |
+
border-radius: 16px;
|
| 1923 |
+
min-height: 280px;
|
| 1924 |
+
padding: 16px;
|
| 1925 |
+
color: #e7e5e4;
|
| 1926 |
+
overflow: auto;
|
| 1927 |
+
}}
|
| 1928 |
+
.meta {{
|
| 1929 |
+
display: flex;
|
| 1930 |
+
flex-wrap: wrap;
|
| 1931 |
+
gap: 8px;
|
| 1932 |
+
margin-top: 8px;
|
| 1933 |
+
}}
|
| 1934 |
+
.chip {{
|
| 1935 |
+
display: inline-flex;
|
| 1936 |
+
align-items: center;
|
| 1937 |
+
gap: 6px;
|
| 1938 |
+
padding: 6px 10px;
|
| 1939 |
+
border-radius: 999px;
|
| 1940 |
+
border: 1px solid var(--line);
|
| 1941 |
+
background: var(--panel-2);
|
| 1942 |
+
font-size: 12px;
|
| 1943 |
+
color: var(--muted);
|
| 1944 |
+
}}
|
| 1945 |
+
.code {{
|
| 1946 |
+
margin-top: 14px;
|
| 1947 |
+
padding: 12px 14px;
|
| 1948 |
+
border-radius: 12px;
|
| 1949 |
+
border: 1px solid var(--line);
|
| 1950 |
+
background: #0b0b0b;
|
| 1951 |
+
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
| 1952 |
+
font-size: 13px;
|
| 1953 |
+
overflow-x: auto;
|
| 1954 |
+
}}
|
| 1955 |
+
@media (max-width: 900px) {{
|
| 1956 |
+
.grid {{ grid-template-columns: 1fr; }}
|
| 1957 |
+
.hero {{ align-items: start; flex-direction: column; }}
|
| 1958 |
+
}}
|
| 1959 |
+
</style>
|
| 1960 |
+
</head>
|
| 1961 |
+
<body>
|
| 1962 |
+
<div class="wrap">
|
| 1963 |
+
<div class="hero">
|
| 1964 |
+
<div>
|
| 1965 |
+
<h1>CompactAI Web</h1>
|
| 1966 |
+
<p class="subtitle">Pull a model from Hugging Face, keep it cached locally, and chat in the browser.</p>
|
| 1967 |
+
</div>
|
| 1968 |
+
<div class="meta">
|
| 1969 |
+
<span class="chip">Hugging Face: CompactAI</span>
|
| 1970 |
+
<span class="chip">pip install -r requirements.txt</span>
|
| 1971 |
+
<span class="chip">Local inference</span>
|
| 1972 |
+
</div>
|
| 1973 |
+
</div>
|
| 1974 |
+
|
| 1975 |
+
<div class="grid">
|
| 1976 |
+
<section class="panel">
|
| 1977 |
+
<h2>Model</h2>
|
| 1978 |
+
<select id="modelSelect"></select>
|
| 1979 |
+
<div class="row" style="margin-top: 10px;">
|
| 1980 |
+
<label class="choice"><input type="radio" name="type" value="model" checked> Instruct / final</label>
|
| 1981 |
+
<label class="choice"><input type="radio" name="type" value="pretrain"> Pretrain</label>
|
| 1982 |
+
</div>
|
| 1983 |
+
<div class="btns" style="margin-top: 12px;">
|
| 1984 |
+
<button id="downloadBtn">Download</button>
|
| 1985 |
+
<button id="refreshBtn">Refresh models</button>
|
| 1986 |
+
</div>
|
| 1987 |
+
<div class="status" id="modelStatus">Loading model list…</div>
|
| 1988 |
+
<div class="code">python3 interactive_web.py</div>
|
| 1989 |
+
</section>
|
| 1990 |
+
|
| 1991 |
+
<section class="panel">
|
| 1992 |
+
<h2>Prompt</h2>
|
| 1993 |
+
<textarea id="prompt" placeholder="Ask something…"></textarea>
|
| 1994 |
+
<div class="row" style="margin-top: 10px;">
|
| 1995 |
+
<input id="temperature" type="number" min="0.1" max="2" step="0.05" value="0.8" style="flex: 1 1 120px;">
|
| 1996 |
+
<input id="topK" type="number" min="1" max="100" step="1" value="40" style="flex: 1 1 120px;">
|
| 1997 |
+
<input id="maxTokens" type="number" min="16" max="2048" step="16" value="256" style="flex: 1 1 120px;">
|
| 1998 |
+
</div>
|
| 1999 |
+
<div class="btns" style="margin-top: 12px;">
|
| 2000 |
+
<button id="generateBtn" class="primary">Generate</button>
|
| 2001 |
+
</div>
|
| 2002 |
+
<div class="status" id="genStatus"></div>
|
| 2003 |
+
</section>
|
| 2004 |
+
</div>
|
| 2005 |
+
|
| 2006 |
+
<section class="panel" style="margin-top: 18px;">
|
| 2007 |
+
<h2>Response</h2>
|
| 2008 |
+
<div id="output" class="output"></div>
|
| 2009 |
+
</section>
|
| 2010 |
+
</div>
|
| 2011 |
+
|
| 2012 |
+
<script>
|
| 2013 |
+
const modelSelect = document.getElementById('modelSelect');
|
| 2014 |
+
const modelStatus = document.getElementById('modelStatus');
|
| 2015 |
+
const genStatus = document.getElementById('genStatus');
|
| 2016 |
+
const output = document.getElementById('output');
|
| 2017 |
+
const promptBox = document.getElementById('prompt');
|
| 2018 |
+
|
| 2019 |
+
async function api(path, body) {{
|
| 2020 |
+
const response = await fetch(path, {{
|
| 2021 |
+
method: body ? 'POST' : 'GET',
|
| 2022 |
+
headers: body ? {{ 'Content-Type': 'application/json' }} : undefined,
|
| 2023 |
+
body: body ? JSON.stringify(body) : undefined,
|
| 2024 |
+
}});
|
| 2025 |
+
return response.json();
|
| 2026 |
+
}}
|
| 2027 |
+
|
| 2028 |
+
function currentType() {{
|
| 2029 |
+
return document.querySelector('input[name="type"]:checked').value;
|
| 2030 |
+
}}
|
| 2031 |
+
|
| 2032 |
+
function currentModelId() {{
|
| 2033 |
+
return modelSelect.value;
|
| 2034 |
+
}}
|
| 2035 |
+
|
| 2036 |
+
function setModels(models) {{
|
| 2037 |
+
modelSelect.innerHTML = '';
|
| 2038 |
+
for (const model of models) {{
|
| 2039 |
+
const option = document.createElement('option');
|
| 2040 |
+
option.value = model.id;
|
| 2041 |
+
option.textContent = `${{model.name}} • ${{model.series}}`;
|
| 2042 |
+
modelSelect.appendChild(option);
|
| 2043 |
+
}}
|
| 2044 |
+
if (models.length === 0) {{
|
| 2045 |
+
const option = document.createElement('option');
|
| 2046 |
+
option.value = '';
|
| 2047 |
+
option.textContent = 'No CompactAI models found';
|
| 2048 |
+
modelSelect.appendChild(option);
|
| 2049 |
+
}}
|
| 2050 |
+
}}
|
| 2051 |
+
|
| 2052 |
+
async function refreshModels() {{
|
| 2053 |
+
modelStatus.textContent = 'Loading model list…';
|
| 2054 |
+
try {{
|
| 2055 |
+
const models = await api('/api/models');
|
| 2056 |
+
setModels(models);
|
| 2057 |
+
modelStatus.textContent = models.length ? `${{models.length}} models available from CompactAI` : 'No compatible models found.';
|
| 2058 |
+
}} catch (error) {{
|
| 2059 |
+
modelStatus.textContent = 'Failed to load model list.';
|
| 2060 |
+
}}
|
| 2061 |
+
}}
|
| 2062 |
+
|
| 2063 |
+
async function ensureModel() {{
|
| 2064 |
+
const modelId = currentModelId();
|
| 2065 |
+
if (!modelId) {{
|
| 2066 |
+
modelStatus.textContent = 'Pick a model first.';
|
| 2067 |
+
return null;
|
| 2068 |
+
}}
|
| 2069 |
+
modelStatus.textContent = 'Downloading model files…';
|
| 2070 |
+
const result = await api('/api/ensure', {{ modelId, type: currentType() }});
|
| 2071 |
+
if (!result.success) {{
|
| 2072 |
+
modelStatus.textContent = result.error || 'Download failed.';
|
| 2073 |
+
return null;
|
| 2074 |
+
}}
|
| 2075 |
+
modelStatus.textContent = `${{result.name}} ready on ${{result.series}}`;
|
| 2076 |
+
return result;
|
| 2077 |
+
}}
|
| 2078 |
+
|
| 2079 |
+
async function generate() {{
|
| 2080 |
+
output.textContent = '';
|
| 2081 |
+
genStatus.textContent = '';
|
| 2082 |
+
const modelId = currentModelId();
|
| 2083 |
+
const prompt = promptBox.value.trim();
|
| 2084 |
+
if (!modelId) {{
|
| 2085 |
+
genStatus.textContent = 'Pick a model first.';
|
| 2086 |
+
return;
|
| 2087 |
+
}}
|
| 2088 |
+
if (!prompt) {{
|
| 2089 |
+
genStatus.textContent = 'Enter a prompt first.';
|
| 2090 |
+
return;
|
| 2091 |
+
}}
|
| 2092 |
+
genStatus.textContent = 'Preparing model…';
|
| 2093 |
+
const result = await api('/api/generate', {{
|
| 2094 |
+
modelId,
|
| 2095 |
+
type: currentType(),
|
| 2096 |
+
prompt,
|
| 2097 |
+
temperature: Number(document.getElementById('temperature').value || 0.8),
|
| 2098 |
+
top_k: Number(document.getElementById('topK').value || 40),
|
| 2099 |
+
max_new_tokens: Number(document.getElementById('maxTokens').value || 256),
|
| 2100 |
+
}});
|
| 2101 |
+
if (!result.success) {{
|
| 2102 |
+
genStatus.textContent = result.error || 'Generation failed.';
|
| 2103 |
+
return;
|
| 2104 |
+
}}
|
| 2105 |
+
output.textContent = result.text || '';
|
| 2106 |
+
genStatus.textContent = 'Done.';
|
| 2107 |
+
}}
|
| 2108 |
+
|
| 2109 |
+
document.getElementById('refreshBtn').addEventListener('click', refreshModels);
|
| 2110 |
+
document.getElementById('downloadBtn').addEventListener('click', ensureModel);
|
| 2111 |
+
document.getElementById('generateBtn').addEventListener('click', generate);
|
| 2112 |
+
promptBox.addEventListener('keydown', (event) => {{
|
| 2113 |
+
if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {{
|
| 2114 |
+
event.preventDefault();
|
| 2115 |
+
generate();
|
| 2116 |
+
}}
|
| 2117 |
+
}});
|
| 2118 |
+
|
| 2119 |
+
refreshModels();
|
| 2120 |
+
</script>
|
| 2121 |
+
</body>
|
| 2122 |
+
</html>"""
|
| 2123 |
+
|
| 2124 |
+
|
| 2125 |
+
class Handler(BaseHTTPRequestHandler):
|
| 2126 |
+
def _send_json(self, payload, status=200):
|
| 2127 |
+
body = json.dumps(payload).encode("utf-8")
|
| 2128 |
+
self.send_response(status)
|
| 2129 |
+
self.send_header("Content-Type", "application/json; charset=utf-8")
|
| 2130 |
+
self.send_header("Content-Length", str(len(body)))
|
| 2131 |
+
self.send_header("Cache-Control", "no-store")
|
| 2132 |
+
self.end_headers()
|
| 2133 |
+
self.wfile.write(body)
|
| 2134 |
+
|
| 2135 |
+
def _send_html(self, payload: str, status=200):
|
| 2136 |
+
body = payload.encode("utf-8")
|
| 2137 |
+
self.send_response(status)
|
| 2138 |
+
self.send_header("Content-Type", "text/html; charset=utf-8")
|
| 2139 |
+
self.send_header("Content-Length", str(len(body)))
|
| 2140 |
+
self.send_header("Cache-Control", "no-store")
|
| 2141 |
+
self.end_headers()
|
| 2142 |
+
self.wfile.write(body)
|
| 2143 |
+
|
| 2144 |
+
def do_GET(self):
|
| 2145 |
+
parsed = urlparse(self.path)
|
| 2146 |
+
if parsed.path in {"/", "/index.html"}:
|
| 2147 |
+
self._send_html(page_html())
|
| 2148 |
+
return
|
| 2149 |
+
if parsed.path == "/api/models":
|
| 2150 |
+
try:
|
| 2151 |
+
self._send_json(model_list())
|
| 2152 |
+
except Exception as exc:
|
| 2153 |
+
self._send_json({"success": False, "error": str(exc)}, 500)
|
| 2154 |
+
return
|
| 2155 |
+
if parsed.path.startswith("/api/models/"):
|
| 2156 |
+
repo_id = normalize_repo_id(parsed.path.removeprefix("/api/models/"))
|
| 2157 |
+
try:
|
| 2158 |
+
details = model_details(repo_id)
|
| 2159 |
+
if not details:
|
| 2160 |
+
self._send_json(
|
| 2161 |
+
{"success": False, "error": "Model not found."}, 404
|
| 2162 |
+
)
|
| 2163 |
+
else:
|
| 2164 |
+
self._send_json(details)
|
| 2165 |
+
except Exception as exc:
|
| 2166 |
+
self._send_json({"success": False, "error": str(exc)}, 500)
|
| 2167 |
+
return
|
| 2168 |
+
self._send_json({"success": False, "error": "Not found."}, 404)
|
| 2169 |
+
|
| 2170 |
+
def do_POST(self):
|
| 2171 |
+
parsed = urlparse(self.path)
|
| 2172 |
+
length = int(self.headers.get("Content-Length", "0") or "0")
|
| 2173 |
+
raw = self.rfile.read(length).decode("utf-8") if length else "{}"
|
| 2174 |
+
try:
|
| 2175 |
+
payload = json.loads(raw or "{}")
|
| 2176 |
+
except Exception:
|
| 2177 |
+
payload = {}
|
| 2178 |
+
if parsed.path == "/api/ensure":
|
| 2179 |
+
try:
|
| 2180 |
+
repo_id = normalize_repo_id(payload.get("modelId", ""))
|
| 2181 |
+
model_type = payload.get("type", "model")
|
| 2182 |
+
if not repo_id:
|
| 2183 |
+
self._send_json(
|
| 2184 |
+
{"success": False, "error": "Missing model ID."}, 400
|
| 2185 |
+
)
|
| 2186 |
+
return
|
| 2187 |
+
details = model_details(repo_id)
|
| 2188 |
+
if not details:
|
| 2189 |
+
self._send_json(
|
| 2190 |
+
{"success": False, "error": "Model not found."}, 404
|
| 2191 |
+
)
|
| 2192 |
+
return
|
| 2193 |
+
bundle = load_bundle(repo_id, model_type)
|
| 2194 |
+
self._send_json(
|
| 2195 |
+
{
|
| 2196 |
+
"success": True,
|
| 2197 |
+
"id": bundle["repo_id"],
|
| 2198 |
+
"name": bundle["name"],
|
| 2199 |
+
"series": bundle["series"],
|
| 2200 |
+
"type": bundle["type"],
|
| 2201 |
+
}
|
| 2202 |
+
)
|
| 2203 |
+
except Exception as exc:
|
| 2204 |
+
self._send_json({"success": False, "error": str(exc)}, 500)
|
| 2205 |
+
return
|
| 2206 |
+
if parsed.path == "/api/generate":
|
| 2207 |
+
try:
|
| 2208 |
+
repo_id = normalize_repo_id(payload.get("modelId", ""))
|
| 2209 |
+
model_type = payload.get("type", "model")
|
| 2210 |
+
prompt = str(payload.get("prompt", ""))
|
| 2211 |
+
if not repo_id:
|
| 2212 |
+
self._send_json(
|
| 2213 |
+
{"success": False, "error": "Missing model ID."}, 400
|
| 2214 |
+
)
|
| 2215 |
+
return
|
| 2216 |
+
bundle = load_bundle(repo_id, model_type)
|
| 2217 |
+
with GENERATION_LOCK:
|
| 2218 |
+
text = generate(
|
| 2219 |
+
model=bundle["model"],
|
| 2220 |
+
tokenizer=bundle["tokenizer"],
|
| 2221 |
+
prompt=prompt,
|
| 2222 |
+
max_new_tokens=int(payload.get("max_new_tokens", 256)),
|
| 2223 |
+
temperature=float(payload.get("temperature", 0.8)),
|
| 2224 |
+
top_k=int(payload.get("top_k", 40)),
|
| 2225 |
+
repetition_penalty=float(
|
| 2226 |
+
payload.get("repetition_penalty", 1.0)
|
| 2227 |
+
),
|
| 2228 |
+
device=str(bundle["device"]),
|
| 2229 |
+
sft_mode=model_type != "pretrain",
|
| 2230 |
+
force_thought=bool(payload.get("force_thought", False)),
|
| 2231 |
+
stream=False,
|
| 2232 |
+
decode_mode=str(payload.get("decode_mode", "legacy")),
|
| 2233 |
+
best_of=int(payload.get("best_of", 3)),
|
| 2234 |
+
no_repeat_ngram_size=int(
|
| 2235 |
+
payload.get("no_repeat_ngram_size", 3)
|
| 2236 |
+
),
|
| 2237 |
+
context_window=int(payload.get("context_window", 2048)),
|
| 2238 |
+
beam_width=int(payload.get("beam_width", 8)),
|
| 2239 |
+
length_penalty=float(payload.get("length_penalty", 0.7)),
|
| 2240 |
+
)
|
| 2241 |
+
self._send_json(
|
| 2242 |
+
{
|
| 2243 |
+
"success": True,
|
| 2244 |
+
"text": text,
|
| 2245 |
+
"name": bundle["name"],
|
| 2246 |
+
"series": bundle["series"],
|
| 2247 |
+
}
|
| 2248 |
+
)
|
| 2249 |
+
except Exception as exc:
|
| 2250 |
+
self._send_json({"success": False, "error": str(exc)}, 500)
|
| 2251 |
+
return
|
| 2252 |
+
self._send_json({"success": False, "error": "Not found."}, 404)
|
| 2253 |
+
|
| 2254 |
+
def log_message(self, format, *args):
|
| 2255 |
+
return
|
| 2256 |
+
|
| 2257 |
+
|
| 2258 |
+
def main():
|
| 2259 |
+
CACHE_ROOT.mkdir(parents=True, exist_ok=True)
|
| 2260 |
+
port = ensure_port(int(os.environ.get("PORT", "7860")))
|
| 2261 |
+
server = ThreadingHTTPServer(("127.0.0.1", port), Handler)
|
| 2262 |
+
url = f"http://127.0.0.1:{port}"
|
| 2263 |
+
print(url, flush=True)
|
| 2264 |
+
try:
|
| 2265 |
+
webbrowser.open(url)
|
| 2266 |
+
except Exception:
|
| 2267 |
+
pass
|
| 2268 |
+
try:
|
| 2269 |
+
server.serve_forever()
|
| 2270 |
+
except KeyboardInterrupt:
|
| 2271 |
+
pass
|
| 2272 |
+
finally:
|
| 2273 |
+
server.server_close()
|
| 2274 |
+
|
| 2275 |
+
|
| 2276 |
+
if __name__ == "__main__":
|
| 2277 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|