CompactAI commited on
Commit
093ccf9
·
verified ·
1 Parent(s): 92ed286

Upload 88 files

Browse files
downloads/CompactAI Studio Setup 1.0.0.exe CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:245e455f3dd33a7d2d1f91f456d9aa926705b7bd69c78f068e8c56e3e846ae3a
3
- size 133
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f685d561cb3d1e9b8f41bfea7b50c8d5bd0b72007000c7f70f63747127c5a57f
3
+ size 128
downloads/CompactAI Studio-1.0.0.AppImage CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:82ac5097ddba3db79ec3a05998820f44e1cd536ce60ae1eeaae08b391b466d1d
3
- size 133
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa6103e15bfdc80bfea40471ef418e54ca4f6f8b6c90ec166072d821d93dfe3c
3
+ size 128
downloads/index.html CHANGED
The diff for this file is too large to render. See raw diff
 
downloads/interactive.py CHANGED
@@ -18,6 +18,8 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
18
  from urllib.parse import quote, unquote, urlparse
19
  from urllib.request import Request, urlopen
20
 
 
 
21
  import torch
22
  import torch.nn as nn
23
  import torch.nn.functional as F
@@ -41,10 +43,17 @@ class ModelConfig:
41
  seq_len: int = 2048
42
  sliding_window_size: int = 512
43
  mtp_horizons: Tuple[int, ...] = (2, 3, 4)
44
- rope_fraction: float = 0.25
45
  embed_scale: bool = True
46
  logit_soft_cap: float = -1.0
47
  quantization: str = "nvfp4"
 
 
 
 
 
 
 
48
 
49
  @property
50
  def head_dim(self) -> int:
@@ -55,16 +64,17 @@ model_config = ModelConfig()
55
 
56
  MODEL_SERIES = {
57
  "haiku": {
58
- "dim": 128,
59
- "n_unique_layers": 8,
60
- "n_logical_layers": 16,
61
  "n_heads": 4,
62
  "n_kv_heads": 2,
63
- "ffn_dim": 224,
64
  "dropout": 0.0,
65
  "seq_len": 2048,
66
  "mtp_horizons": (2, 3, 4),
67
- "batch_size": 48,
 
68
  "grad_accum": 1,
69
  "lr": 8e-4,
70
  "min_lr": 1e-5,
@@ -74,29 +84,34 @@ MODEL_SERIES = {
74
  "weight_decay": 0.02,
75
  "pretrain_passes": 2,
76
  "sft_passes": 3,
77
- "max_sft_target_chars": 128,
78
- "use_grad_checkpoint": False,
79
- "use_torch_compile": True,
80
  "num_workers": 24,
81
  "prefetch_factor": 64,
82
  "shuffle_buffer": 8192,
83
  "max_pretrain_tokens": 0,
84
  "min_pretrain_tokens": 100_000_000,
85
  "quantization": "nvfp4",
 
 
 
 
 
86
  },
87
  "sonnet": {
88
- "dim": 768,
89
- "n_unique_layers": 18,
90
- "n_logical_layers": 36,
91
- "n_heads": 12,
92
  "n_kv_heads": 4,
93
- "ffn_dim": 2538,
94
  "dropout": 0.0,
95
  "seq_len": 2048,
96
  "mtp_horizons": (2,),
97
- "batch_size": 6,
 
98
  "grad_accum": 1,
99
- "lr": 2e-4,
100
  "min_lr": 2e-5,
101
  "sft_lr": 5e-5,
102
  "sft_min_lr": 5e-6,
@@ -104,27 +119,32 @@ MODEL_SERIES = {
104
  "weight_decay": 0.1,
105
  "pretrain_passes": 1,
106
  "sft_passes": 1,
107
- "max_sft_target_chars": 512,
108
  "use_grad_checkpoint": True,
109
- "use_torch_compile": True,
110
  "num_workers": 32,
111
- "prefetch_factor": 48,
112
  "shuffle_buffer": 16384,
113
  "max_pretrain_tokens": 0,
114
- "min_pretrain_tokens": 0,
115
  "quantization": "nvfp4",
 
 
 
 
 
116
  },
117
  "opus": {
118
- "dim": 1024,
119
- "n_unique_layers": 20,
120
- "n_logical_layers": 40,
121
  "n_heads": 16,
122
  "n_kv_heads": 4,
123
- "ffn_dim": 3557,
124
  "dropout": 0.0,
125
  "seq_len": 2048,
126
  "mtp_horizons": (2,),
127
- "batch_size": 12,
 
128
  "grad_accum": 1,
129
  "lr": 1.6e-4,
130
  "min_lr": 1.6e-5,
@@ -134,15 +154,19 @@ MODEL_SERIES = {
134
  "weight_decay": 0.1,
135
  "pretrain_passes": 1,
136
  "sft_passes": 1,
137
- "max_sft_target_chars": 1024,
138
  "use_grad_checkpoint": True,
139
- "use_torch_compile": True,
140
  "num_workers": 48,
141
- "prefetch_factor": 48,
142
  "shuffle_buffer": 16384,
143
  "max_pretrain_tokens": 0,
144
- "min_pretrain_tokens": 0,
145
  "quantization": "nvfp4",
 
 
 
 
 
146
  },
147
  }
148
 
@@ -381,6 +405,10 @@ class CausalSelfAttention(nn.Module):
381
  self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
382
  self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
383
 
 
 
 
 
384
  self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
385
  self.rope = RotaryEmbedding(self.rope_dim)
386
 
@@ -444,7 +472,7 @@ class CausalSelfAttention(nn.Module):
444
  .reshape(B, self.n_heads, S, self.head_dim)
445
  )
446
 
447
- drop_p = self.dropout if self.training else 0.0
448
 
449
  if is_global:
450
  if past_kv is None and T > 1:
@@ -479,9 +507,184 @@ class SwiGLU(nn.Module):
479
  self.down = nn.Linear(hidden_dim, dim, bias=False)
480
  self.drop = nn.Dropout(dropout)
481
 
 
 
 
 
482
  def forward(self, x: torch.Tensor) -> torch.Tensor:
483
  h = F.silu(self.gate(x)) * self.up(x)
484
- return self.drop(self.down(h))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
 
487
  class TransformerBlock(nn.Module):
@@ -495,8 +698,14 @@ class TransformerBlock(nn.Module):
495
  dropout: float,
496
  sliding_window: int,
497
  rope_fraction: float,
 
 
 
 
 
498
  ) -> None:
499
  super().__init__()
 
500
  self.norm1 = RMSNorm(dim)
501
  self.attn = CausalSelfAttention(
502
  dim=dim,
@@ -509,6 +718,20 @@ class TransformerBlock(nn.Module):
509
  )
510
  self.norm2 = RMSNorm(dim)
511
  self.ffn = SwiGLU(dim, ffn_dim, dropout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
  def forward(
514
  self,
@@ -516,13 +739,50 @@ class TransformerBlock(nn.Module):
516
  is_global: bool,
517
  past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
518
  use_cache: bool = False,
 
519
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
520
- attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
521
- x = x + attn_out
522
- x = x + self.ffn(self.norm2(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  return x, new_kv
524
 
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  class TinyMemoryLM(nn.Module):
527
  def __init__(
528
  self,
@@ -537,8 +797,13 @@ class TinyMemoryLM(nn.Module):
537
  mtp_horizons: Sequence[int],
538
  grad_checkpoint: bool,
539
  sliding_window: int = 512,
540
- rope_fraction: float = 0.25,
541
  embed_scale: bool = True,
 
 
 
 
 
542
  ) -> None:
543
  super().__init__()
544
  self.dim = dim
@@ -565,6 +830,11 @@ class TinyMemoryLM(nn.Module):
565
  dropout=dropout,
566
  sliding_window=sliding_window,
567
  rope_fraction=rope_fraction,
 
 
 
 
 
568
  )
569
  for _ in range(n_unique_layers)
570
  ]
@@ -634,7 +904,7 @@ class TinyMemoryLM(nn.Module):
634
  )
635
 
636
  for layer_idx, (block, logical_idx) in enumerate(logical_layers):
637
- is_global = logical_idx % 3 == 0
638
  past_kv = (
639
  past_key_values[layer_idx]
640
  if past_key_values is not None and layer_idx < len(past_key_values)
@@ -643,17 +913,20 @@ class TinyMemoryLM(nn.Module):
643
 
644
  if self.grad_checkpoint and self.training and not use_cache:
645
  x, layer_kv = checkpoint(
646
- block, x, is_global, past_kv, use_cache, use_reentrant=False
647
  )
648
  else:
649
- x, layer_kv = block(x, is_global, past_kv, use_cache)
650
 
651
  if new_past_key_values is not None:
652
  new_past_key_values.append(layer_kv)
653
 
654
  x = self.norm(x)
655
  h_out = x if return_hidden else None
656
- logits = self.head(x) + self.output_bias
 
 
 
657
 
658
  mtp: Dict[int, torch.Tensor] = {}
659
  if self.mtp_horizons and self.training:
@@ -662,7 +935,10 @@ class TinyMemoryLM(nn.Module):
662
  shifted_h = x[:, :-horizon, :]
663
  adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
664
  adapted_h = self.mtp_norms[str(horizon)](adapted_h)
665
- mtp_logits = self.head(adapted_h) + self.output_bias
 
 
 
666
  mtp[horizon] = mtp_logits
667
 
668
  return logits, mtp, h_out, new_past_key_values
@@ -1462,6 +1738,14 @@ def load_bundle(repo_id: str, model_type: str) -> dict[str, object]:
1462
  ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
1463
  cfg = series_config(series)
1464
  vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
 
 
 
 
 
 
 
 
1465
  model = TinyMemoryLM(
1466
  vocab_size=vocab_size,
1467
  dim=int(cfg.get("dim", model_config.dim)),
@@ -1480,19 +1764,24 @@ def load_bundle(repo_id: str, model_type: str) -> dict[str, object]:
1480
  ),
1481
  grad_checkpoint=False,
1482
  sliding_window=int(
1483
- cfg.get(
1484
- "sliding_window_size",
1485
- getattr(model_config, "sliding_window_size", 512),
1486
- )
1487
  ),
1488
  rope_fraction=float(
1489
- cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))
1490
  ),
1491
  embed_scale=bool(
1492
- cfg.get("embed_scale", getattr(model_config, "embed_scale", True))
 
 
 
 
 
1493
  ),
 
 
 
 
1494
  )
1495
- state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
1496
  model.load_state_dict(state_dict, strict=False)
1497
  model.eval()
1498
  if tokenizer.vocab_size > vocab_size:
@@ -1678,7 +1967,7 @@ def page_html() -> str:
1678
  </div>
1679
  <div class="meta">
1680
  <span class="chip">Hugging Face: CompactAI</span>
1681
- <span class="chip">Auto-installs deps</span>
1682
  <span class="chip">Local inference</span>
1683
  </div>
1684
  </div>
 
18
  from urllib.parse import quote, unquote, urlparse
19
  from urllib.request import Request, urlopen
20
 
21
+ import hashlib
22
+
23
  import torch
24
  import torch.nn as nn
25
  import torch.nn.functional as F
 
43
  seq_len: int = 2048
44
  sliding_window_size: int = 512
45
  mtp_horizons: Tuple[int, ...] = (2, 3, 4)
46
+ rope_fraction: float = 0.5
47
  embed_scale: bool = True
48
  logit_soft_cap: float = -1.0
49
  quantization: str = "nvfp4"
50
+ # Engram (conditional memory) config
51
+ engram_dim: int = 0
52
+ engram_heads: int = 4
53
+ engram_table_size: int = 8192
54
+ engram_max_ngram: int = 3
55
+ # mHC (Manifold-Constrained Hyper-Connections) config
56
+ mhc_expansion: int = 1
57
 
58
  @property
59
  def head_dim(self) -> int:
 
64
 
65
  MODEL_SERIES = {
66
  "haiku": {
67
+ "dim": 64,
68
+ "n_unique_layers": 12,
69
+ "n_logical_layers": 24,
70
  "n_heads": 4,
71
  "n_kv_heads": 2,
72
+ "ffn_dim": 384,
73
  "dropout": 0.0,
74
  "seq_len": 2048,
75
  "mtp_horizons": (2, 3, 4),
76
+ "rope_fraction": 0.5,
77
+ "batch_size": 80,
78
  "grad_accum": 1,
79
  "lr": 8e-4,
80
  "min_lr": 1e-5,
 
84
  "weight_decay": 0.02,
85
  "pretrain_passes": 2,
86
  "sft_passes": 3,
87
+ "max_sft_target_chars": 0,
88
+ "use_grad_checkpoint": True,
 
89
  "num_workers": 24,
90
  "prefetch_factor": 64,
91
  "shuffle_buffer": 8192,
92
  "max_pretrain_tokens": 0,
93
  "min_pretrain_tokens": 100_000_000,
94
  "quantization": "nvfp4",
95
+ "engram_dim": 8,
96
+ "engram_heads": 2,
97
+ "engram_table_size": 64,
98
+ "engram_max_ngram": 2,
99
+ "mhc_expansion": 2,
100
  },
101
  "sonnet": {
102
+ "dim": 1024,
103
+ "n_unique_layers": 20,
104
+ "n_logical_layers": 40,
105
+ "n_heads": 16,
106
  "n_kv_heads": 4,
107
+ "ffn_dim": 4096,
108
  "dropout": 0.0,
109
  "seq_len": 2048,
110
  "mtp_horizons": (2,),
111
+ "rope_fraction": 0.5,
112
+ "batch_size": 24,
113
  "grad_accum": 1,
114
+ "lr": 1e-4,
115
  "min_lr": 2e-5,
116
  "sft_lr": 5e-5,
117
  "sft_min_lr": 5e-6,
 
119
  "weight_decay": 0.1,
120
  "pretrain_passes": 1,
121
  "sft_passes": 1,
122
+ "max_sft_target_chars": 0,
123
  "use_grad_checkpoint": True,
 
124
  "num_workers": 32,
125
+ "prefetch_factor": 64,
126
  "shuffle_buffer": 16384,
127
  "max_pretrain_tokens": 0,
128
+ "min_pretrain_tokens": 100_000_000,
129
  "quantization": "nvfp4",
130
+ "engram_dim": 32,
131
+ "engram_heads": 8,
132
+ "engram_table_size": 4096,
133
+ "engram_max_ngram": 2,
134
+ "mhc_expansion": 2,
135
  },
136
  "opus": {
137
+ "dim": 1536,
138
+ "n_unique_layers": 18,
139
+ "n_logical_layers": 36,
140
  "n_heads": 16,
141
  "n_kv_heads": 4,
142
+ "ffn_dim": 5888,
143
  "dropout": 0.0,
144
  "seq_len": 2048,
145
  "mtp_horizons": (2,),
146
+ "rope_fraction": 0.5,
147
+ "batch_size": 24,
148
  "grad_accum": 1,
149
  "lr": 1.6e-4,
150
  "min_lr": 1.6e-5,
 
154
  "weight_decay": 0.1,
155
  "pretrain_passes": 1,
156
  "sft_passes": 1,
157
+ "max_sft_target_chars": 0,
158
  "use_grad_checkpoint": True,
 
159
  "num_workers": 48,
160
+ "prefetch_factor": 64,
161
  "shuffle_buffer": 16384,
162
  "max_pretrain_tokens": 0,
163
+ "min_pretrain_tokens": 100_000_000,
164
  "quantization": "nvfp4",
165
+ "engram_dim": 64,
166
+ "engram_heads": 8,
167
+ "engram_table_size": 8192,
168
+ "engram_max_ngram": 2,
169
+ "mhc_expansion": 4,
170
  },
171
  }
172
 
 
405
  self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
406
  self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
407
 
408
+ for lin in (self.wq, self.wk, self.wv):
409
+ nn.init.normal_(lin.weight, std=dim ** -0.5)
410
+ nn.init.normal_(self.wo.weight, std=(n_heads * head_dim) ** -0.5)
411
+
412
  self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
413
  self.rope = RotaryEmbedding(self.rope_dim)
414
 
 
472
  .reshape(B, self.n_heads, S, self.head_dim)
473
  )
474
 
475
+ drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0
476
 
477
  if is_global:
478
  if past_kv is None and T > 1:
 
507
  self.down = nn.Linear(hidden_dim, dim, bias=False)
508
  self.drop = nn.Dropout(dropout)
509
 
510
+ nn.init.normal_(self.gate.weight, std=dim ** -0.5)
511
+ nn.init.normal_(self.up.weight, std=dim ** -0.5)
512
+ nn.init.normal_(self.down.weight, std=hidden_dim ** -0.5)
513
+
514
  def forward(self, x: torch.Tensor) -> torch.Tensor:
515
  h = F.silu(self.gate(x)) * self.up(x)
516
+ out = self.down(h)
517
+ if self.training and torch.is_grad_enabled():
518
+ out = self.drop(out)
519
+ return out
520
+
521
+
522
+ class EngramBlock(nn.Module):
523
+ """Conditional memory via O(1) hashed N-gram lookup (DeepSeek Engram)."""
524
+
525
+ def __init__(
526
+ self,
527
+ dim: int,
528
+ engram_dim: int,
529
+ n_heads: int = 4,
530
+ table_size: int = 8192,
531
+ max_ngram: int = 3,
532
+ ) -> None:
533
+ super().__init__()
534
+ self.dim = dim
535
+ self.engram_dim = engram_dim
536
+ self.n_heads = n_heads
537
+ self.table_size = table_size
538
+ self.max_ngram = max_ngram
539
+
540
+ self.embeddings = nn.ParameterDict()
541
+ for n in range(2, max_ngram + 1):
542
+ for k in range(n_heads):
543
+ self.embeddings[f"{n}_{k}"] = nn.Parameter(
544
+ torch.randn(table_size, engram_dim) * (engram_dim**-0.5)
545
+ )
546
+
547
+ for n in range(2, max_ngram + 1):
548
+ for k in range(n_heads):
549
+ seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16)
550
+ rng = torch.Generator().manual_seed(seed)
551
+ a = torch.randint(1, 2**31, (1,), generator=rng).item()
552
+ b = torch.randint(0, 2**31, (1,), generator=rng).item()
553
+ self.register_buffer(
554
+ f"hash_a_{n}_{k}", torch.tensor(a), persistent=False
555
+ )
556
+ self.register_buffer(
557
+ f"hash_b_{n}_{k}", torch.tensor(b), persistent=False
558
+ )
559
+
560
+ total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
561
+ self.branch_conv = nn.Conv1d(
562
+ total_branch_dim,
563
+ total_branch_dim,
564
+ kernel_size=4,
565
+ dilation=max_ngram,
566
+ padding=0,
567
+ groups=total_branch_dim,
568
+ bias=True,
569
+ )
570
+ nn.init.zeros_(self.branch_conv.weight)
571
+ nn.init.zeros_(self.branch_conv.bias)
572
+
573
+ self.gate_query = nn.Linear(dim, engram_dim, bias=False)
574
+ self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False)
575
+ self.gate_value = nn.Linear(total_branch_dim, dim, bias=False)
576
+ self.gate_scale = engram_dim**-0.5
577
+
578
+ def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor:
579
+ a = getattr(self, f"hash_a_{n}_{k}")
580
+ b = getattr(self, f"hash_b_{n}_{k}")
581
+ B, T = token_ids.shape
582
+ padded = F.pad(token_ids, (n - 1, 0), value=0)
583
+ combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device)
584
+ for i in range(n):
585
+ combined = (combined * 31 + padded[:, i : i + T].long()) % self.table_size
586
+ return ((a * combined) ^ b) % self.table_size
587
+
588
+ def forward(
589
+ self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None
590
+ ) -> torch.Tensor:
591
+ B, T, _ = hidden.shape
592
+ if token_ids is None:
593
+ token_ids = hidden.mean(dim=-1).long() % self.table_size
594
+ all_indices = []
595
+ all_tables = []
596
+ for n in range(2, self.max_ngram + 1):
597
+ for k in range(self.n_heads):
598
+ all_indices.append(self._hash_ngram(token_ids, n, k))
599
+ all_tables.append(self.embeddings[f"{n}_{k}"])
600
+ branch_outputs = [tbl[idx] for idx, tbl in zip(all_indices, all_tables)]
601
+ memory = torch.cat(branch_outputs, dim=-1)
602
+ conv_in = memory.transpose(1, 2)
603
+ conv_in = F.pad(
604
+ conv_in,
605
+ (self.branch_conv.dilation[0] * (self.branch_conv.kernel_size[0] - 1), 0),
606
+ )
607
+ conv_out = self.branch_conv(conv_in)
608
+ memory = conv_out.transpose(1, 2)
609
+ query = self.gate_query(hidden)
610
+ key = self.gate_key(memory)
611
+ gate = torch.sigmoid(
612
+ (query * key).sum(dim=-1, keepdim=True) * self.gate_scale
613
+ )
614
+ value = self.gate_value(memory)
615
+ return gate * value
616
+
617
+
618
+ def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
619
+ M = torch.exp(logits.clamp(-10, 10))
620
+ for _ in range(n_iters):
621
+ M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10)
622
+ M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10)
623
+ return M
624
+
625
+
626
+ class ManifoldHyperConnection(nn.Module):
627
+ """Manifold-Constrained Hyper-Connections (mHC) residual wrapper."""
628
+
629
+ def __init__(self, dim: int, expansion: int = 2) -> None:
630
+ super().__init__()
631
+ self.dim = dim
632
+ self.expansion = expansion
633
+ n = expansion
634
+
635
+ self.bias_pre = nn.Parameter(torch.zeros(1, n))
636
+ self.bias_post = nn.Parameter(torch.zeros(1, n))
637
+ self.bias_res = nn.Parameter(torch.zeros(n, n))
638
+
639
+ self.theta_pre = nn.Linear(n * dim, n, bias=False)
640
+ self.theta_post = nn.Linear(n * dim, n, bias=False)
641
+ self.theta_res = nn.Linear(n * dim, n * n, bias=False)
642
+
643
+ self.alpha_pre = nn.Parameter(torch.tensor(0.0))
644
+ self.alpha_post = nn.Parameter(torch.tensor(0.0))
645
+ self.alpha_res = nn.Parameter(torch.tensor(0.0))
646
+
647
+ def _compute_mappings(
648
+ self, x_expanded: torch.Tensor
649
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
650
+ B, T, _ = x_expanded.shape
651
+ n = self.expansion
652
+ x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]])
653
+ d_pre = torch.tanh(self.theta_pre(x_norm))
654
+ d_post = torch.tanh(self.theta_post(x_norm))
655
+ d_res = self.theta_res(x_norm)
656
+ H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre)
657
+ H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post)
658
+ H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape(
659
+ B, T, n, n
660
+ )
661
+ H_res = _sinkhorn_knopp(H_res_raw)
662
+ return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res
663
+
664
+ def expand_stream(self, x: torch.Tensor) -> torch.Tensor:
665
+ return x.repeat(1, 1, self.expansion)
666
+
667
+ def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor:
668
+ B, T, _ = x_expanded.shape
669
+ return x_expanded.view(B, T, self.expansion, self.dim).mean(dim=-2)
670
+
671
+ def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor:
672
+ B, T, _ = x_expanded.shape
673
+ x_streams = x_expanded.view(B, T, self.expansion, self.dim)
674
+ return (H_pre @ x_streams).squeeze(-2)
675
+
676
+ def post_res_mix(
677
+ self,
678
+ layer_output: torch.Tensor,
679
+ x_expanded: torch.Tensor,
680
+ H_post: torch.Tensor,
681
+ H_res: torch.Tensor,
682
+ ) -> torch.Tensor:
683
+ B, T, _ = x_expanded.shape
684
+ x_streams = x_expanded.view(B, T, self.expansion, self.dim)
685
+ mixed = torch.matmul(H_res, x_streams)
686
+ post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2))
687
+ return (mixed + post_out).reshape(B, T, self.expansion * self.dim)
688
 
689
 
690
  class TransformerBlock(nn.Module):
 
698
  dropout: float,
699
  sliding_window: int,
700
  rope_fraction: float,
701
+ engram_dim: int = 0,
702
+ engram_heads: int = 4,
703
+ engram_table_size: int = 8192,
704
+ engram_max_ngram: int = 3,
705
+ mhc_expansion: int = 1,
706
  ) -> None:
707
  super().__init__()
708
+ self.dim = dim
709
  self.norm1 = RMSNorm(dim)
710
  self.attn = CausalSelfAttention(
711
  dim=dim,
 
718
  )
719
  self.norm2 = RMSNorm(dim)
720
  self.ffn = SwiGLU(dim, ffn_dim, dropout)
721
+ self.use_engram = engram_dim > 0
722
+ if self.use_engram:
723
+ self.engram = EngramBlock(
724
+ dim=dim,
725
+ engram_dim=engram_dim,
726
+ n_heads=engram_heads,
727
+ table_size=engram_table_size,
728
+ max_ngram=engram_max_ngram,
729
+ )
730
+ self.engram_norm = RMSNorm(dim)
731
+ self.use_mhc = mhc_expansion > 1
732
+ if self.use_mhc:
733
+ self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
734
+ self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
735
 
736
  def forward(
737
  self,
 
739
  is_global: bool,
740
  past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
741
  use_cache: bool = False,
742
+ token_ids: Optional[torch.Tensor] = None,
743
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
744
+ if self.use_mhc:
745
+ x_exp = self.mhc_attn.expand_stream(x)
746
+ H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp)
747
+ attn_in = self.mhc_attn.pre_mix(x_exp, H_pre)
748
+ attn_out, new_kv = self.attn(
749
+ self.norm1(attn_in), is_global, past_kv, use_cache
750
+ )
751
+ x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res)
752
+ if self.use_engram:
753
+ collapsed = self.mhc_attn.collapse_stream(x_exp)
754
+ collapsed = collapsed + self.engram(
755
+ self.engram_norm(collapsed), token_ids=token_ids
756
+ )
757
+ x_exp = self.mhc_attn.expand_stream(collapsed)
758
+ H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp)
759
+ ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2)
760
+ ffn_out = self.ffn(self.norm2(ffn_in))
761
+ x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2)
762
+ x = self.mhc_attn.collapse_stream(x_exp)
763
+ else:
764
+ attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
765
+ x = x + attn_out
766
+ if self.use_engram:
767
+ x = x + self.engram(self.engram_norm(x), token_ids=token_ids)
768
+ x = x + self.ffn(self.norm2(x))
769
  return x, new_kv
770
 
771
 
772
+ def _detect_engram_dim(state_dict: dict) -> int:
773
+ for key in state_dict:
774
+ if ".engram." in key and ".embeddings." in key:
775
+ return state_dict[key].shape[-1]
776
+ return 0
777
+
778
+
779
+ def _detect_mhc_expansion(state_dict: dict) -> int:
780
+ for key, val in state_dict.items():
781
+ if ".mhc_attn.bias_pre" in key and val.dim() == 2:
782
+ return val.shape[-1]
783
+ return 1
784
+
785
+
786
  class TinyMemoryLM(nn.Module):
787
  def __init__(
788
  self,
 
797
  mtp_horizons: Sequence[int],
798
  grad_checkpoint: bool,
799
  sliding_window: int = 512,
800
+ rope_fraction: float = 0.5,
801
  embed_scale: bool = True,
802
+ engram_dim: int = 0,
803
+ engram_heads: int = 4,
804
+ engram_table_size: int = 8192,
805
+ engram_max_ngram: int = 3,
806
+ mhc_expansion: int = 1,
807
  ) -> None:
808
  super().__init__()
809
  self.dim = dim
 
830
  dropout=dropout,
831
  sliding_window=sliding_window,
832
  rope_fraction=rope_fraction,
833
+ engram_dim=engram_dim,
834
+ engram_heads=engram_heads,
835
+ engram_table_size=engram_table_size,
836
+ engram_max_ngram=engram_max_ngram,
837
+ mhc_expansion=mhc_expansion,
838
  )
839
  for _ in range(n_unique_layers)
840
  ]
 
904
  )
905
 
906
  for layer_idx, (block, logical_idx) in enumerate(logical_layers):
907
+ is_global = logical_idx % 2 == 0
908
  past_kv = (
909
  past_key_values[layer_idx]
910
  if past_key_values is not None and layer_idx < len(past_key_values)
 
913
 
914
  if self.grad_checkpoint and self.training and not use_cache:
915
  x, layer_kv = checkpoint(
916
+ block, x, is_global, past_kv, use_cache, ids, use_reentrant=True
917
  )
918
  else:
919
+ x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
920
 
921
  if new_past_key_values is not None:
922
  new_past_key_values.append(layer_kv)
923
 
924
  x = self.norm(x)
925
  h_out = x if return_hidden else None
926
+ logits = self.head(x)
927
+ if self.embed_scale_factor != 1.0:
928
+ logits = logits / self.embed_scale_factor
929
+ logits = logits + self.output_bias
930
 
931
  mtp: Dict[int, torch.Tensor] = {}
932
  if self.mtp_horizons and self.training:
 
935
  shifted_h = x[:, :-horizon, :]
936
  adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
937
  adapted_h = self.mtp_norms[str(horizon)](adapted_h)
938
+ mtp_logits = self.head(adapted_h)
939
+ if self.embed_scale_factor != 1.0:
940
+ mtp_logits = mtp_logits / self.embed_scale_factor
941
+ mtp_logits = mtp_logits + self.output_bias
942
  mtp[horizon] = mtp_logits
943
 
944
  return logits, mtp, h_out, new_past_key_values
 
1738
  ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
1739
  cfg = series_config(series)
1740
  vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
1741
+ state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
1742
+ # Auto-detect new arch features from checkpoint weights
1743
+ engram_dim = _detect_engram_dim(state_dict) or int(
1744
+ cfg.get("engram_dim", model_config.engram_dim)
1745
+ )
1746
+ mhc_expansion = _detect_mhc_expansion(state_dict) or int(
1747
+ cfg.get("mhc_expansion", model_config.mhc_expansion)
1748
+ )
1749
  model = TinyMemoryLM(
1750
  vocab_size=vocab_size,
1751
  dim=int(cfg.get("dim", model_config.dim)),
 
1764
  ),
1765
  grad_checkpoint=False,
1766
  sliding_window=int(
1767
+ cfg.get("sliding_window_size", model_config.sliding_window_size)
 
 
 
1768
  ),
1769
  rope_fraction=float(
1770
+ cfg.get("rope_fraction", model_config.rope_fraction)
1771
  ),
1772
  embed_scale=bool(
1773
+ cfg.get("embed_scale", model_config.embed_scale)
1774
+ ),
1775
+ engram_dim=engram_dim,
1776
+ engram_heads=int(cfg.get("engram_heads", model_config.engram_heads)),
1777
+ engram_table_size=int(
1778
+ cfg.get("engram_table_size", model_config.engram_table_size)
1779
  ),
1780
+ engram_max_ngram=int(
1781
+ cfg.get("engram_max_ngram", model_config.engram_max_ngram)
1782
+ ),
1783
+ mhc_expansion=mhc_expansion,
1784
  )
 
1785
  model.load_state_dict(state_dict, strict=False)
1786
  model.eval()
1787
  if tokenizer.vocab_size > vocab_size:
 
1967
  </div>
1968
  <div class="meta">
1969
  <span class="chip">Hugging Face: CompactAI</span>
1970
+ <span class="chip">pip install -r requirements.txt</span>
1971
  <span class="chip">Local inference</span>
1972
  </div>
1973
  </div>
interactive.py ADDED
@@ -0,0 +1,2277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import math
6
+ import os
7
+ import re
8
+ import shutil
9
+ import socket
10
+ import string
11
+ import sys
12
+ import threading
13
+ import webbrowser
14
+ from dataclasses import dataclass
15
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
18
+ from urllib.parse import quote, unquote, urlparse
19
+ from urllib.request import Request, urlopen
20
+
21
+ import hashlib
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from torch.utils.checkpoint import checkpoint
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Config (from ailay.config)
31
+ # ---------------------------------------------------------------------------
32
+
33
+
34
+ @dataclass
35
+ class ModelConfig:
36
+ dim: int = 128
37
+ n_unique_layers: int = 8
38
+ n_logical_layers: int = 16
39
+ n_heads: int = 4
40
+ n_kv_heads: int = 2
41
+ ffn_dim: int = 224
42
+ dropout: float = 0.0
43
+ seq_len: int = 2048
44
+ sliding_window_size: int = 512
45
+ mtp_horizons: Tuple[int, ...] = (2, 3, 4)
46
+ rope_fraction: float = 0.5
47
+ embed_scale: bool = True
48
+ logit_soft_cap: float = -1.0
49
+ quantization: str = "nvfp4"
50
+ # Engram (conditional memory) config
51
+ engram_dim: int = 0
52
+ engram_heads: int = 4
53
+ engram_table_size: int = 8192
54
+ engram_max_ngram: int = 3
55
+ # mHC (Manifold-Constrained Hyper-Connections) config
56
+ mhc_expansion: int = 1
57
+
58
+ @property
59
+ def head_dim(self) -> int:
60
+ return self.dim // self.n_heads
61
+
62
+
63
+ model_config = ModelConfig()
64
+
65
+ MODEL_SERIES = {
66
+ "haiku": {
67
+ "dim": 64,
68
+ "n_unique_layers": 12,
69
+ "n_logical_layers": 24,
70
+ "n_heads": 4,
71
+ "n_kv_heads": 2,
72
+ "ffn_dim": 384,
73
+ "dropout": 0.0,
74
+ "seq_len": 2048,
75
+ "mtp_horizons": (2, 3, 4),
76
+ "rope_fraction": 0.5,
77
+ "batch_size": 80,
78
+ "grad_accum": 1,
79
+ "lr": 8e-4,
80
+ "min_lr": 1e-5,
81
+ "sft_lr": 2e-4,
82
+ "sft_min_lr": 1e-5,
83
+ "warmup_steps": 300,
84
+ "weight_decay": 0.02,
85
+ "pretrain_passes": 2,
86
+ "sft_passes": 3,
87
+ "max_sft_target_chars": 0,
88
+ "use_grad_checkpoint": True,
89
+ "num_workers": 24,
90
+ "prefetch_factor": 64,
91
+ "shuffle_buffer": 8192,
92
+ "max_pretrain_tokens": 0,
93
+ "min_pretrain_tokens": 100_000_000,
94
+ "quantization": "nvfp4",
95
+ "engram_dim": 8,
96
+ "engram_heads": 2,
97
+ "engram_table_size": 64,
98
+ "engram_max_ngram": 2,
99
+ "mhc_expansion": 2,
100
+ },
101
+ "sonnet": {
102
+ "dim": 1024,
103
+ "n_unique_layers": 20,
104
+ "n_logical_layers": 40,
105
+ "n_heads": 16,
106
+ "n_kv_heads": 4,
107
+ "ffn_dim": 4096,
108
+ "dropout": 0.0,
109
+ "seq_len": 2048,
110
+ "mtp_horizons": (2,),
111
+ "rope_fraction": 0.5,
112
+ "batch_size": 24,
113
+ "grad_accum": 1,
114
+ "lr": 1e-4,
115
+ "min_lr": 2e-5,
116
+ "sft_lr": 5e-5,
117
+ "sft_min_lr": 5e-6,
118
+ "warmup_steps": 250,
119
+ "weight_decay": 0.1,
120
+ "pretrain_passes": 1,
121
+ "sft_passes": 1,
122
+ "max_sft_target_chars": 0,
123
+ "use_grad_checkpoint": True,
124
+ "num_workers": 32,
125
+ "prefetch_factor": 64,
126
+ "shuffle_buffer": 16384,
127
+ "max_pretrain_tokens": 0,
128
+ "min_pretrain_tokens": 100_000_000,
129
+ "quantization": "nvfp4",
130
+ "engram_dim": 32,
131
+ "engram_heads": 8,
132
+ "engram_table_size": 4096,
133
+ "engram_max_ngram": 2,
134
+ "mhc_expansion": 2,
135
+ },
136
+ "opus": {
137
+ "dim": 1536,
138
+ "n_unique_layers": 18,
139
+ "n_logical_layers": 36,
140
+ "n_heads": 16,
141
+ "n_kv_heads": 4,
142
+ "ffn_dim": 5888,
143
+ "dropout": 0.0,
144
+ "seq_len": 2048,
145
+ "mtp_horizons": (2,),
146
+ "rope_fraction": 0.5,
147
+ "batch_size": 24,
148
+ "grad_accum": 1,
149
+ "lr": 1.6e-4,
150
+ "min_lr": 1.6e-5,
151
+ "sft_lr": 3e-5,
152
+ "sft_min_lr": 3e-6,
153
+ "warmup_steps": 200,
154
+ "weight_decay": 0.1,
155
+ "pretrain_passes": 1,
156
+ "sft_passes": 1,
157
+ "max_sft_target_chars": 0,
158
+ "use_grad_checkpoint": True,
159
+ "num_workers": 48,
160
+ "prefetch_factor": 64,
161
+ "shuffle_buffer": 16384,
162
+ "max_pretrain_tokens": 0,
163
+ "min_pretrain_tokens": 100_000_000,
164
+ "quantization": "nvfp4",
165
+ "engram_dim": 64,
166
+ "engram_heads": 8,
167
+ "engram_table_size": 8192,
168
+ "engram_max_ngram": 2,
169
+ "mhc_expansion": 4,
170
+ },
171
+ }
172
+
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # Tokenizer (from ailay.tokenizer)
176
+ # ---------------------------------------------------------------------------
177
+
178
+ FORMAT_TOKENS = [
179
+ "<|user|>",
180
+ "<|assistant|>",
181
+ "<|system|>",
182
+ "<|start_header_id|>",
183
+ "<|end_header_id|>",
184
+ "<|begin_of_thought|>",
185
+ "<|end_of_thought|>",
186
+ "<|begin_of_solution|>",
187
+ "<|end_of_solution|>",
188
+ ]
189
+
190
+
191
+ class WordTokenizer:
192
+ WORD_RE = re.compile(
193
+ r"\s+|[^\W\d_]+(?:['\u2019][^\W\d_]+)?|\d+|[^\w\s]+", re.UNICODE
194
+ )
195
+
196
+ def __init__(
197
+ self, extra_chars: str = "", format_tokens: Optional[List[str]] = None
198
+ ) -> None:
199
+ base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r"
200
+ fallback_chars = sorted(set(base + extra_chars))
201
+ self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]
202
+ self.format_tokens = (
203
+ list(format_tokens) if format_tokens else list(FORMAT_TOKENS)
204
+ )
205
+ self.special = list(self.core_special) + list(self.format_tokens)
206
+ self.id_to_token: List[str] = (
207
+ list(self.core_special) + self.format_tokens + fallback_chars
208
+ )
209
+ self.token_to_id: Dict[str, int] = {
210
+ t: i for i, t in enumerate(self.id_to_token)
211
+ }
212
+ self.special_multi_tokens = sorted(
213
+ [t for t in self.special if len(t) > 1], key=len, reverse=True
214
+ )
215
+ self.multi_char_tokens = self.special_multi_tokens
216
+ self.dynamic_additions = 0
217
+
218
+ @property
219
+ def pad_id(self) -> int:
220
+ return self.token_to_id["<PAD>"]
221
+
222
+ @property
223
+ def bos_id(self) -> int:
224
+ return self.token_to_id["<BOS>"]
225
+
226
+ @property
227
+ def eos_id(self) -> int:
228
+ return self.token_to_id["<EOS>"]
229
+
230
+ @property
231
+ def unk_id(self) -> int:
232
+ return self.token_to_id["<UNK>"]
233
+
234
+ @property
235
+ def vocab_size(self) -> int:
236
+ return len(self.id_to_token)
237
+
238
+ def maybe_add_char(self, ch: str) -> bool:
239
+ if ch in self.token_to_id:
240
+ return False
241
+ self.token_to_id[ch] = len(self.id_to_token)
242
+ self.id_to_token.append(ch)
243
+ self.dynamic_additions += 1
244
+ return True
245
+
246
+ def maybe_add_token(self, token: str) -> bool:
247
+ if token in self.token_to_id:
248
+ return False
249
+ self.token_to_id[token] = len(self.id_to_token)
250
+ self.id_to_token.append(token)
251
+ self.dynamic_additions += 1
252
+ return True
253
+
254
+ def iter_lexical_tokens(self, text: str) -> Iterator[str]:
255
+ i = 0
256
+ n = len(text)
257
+ while i < n:
258
+ matched_special = False
259
+ for token in self.special_multi_tokens:
260
+ if text.startswith(token, i):
261
+ yield token
262
+ i += len(token)
263
+ matched_special = True
264
+ break
265
+ if matched_special:
266
+ continue
267
+ m = self.WORD_RE.match(text, i)
268
+ if m is None:
269
+ yield text[i]
270
+ i += 1
271
+ continue
272
+ tok = m.group(0)
273
+ yield tok
274
+ i = m.end()
275
+
276
+ def encode(
277
+ self, text: str, add_bos: bool = False, add_eos: bool = False
278
+ ) -> List[int]:
279
+ out: List[int] = []
280
+ if add_bos:
281
+ out.append(self.bos_id)
282
+ unk = self.unk_id
283
+ t2i = self.token_to_id
284
+ for tok in self.iter_lexical_tokens(text):
285
+ tid = t2i.get(tok)
286
+ if tid is not None:
287
+ out.append(tid)
288
+ continue
289
+ for ch in tok:
290
+ out.append(t2i.get(ch, unk))
291
+ if add_eos:
292
+ out.append(self.eos_id)
293
+ return out
294
+
295
+ def decode(self, ids: Sequence[int], skip_special: bool = True) -> str:
296
+ pieces: List[str] = []
297
+ for idx in ids:
298
+ if int(idx) < 0 or int(idx) >= len(self.id_to_token):
299
+ continue
300
+ tok = self.id_to_token[int(idx)]
301
+ if skip_special and tok in self.special:
302
+ continue
303
+ pieces.append(tok)
304
+ return "".join(pieces)
305
+
306
+ def save(self, path: Path) -> None:
307
+ with path.open("w", encoding="utf-8") as f:
308
+ json.dump(
309
+ {
310
+ "id_to_token": self.id_to_token,
311
+ "format_tokens": self.format_tokens,
312
+ "core_special": self.core_special,
313
+ "tokenizer_type": "word_level_v1",
314
+ },
315
+ f,
316
+ ensure_ascii=False,
317
+ indent=2,
318
+ )
319
+
320
+ @classmethod
321
+ def load(cls, path: Path) -> WordTokenizer:
322
+ with path.open("r", encoding="utf-8") as f:
323
+ data = json.load(f)
324
+ format_tokens = data.get("format_tokens", FORMAT_TOKENS)
325
+ tokenizer = cls(extra_chars="", format_tokens=format_tokens)
326
+ tokenizer.id_to_token = data["id_to_token"]
327
+ tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)}
328
+ tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens)
329
+ tokenizer.special_multi_tokens = sorted(
330
+ [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True
331
+ )
332
+ tokenizer.multi_char_tokens = tokenizer.special_multi_tokens
333
+ return tokenizer
334
+
335
+
336
+ LetterTokenizer = WordTokenizer
337
+
338
+
339
+ # ---------------------------------------------------------------------------
340
+ # Model (from ailay.model)
341
+ # ---------------------------------------------------------------------------
342
+
343
+
344
+ class RMSNorm(nn.Module):
345
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
346
+ super().__init__()
347
+ self.weight = nn.Parameter(torch.ones(dim))
348
+ self.eps = eps
349
+
350
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
351
+ if hasattr(torch.nn.functional, "rms_norm"):
352
+ return torch.nn.functional.rms_norm(
353
+ x, self.weight.shape, self.weight, self.eps
354
+ )
355
+ x_fp = x.float()
356
+ rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps)
357
+ return (x_fp * rms).to(dtype=x.dtype) * self.weight
358
+
359
+
360
+ class RotaryEmbedding(nn.Module):
361
+ def __init__(self, dim: int, base: float = 10000.0) -> None:
362
+ super().__init__()
363
+ inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
364
+ self.register_buffer("inv_freq", inv, persistent=False)
365
+
366
+ def cos_sin(
367
+ self, seq_len: int, device: torch.device, dtype: torch.dtype
368
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
369
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
370
+ freqs = torch.outer(t, self.inv_freq)
371
+ emb = torch.cat([freqs, freqs], dim=-1)
372
+ cos = emb.cos()[None, None, :, :].to(dtype=dtype)
373
+ sin = emb.sin()[None, None, :, :].to(dtype=dtype)
374
+ return cos, sin
375
+
376
+
377
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
378
+ x1 = x[..., : x.shape[-1] // 2]
379
+ x2 = x[..., x.shape[-1] // 2 :]
380
+ return torch.cat((-x2, x1), dim=-1)
381
+
382
+
383
+ class CausalSelfAttention(nn.Module):
384
+ def __init__(
385
+ self,
386
+ dim: int,
387
+ n_heads: int,
388
+ n_kv_heads: int,
389
+ head_dim: int,
390
+ dropout: float,
391
+ sliding_window: int,
392
+ rope_fraction: float,
393
+ ) -> None:
394
+ super().__init__()
395
+ self.dim = dim
396
+ self.n_heads = n_heads
397
+ self.n_kv_heads = n_kv_heads
398
+ self.head_dim = head_dim
399
+ self.n_rep = n_heads // n_kv_heads
400
+ self.dropout = dropout
401
+ self.sliding_window = sliding_window
402
+
403
+ self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
404
+ self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
405
+ self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
406
+ self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
407
+
408
+ for lin in (self.wq, self.wk, self.wv):
409
+ nn.init.normal_(lin.weight, std=dim ** -0.5)
410
+ nn.init.normal_(self.wo.weight, std=(n_heads * head_dim) ** -0.5)
411
+
412
+ self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
413
+ self.rope = RotaryEmbedding(self.rope_dim)
414
+
415
+ self.q_norm = RMSNorm(head_dim)
416
+ self.k_norm = RMSNorm(head_dim)
417
+
418
+ self.output_gate = nn.Parameter(torch.zeros(n_heads))
419
+
420
+ def forward(
421
+ self,
422
+ x: torch.Tensor,
423
+ is_global: bool,
424
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
425
+ use_cache: bool = False,
426
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
427
+ B, T, _ = x.shape
428
+
429
+ q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
430
+ k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
431
+ v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
432
+
433
+ q = self.q_norm(q)
434
+ k = self.k_norm(k)
435
+
436
+ q = q.transpose(1, 2)
437
+ k = k.transpose(1, 2)
438
+ v = v.transpose(1, 2)
439
+
440
+ past_len = past_kv[0].shape[2] if past_kv is not None else 0
441
+ cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype)
442
+ cos_slice = cos[:, :, past_len : past_len + T, :]
443
+ sin_slice = sin[:, :, past_len : past_len + T, :]
444
+
445
+ q_rope = q[..., : self.rope_dim]
446
+ q_pass = q[..., self.rope_dim :]
447
+ k_rope = k[..., : self.rope_dim]
448
+ k_pass = k[..., self.rope_dim :]
449
+
450
+ q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice)
451
+ k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice)
452
+
453
+ q = torch.cat([q_rope, q_pass], dim=-1)
454
+ k = torch.cat([k_rope, k_pass], dim=-1)
455
+
456
+ if past_kv is not None:
457
+ k = torch.cat([past_kv[0], k], dim=2)
458
+ v = torch.cat([past_kv[1], v], dim=2)
459
+
460
+ new_kv = (k, v) if use_cache else None
461
+
462
+ S = k.shape[2]
463
+ if self.n_rep > 1:
464
+ k = (
465
+ k[:, :, None, :, :]
466
+ .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
467
+ .reshape(B, self.n_heads, S, self.head_dim)
468
+ )
469
+ v = (
470
+ v[:, :, None, :, :]
471
+ .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
472
+ .reshape(B, self.n_heads, S, self.head_dim)
473
+ )
474
+
475
+ drop_p = self.dropout if (self.training and torch.is_grad_enabled()) else 0.0
476
+
477
+ if is_global:
478
+ if past_kv is None and T > 1:
479
+ out = F.scaled_dot_product_attention(
480
+ q, k, v, is_causal=True, dropout_p=drop_p
481
+ )
482
+ else:
483
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p)
484
+ else:
485
+ T_q = q.shape[2]
486
+ q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1)
487
+ k_pos = torch.arange(S, device=q.device).unsqueeze(0)
488
+ mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window)
489
+ out = F.scaled_dot_product_attention(
490
+ q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p
491
+ )
492
+
493
+ gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1)
494
+ out = out * gate
495
+
496
+ out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
497
+ out = self.wo(out)
498
+
499
+ return out, new_kv
500
+
501
+
502
+ class SwiGLU(nn.Module):
503
+ def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
504
+ super().__init__()
505
+ self.gate = nn.Linear(dim, hidden_dim, bias=False)
506
+ self.up = nn.Linear(dim, hidden_dim, bias=False)
507
+ self.down = nn.Linear(hidden_dim, dim, bias=False)
508
+ self.drop = nn.Dropout(dropout)
509
+
510
+ nn.init.normal_(self.gate.weight, std=dim ** -0.5)
511
+ nn.init.normal_(self.up.weight, std=dim ** -0.5)
512
+ nn.init.normal_(self.down.weight, std=hidden_dim ** -0.5)
513
+
514
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
515
+ h = F.silu(self.gate(x)) * self.up(x)
516
+ out = self.down(h)
517
+ if self.training and torch.is_grad_enabled():
518
+ out = self.drop(out)
519
+ return out
520
+
521
+
522
+ class EngramBlock(nn.Module):
523
+ """Conditional memory via O(1) hashed N-gram lookup (DeepSeek Engram)."""
524
+
525
+ def __init__(
526
+ self,
527
+ dim: int,
528
+ engram_dim: int,
529
+ n_heads: int = 4,
530
+ table_size: int = 8192,
531
+ max_ngram: int = 3,
532
+ ) -> None:
533
+ super().__init__()
534
+ self.dim = dim
535
+ self.engram_dim = engram_dim
536
+ self.n_heads = n_heads
537
+ self.table_size = table_size
538
+ self.max_ngram = max_ngram
539
+
540
+ self.embeddings = nn.ParameterDict()
541
+ for n in range(2, max_ngram + 1):
542
+ for k in range(n_heads):
543
+ self.embeddings[f"{n}_{k}"] = nn.Parameter(
544
+ torch.randn(table_size, engram_dim) * (engram_dim**-0.5)
545
+ )
546
+
547
+ for n in range(2, max_ngram + 1):
548
+ for k in range(n_heads):
549
+ seed = int(hashlib.md5(f"engram_{n}_{k}".encode()).hexdigest()[:8], 16)
550
+ rng = torch.Generator().manual_seed(seed)
551
+ a = torch.randint(1, 2**31, (1,), generator=rng).item()
552
+ b = torch.randint(0, 2**31, (1,), generator=rng).item()
553
+ self.register_buffer(
554
+ f"hash_a_{n}_{k}", torch.tensor(a), persistent=False
555
+ )
556
+ self.register_buffer(
557
+ f"hash_b_{n}_{k}", torch.tensor(b), persistent=False
558
+ )
559
+
560
+ total_branch_dim = engram_dim * n_heads * (max_ngram - 1)
561
+ self.branch_conv = nn.Conv1d(
562
+ total_branch_dim,
563
+ total_branch_dim,
564
+ kernel_size=4,
565
+ dilation=max_ngram,
566
+ padding=0,
567
+ groups=total_branch_dim,
568
+ bias=True,
569
+ )
570
+ nn.init.zeros_(self.branch_conv.weight)
571
+ nn.init.zeros_(self.branch_conv.bias)
572
+
573
+ self.gate_query = nn.Linear(dim, engram_dim, bias=False)
574
+ self.gate_key = nn.Linear(total_branch_dim, engram_dim, bias=False)
575
+ self.gate_value = nn.Linear(total_branch_dim, dim, bias=False)
576
+ self.gate_scale = engram_dim**-0.5
577
+
578
+ def _hash_ngram(self, token_ids: torch.Tensor, n: int, k: int) -> torch.Tensor:
579
+ a = getattr(self, f"hash_a_{n}_{k}")
580
+ b = getattr(self, f"hash_b_{n}_{k}")
581
+ B, T = token_ids.shape
582
+ padded = F.pad(token_ids, (n - 1, 0), value=0)
583
+ combined = torch.zeros(B, T, dtype=torch.long, device=token_ids.device)
584
+ for i in range(n):
585
+ combined = (combined * 31 + padded[:, i : i + T].long()) % self.table_size
586
+ return ((a * combined) ^ b) % self.table_size
587
+
588
+ def forward(
589
+ self, hidden: torch.Tensor, token_ids: Optional[torch.Tensor] = None
590
+ ) -> torch.Tensor:
591
+ B, T, _ = hidden.shape
592
+ if token_ids is None:
593
+ token_ids = hidden.mean(dim=-1).long() % self.table_size
594
+ all_indices = []
595
+ all_tables = []
596
+ for n in range(2, self.max_ngram + 1):
597
+ for k in range(self.n_heads):
598
+ all_indices.append(self._hash_ngram(token_ids, n, k))
599
+ all_tables.append(self.embeddings[f"{n}_{k}"])
600
+ branch_outputs = [tbl[idx] for idx, tbl in zip(all_indices, all_tables)]
601
+ memory = torch.cat(branch_outputs, dim=-1)
602
+ conv_in = memory.transpose(1, 2)
603
+ conv_in = F.pad(
604
+ conv_in,
605
+ (self.branch_conv.dilation[0] * (self.branch_conv.kernel_size[0] - 1), 0),
606
+ )
607
+ conv_out = self.branch_conv(conv_in)
608
+ memory = conv_out.transpose(1, 2)
609
+ query = self.gate_query(hidden)
610
+ key = self.gate_key(memory)
611
+ gate = torch.sigmoid(
612
+ (query * key).sum(dim=-1, keepdim=True) * self.gate_scale
613
+ )
614
+ value = self.gate_value(memory)
615
+ return gate * value
616
+
617
+
618
+ def _sinkhorn_knopp(logits: torch.Tensor, n_iters: int = 7) -> torch.Tensor:
619
+ M = torch.exp(logits.clamp(-10, 10))
620
+ for _ in range(n_iters):
621
+ M = M / M.sum(dim=-1, keepdim=True).clamp(min=1e-10)
622
+ M = M / M.sum(dim=-2, keepdim=True).clamp(min=1e-10)
623
+ return M
624
+
625
+
626
+ class ManifoldHyperConnection(nn.Module):
627
+ """Manifold-Constrained Hyper-Connections (mHC) residual wrapper."""
628
+
629
+ def __init__(self, dim: int, expansion: int = 2) -> None:
630
+ super().__init__()
631
+ self.dim = dim
632
+ self.expansion = expansion
633
+ n = expansion
634
+
635
+ self.bias_pre = nn.Parameter(torch.zeros(1, n))
636
+ self.bias_post = nn.Parameter(torch.zeros(1, n))
637
+ self.bias_res = nn.Parameter(torch.zeros(n, n))
638
+
639
+ self.theta_pre = nn.Linear(n * dim, n, bias=False)
640
+ self.theta_post = nn.Linear(n * dim, n, bias=False)
641
+ self.theta_res = nn.Linear(n * dim, n * n, bias=False)
642
+
643
+ self.alpha_pre = nn.Parameter(torch.tensor(0.0))
644
+ self.alpha_post = nn.Parameter(torch.tensor(0.0))
645
+ self.alpha_res = nn.Parameter(torch.tensor(0.0))
646
+
647
+ def _compute_mappings(
648
+ self, x_expanded: torch.Tensor
649
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
650
+ B, T, _ = x_expanded.shape
651
+ n = self.expansion
652
+ x_norm = F.rms_norm(x_expanded, [x_expanded.shape[-1]])
653
+ d_pre = torch.tanh(self.theta_pre(x_norm))
654
+ d_post = torch.tanh(self.theta_post(x_norm))
655
+ d_res = self.theta_res(x_norm)
656
+ H_pre_raw = torch.sigmoid(self.alpha_pre * d_pre + self.bias_pre)
657
+ H_post_raw = 2.0 * torch.sigmoid(self.alpha_post * d_post + self.bias_post)
658
+ H_res_raw = (self.alpha_res * d_res + self.bias_res.reshape(1, 1, -1)).reshape(
659
+ B, T, n, n
660
+ )
661
+ H_res = _sinkhorn_knopp(H_res_raw)
662
+ return H_pre_raw.unsqueeze(-2), H_post_raw.unsqueeze(-2), H_res
663
+
664
+ def expand_stream(self, x: torch.Tensor) -> torch.Tensor:
665
+ return x.repeat(1, 1, self.expansion)
666
+
667
+ def collapse_stream(self, x_expanded: torch.Tensor) -> torch.Tensor:
668
+ B, T, _ = x_expanded.shape
669
+ return x_expanded.view(B, T, self.expansion, self.dim).mean(dim=-2)
670
+
671
+ def pre_mix(self, x_expanded: torch.Tensor, H_pre: torch.Tensor) -> torch.Tensor:
672
+ B, T, _ = x_expanded.shape
673
+ x_streams = x_expanded.view(B, T, self.expansion, self.dim)
674
+ return (H_pre @ x_streams).squeeze(-2)
675
+
676
+ def post_res_mix(
677
+ self,
678
+ layer_output: torch.Tensor,
679
+ x_expanded: torch.Tensor,
680
+ H_post: torch.Tensor,
681
+ H_res: torch.Tensor,
682
+ ) -> torch.Tensor:
683
+ B, T, _ = x_expanded.shape
684
+ x_streams = x_expanded.view(B, T, self.expansion, self.dim)
685
+ mixed = torch.matmul(H_res, x_streams)
686
+ post_out = torch.matmul(H_post.transpose(-2, -1), layer_output.unsqueeze(-2))
687
+ return (mixed + post_out).reshape(B, T, self.expansion * self.dim)
688
+
689
+
690
+ class TransformerBlock(nn.Module):
691
+ def __init__(
692
+ self,
693
+ dim: int,
694
+ n_heads: int,
695
+ n_kv_heads: int,
696
+ head_dim: int,
697
+ ffn_dim: int,
698
+ dropout: float,
699
+ sliding_window: int,
700
+ rope_fraction: float,
701
+ engram_dim: int = 0,
702
+ engram_heads: int = 4,
703
+ engram_table_size: int = 8192,
704
+ engram_max_ngram: int = 3,
705
+ mhc_expansion: int = 1,
706
+ ) -> None:
707
+ super().__init__()
708
+ self.dim = dim
709
+ self.norm1 = RMSNorm(dim)
710
+ self.attn = CausalSelfAttention(
711
+ dim=dim,
712
+ n_heads=n_heads,
713
+ n_kv_heads=n_kv_heads,
714
+ head_dim=head_dim,
715
+ dropout=dropout,
716
+ sliding_window=sliding_window,
717
+ rope_fraction=rope_fraction,
718
+ )
719
+ self.norm2 = RMSNorm(dim)
720
+ self.ffn = SwiGLU(dim, ffn_dim, dropout)
721
+ self.use_engram = engram_dim > 0
722
+ if self.use_engram:
723
+ self.engram = EngramBlock(
724
+ dim=dim,
725
+ engram_dim=engram_dim,
726
+ n_heads=engram_heads,
727
+ table_size=engram_table_size,
728
+ max_ngram=engram_max_ngram,
729
+ )
730
+ self.engram_norm = RMSNorm(dim)
731
+ self.use_mhc = mhc_expansion > 1
732
+ if self.use_mhc:
733
+ self.mhc_attn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
734
+ self.mhc_ffn = ManifoldHyperConnection(dim, expansion=mhc_expansion)
735
+
736
+ def forward(
737
+ self,
738
+ x: torch.Tensor,
739
+ is_global: bool,
740
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
741
+ use_cache: bool = False,
742
+ token_ids: Optional[torch.Tensor] = None,
743
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
744
+ if self.use_mhc:
745
+ x_exp = self.mhc_attn.expand_stream(x)
746
+ H_pre, H_post, H_res = self.mhc_attn._compute_mappings(x_exp)
747
+ attn_in = self.mhc_attn.pre_mix(x_exp, H_pre)
748
+ attn_out, new_kv = self.attn(
749
+ self.norm1(attn_in), is_global, past_kv, use_cache
750
+ )
751
+ x_exp = self.mhc_attn.post_res_mix(attn_out, x_exp, H_post, H_res)
752
+ if self.use_engram:
753
+ collapsed = self.mhc_attn.collapse_stream(x_exp)
754
+ collapsed = collapsed + self.engram(
755
+ self.engram_norm(collapsed), token_ids=token_ids
756
+ )
757
+ x_exp = self.mhc_attn.expand_stream(collapsed)
758
+ H_pre2, H_post2, H_res2 = self.mhc_ffn._compute_mappings(x_exp)
759
+ ffn_in = self.mhc_ffn.pre_mix(x_exp, H_pre2)
760
+ ffn_out = self.ffn(self.norm2(ffn_in))
761
+ x_exp = self.mhc_ffn.post_res_mix(ffn_out, x_exp, H_post2, H_res2)
762
+ x = self.mhc_attn.collapse_stream(x_exp)
763
+ else:
764
+ attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
765
+ x = x + attn_out
766
+ if self.use_engram:
767
+ x = x + self.engram(self.engram_norm(x), token_ids=token_ids)
768
+ x = x + self.ffn(self.norm2(x))
769
+ return x, new_kv
770
+
771
+
772
+ def _detect_engram_dim(state_dict: dict) -> int:
773
+ for key in state_dict:
774
+ if ".engram." in key and ".embeddings." in key:
775
+ return state_dict[key].shape[-1]
776
+ return 0
777
+
778
+
779
+ def _detect_mhc_expansion(state_dict: dict) -> int:
780
+ for key, val in state_dict.items():
781
+ if ".mhc_attn.bias_pre" in key and val.dim() == 2:
782
+ return val.shape[-1]
783
+ return 1
784
+
785
+
786
+ class TinyMemoryLM(nn.Module):
787
+ def __init__(
788
+ self,
789
+ vocab_size: int,
790
+ dim: int,
791
+ n_unique_layers: int,
792
+ n_logical_layers: int,
793
+ n_heads: int,
794
+ n_kv_heads: int,
795
+ ffn_dim: int,
796
+ dropout: float,
797
+ mtp_horizons: Sequence[int],
798
+ grad_checkpoint: bool,
799
+ sliding_window: int = 512,
800
+ rope_fraction: float = 0.5,
801
+ embed_scale: bool = True,
802
+ engram_dim: int = 0,
803
+ engram_heads: int = 4,
804
+ engram_table_size: int = 8192,
805
+ engram_max_ngram: int = 3,
806
+ mhc_expansion: int = 1,
807
+ ) -> None:
808
+ super().__init__()
809
+ self.dim = dim
810
+ self.n_unique_layers = n_unique_layers
811
+ self.n_logical_layers = n_logical_layers
812
+ self.grad_checkpoint = grad_checkpoint
813
+ self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0
814
+ head_dim = dim // n_heads
815
+
816
+ self.embed_tokens = nn.Embedding(vocab_size, dim)
817
+ self.head = nn.Linear(dim, vocab_size, bias=False)
818
+ self.head.weight = self.embed_tokens.weight
819
+
820
+ self.output_bias = nn.Parameter(torch.zeros(vocab_size))
821
+
822
+ self.blocks = nn.ModuleList(
823
+ [
824
+ TransformerBlock(
825
+ dim=dim,
826
+ n_heads=n_heads,
827
+ n_kv_heads=n_kv_heads,
828
+ head_dim=head_dim,
829
+ ffn_dim=ffn_dim,
830
+ dropout=dropout,
831
+ sliding_window=sliding_window,
832
+ rope_fraction=rope_fraction,
833
+ engram_dim=engram_dim,
834
+ engram_heads=engram_heads,
835
+ engram_table_size=engram_table_size,
836
+ engram_max_ngram=engram_max_ngram,
837
+ mhc_expansion=mhc_expansion,
838
+ )
839
+ for _ in range(n_unique_layers)
840
+ ]
841
+ )
842
+ self.norm = RMSNorm(dim)
843
+
844
+ self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
845
+ self.mtp_adapters = nn.ModuleDict(
846
+ {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons}
847
+ )
848
+ self.mtp_norms = nn.ModuleDict(
849
+ {str(h): RMSNorm(dim) for h in self.mtp_horizons}
850
+ )
851
+
852
+ res_scale = (2 * n_logical_layers) ** -0.5
853
+ for block in self.blocks:
854
+ block.attn.wo.weight.data.mul_(res_scale)
855
+ block.ffn.down.weight.data.mul_(res_scale)
856
+
857
+ def resize_token_embeddings(self, new_vocab_size: int) -> None:
858
+ old_vocab_size = self.embed_tokens.num_embeddings
859
+ if new_vocab_size == old_vocab_size:
860
+ return
861
+ device = self.embed_tokens.weight.device
862
+ old_embed_weight = self.embed_tokens.weight.data.clone()
863
+ self.embed_tokens = nn.Embedding(
864
+ new_vocab_size, self.embed_tokens.embedding_dim
865
+ ).to(device)
866
+ self.head = nn.Linear(
867
+ self.embed_tokens.embedding_dim, new_vocab_size, bias=False
868
+ ).to(device)
869
+ self.head.weight = self.embed_tokens.weight
870
+ old_bias = self.output_bias.data.clone()
871
+ self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
872
+ copy_size = min(old_vocab_size, new_vocab_size)
873
+ self.output_bias.data[:copy_size] = old_bias[:copy_size]
874
+ self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
875
+
876
+ def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
877
+ logical = []
878
+ blocks_list = list(self.blocks)
879
+ full_sequence = blocks_list + blocks_list
880
+ for logical_idx, block in enumerate(full_sequence[: self.n_logical_layers]):
881
+ logical.append((block, logical_idx))
882
+ return logical
883
+
884
+ def forward(
885
+ self,
886
+ ids: torch.Tensor,
887
+ use_cache: bool = False,
888
+ past_key_values: Optional[
889
+ List[Optional[Tuple[torch.Tensor, torch.Tensor]]]
890
+ ] = None,
891
+ return_hidden: bool = False,
892
+ ) -> Tuple[
893
+ torch.Tensor,
894
+ Dict[int, torch.Tensor],
895
+ Optional[torch.Tensor],
896
+ Optional[List[Tuple[torch.Tensor, torch.Tensor]]],
897
+ ]:
898
+ B, T = ids.shape
899
+ x = self.embed_tokens(ids) * self.embed_scale_factor
900
+
901
+ logical_layers = self._build_logical_layers()
902
+ new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = (
903
+ [] if use_cache else None
904
+ )
905
+
906
+ for layer_idx, (block, logical_idx) in enumerate(logical_layers):
907
+ is_global = logical_idx % 2 == 0
908
+ past_kv = (
909
+ past_key_values[layer_idx]
910
+ if past_key_values is not None and layer_idx < len(past_key_values)
911
+ else None
912
+ )
913
+
914
+ if self.grad_checkpoint and self.training and not use_cache:
915
+ x, layer_kv = checkpoint(
916
+ block, x, is_global, past_kv, use_cache, ids, use_reentrant=True
917
+ )
918
+ else:
919
+ x, layer_kv = block(x, is_global, past_kv, use_cache, ids)
920
+
921
+ if new_past_key_values is not None:
922
+ new_past_key_values.append(layer_kv)
923
+
924
+ x = self.norm(x)
925
+ h_out = x if return_hidden else None
926
+ logits = self.head(x)
927
+ if self.embed_scale_factor != 1.0:
928
+ logits = logits / self.embed_scale_factor
929
+ logits = logits + self.output_bias
930
+
931
+ mtp: Dict[int, torch.Tensor] = {}
932
+ if self.mtp_horizons and self.training:
933
+ for horizon in self.mtp_horizons:
934
+ if horizon > 1 and horizon <= T - 1:
935
+ shifted_h = x[:, :-horizon, :]
936
+ adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
937
+ adapted_h = self.mtp_norms[str(horizon)](adapted_h)
938
+ mtp_logits = self.head(adapted_h)
939
+ if self.embed_scale_factor != 1.0:
940
+ mtp_logits = mtp_logits / self.embed_scale_factor
941
+ mtp_logits = mtp_logits + self.output_bias
942
+ mtp[horizon] = mtp_logits
943
+
944
+ return logits, mtp, h_out, new_past_key_values
945
+
946
+
947
+ # ---------------------------------------------------------------------------
948
+ # Generation (from ailay.generation)
949
+ # ---------------------------------------------------------------------------
950
+
951
+
952
+ def sample_text(
953
+ model: TinyMemoryLM,
954
+ tokenizer: WordTokenizer,
955
+ prompt: str,
956
+ max_new_tokens: int,
957
+ temperature: float,
958
+ top_k: int,
959
+ branches: int,
960
+ branch_len: int,
961
+ device: torch.device,
962
+ seq_len: int,
963
+ ) -> str:
964
+ def _sample_id(logits: torch.Tensor) -> torch.Tensor:
965
+ if not torch.isfinite(logits).any():
966
+ logits = torch.zeros_like(logits)
967
+ logits = torch.where(
968
+ torch.isfinite(logits), logits, torch.full_like(logits, -1e9)
969
+ )
970
+ if top_k > 0:
971
+ v, idx = torch.topk(logits, k=min(top_k, logits.shape[-1]))
972
+ p = torch.softmax(v, dim=-1)
973
+ return idx.gather(-1, torch.multinomial(p, 1))
974
+ p = torch.softmax(logits, dim=-1)
975
+ return torch.multinomial(p, 1)
976
+
977
+ model.eval()
978
+ ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
979
+ prompt_len = len(ids)
980
+ x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
981
+
982
+ with torch.no_grad():
983
+ generated = 0
984
+ while generated < max_new_tokens:
985
+ if branches <= 1:
986
+ ctx = x[:, -seq_len:]
987
+ logits, _, _, _ = model(ctx)
988
+ nlogits = logits[:, -1, :] / max(temperature, 1e-6)
989
+ nid = _sample_id(nlogits)
990
+ x = torch.cat([x, nid], dim=1)
991
+ generated += 1
992
+ continue
993
+ rollout = min(branch_len, max_new_tokens - generated)
994
+ best_nll: Optional[float] = None
995
+ best_tokens: Optional[List[torch.Tensor]] = None
996
+ for _ in range(branches):
997
+ cand = x
998
+ cand_tokens: List[torch.Tensor] = []
999
+ nll = 0.0
1000
+ for _ in range(rollout):
1001
+ ctx = cand[:, -seq_len:]
1002
+ logits, _, _, _ = model(ctx)
1003
+ nlogits = logits[:, -1, :] / max(temperature, 1e-6)
1004
+ nid = _sample_id(nlogits)
1005
+ lp = F.log_softmax(nlogits.float(), dim=-1)
1006
+ nll += float(-lp.gather(-1, nid).item())
1007
+ cand = torch.cat([cand, nid], dim=1)
1008
+ cand_tokens.append(nid)
1009
+ if best_nll is None or nll < best_nll:
1010
+ best_nll = nll
1011
+ best_tokens = cand_tokens
1012
+ assert best_tokens is not None
1013
+ for t in best_tokens:
1014
+ x = torch.cat([x, t], dim=1)
1015
+ generated += 1
1016
+
1017
+ generated_ids = x[0, prompt_len:].tolist()
1018
+ return tokenizer.decode(generated_ids, skip_special=True)
1019
+
1020
+
1021
+ def sample_text_cached(
1022
+ model: TinyMemoryLM,
1023
+ tokenizer: WordTokenizer,
1024
+ prompt: str,
1025
+ max_new_tokens: int,
1026
+ temperature: float,
1027
+ top_k: int,
1028
+ device: torch.device,
1029
+ seq_len: int,
1030
+ ) -> str:
1031
+ model.eval()
1032
+ ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
1033
+ prompt_len = len(ids)
1034
+ x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
1035
+
1036
+ with torch.no_grad():
1037
+ logits, _, _, past_kv = model(x, use_cache=True)
1038
+ nlogits = logits[:, -1, :] / max(temperature, 1e-6)
1039
+ if top_k > 0:
1040
+ v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1]))
1041
+ p = torch.softmax(v, dim=-1)
1042
+ nid = idx.gather(-1, torch.multinomial(p, 1))
1043
+ else:
1044
+ p = torch.softmax(nlogits, dim=-1)
1045
+ nid = torch.multinomial(p, 1)
1046
+ all_ids = [int(nid.item())]
1047
+
1048
+ for _ in range(max_new_tokens - 1):
1049
+ logits, _, _, past_kv = model(nid, use_cache=True, past_key_values=past_kv)
1050
+ nlogits = logits[:, -1, :] / max(temperature, 1e-6)
1051
+ if top_k > 0:
1052
+ v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1]))
1053
+ p = torch.softmax(v, dim=-1)
1054
+ nid = idx.gather(-1, torch.multinomial(p, 1))
1055
+ else:
1056
+ p = torch.softmax(nlogits, dim=-1)
1057
+ nid = torch.multinomial(p, 1)
1058
+ tid = int(nid.item())
1059
+ all_ids.append(tid)
1060
+ if tid == tokenizer.eos_id:
1061
+ break
1062
+
1063
+ return tokenizer.decode(all_ids, skip_special=True)
1064
+
1065
+
1066
+ def speculative_decode(
1067
+ model: TinyMemoryLM,
1068
+ tokenizer: WordTokenizer,
1069
+ prompt: str,
1070
+ max_new_tokens: int,
1071
+ temperature: float,
1072
+ top_k: int,
1073
+ device: torch.device,
1074
+ seq_len: int,
1075
+ ) -> str:
1076
+ model.eval()
1077
+ ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
1078
+ x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
1079
+ all_generated: List[int] = []
1080
+
1081
+ with torch.no_grad():
1082
+ logits, _, h_out, past_kv = model(x, use_cache=True, return_hidden=True)
1083
+
1084
+ def _sample_from(lg: torch.Tensor) -> int:
1085
+ lg = lg / max(temperature, 1e-6)
1086
+ if top_k > 0:
1087
+ v, idx = torch.topk(lg, k=min(top_k, lg.shape[-1]))
1088
+ p = torch.softmax(v, dim=-1)
1089
+ return int(idx[torch.multinomial(p, 1)].item())
1090
+ p = torch.softmax(lg, dim=-1)
1091
+ return int(torch.multinomial(p, 1).item())
1092
+
1093
+ main_token = _sample_from(logits[0, -1, :])
1094
+ all_generated.append(main_token)
1095
+
1096
+ while len(all_generated) < max_new_tokens:
1097
+ if main_token == tokenizer.eos_id:
1098
+ break
1099
+
1100
+ draft_tokens = []
1101
+ if h_out is not None and model.mtp_horizons:
1102
+ last_hidden = h_out[:, -1:, :]
1103
+ for h in model.mtp_horizons:
1104
+ adapter = model.mtp_adapters[str(h)]
1105
+ norm = model.mtp_norms[str(h)]
1106
+ adapted = norm(adapter(last_hidden))
1107
+ draft_logits = model.head(adapted) + model.output_bias
1108
+ draft_tok = _sample_from(draft_logits[0, 0, :])
1109
+ draft_tokens.append(draft_tok)
1110
+
1111
+ if not draft_tokens:
1112
+ nid = torch.tensor([[main_token]], dtype=torch.long, device=device)
1113
+ logits, _, h_out, past_kv = model(
1114
+ nid, use_cache=True, past_key_values=past_kv, return_hidden=True
1115
+ )
1116
+ main_token = _sample_from(logits[0, -1, :])
1117
+ all_generated.append(main_token)
1118
+ continue
1119
+
1120
+ verify_input = torch.tensor(
1121
+ [[main_token] + draft_tokens], dtype=torch.long, device=device
1122
+ )
1123
+ verify_logits, _, h_out, past_kv = model(
1124
+ verify_input,
1125
+ use_cache=True,
1126
+ past_key_values=past_kv,
1127
+ return_hidden=True,
1128
+ )
1129
+
1130
+ accepted = 0
1131
+ all_generated.append(main_token) if main_token not in all_generated[
1132
+ -1:
1133
+ ] else None
1134
+ for i, draft_tok in enumerate(draft_tokens):
1135
+ verified_tok = _sample_from(verify_logits[0, i, :])
1136
+ if verified_tok == draft_tok:
1137
+ all_generated.append(draft_tok)
1138
+ accepted += 1
1139
+ if draft_tok == tokenizer.eos_id:
1140
+ break
1141
+ else:
1142
+ all_generated.append(verified_tok)
1143
+ break
1144
+
1145
+ if accepted < len(draft_tokens):
1146
+ trim_len = len(draft_tokens) - accepted - 1
1147
+ if trim_len > 0 and past_kv is not None:
1148
+ past_kv = [
1149
+ (k[:, :, :-trim_len, :], v[:, :, :-trim_len, :])
1150
+ if k is not None
1151
+ else None
1152
+ for k, v in past_kv
1153
+ ]
1154
+
1155
+ main_token = all_generated[-1]
1156
+
1157
+ return tokenizer.decode(all_generated, skip_special=True)
1158
+
1159
+
1160
+ def build_stop_token_ids(tokenizer: WordTokenizer) -> set:
1161
+ stop_tokens = {tokenizer.eos_id}
1162
+ for tok in ("<|user|>", "<|system|>", "<|assistant|>"):
1163
+ tid = tokenizer.token_to_id.get(tok)
1164
+ if tid is not None:
1165
+ stop_tokens.add(int(tid))
1166
+ return stop_tokens
1167
+
1168
+
1169
+ def apply_no_repeat_ngram(
1170
+ logits: torch.Tensor,
1171
+ token_history: Sequence[int],
1172
+ ngram_size: int,
1173
+ ) -> torch.Tensor:
1174
+ if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1):
1175
+ return logits
1176
+ prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple()
1177
+ banned: set = set()
1178
+ for i in range(len(token_history) - ngram_size + 1):
1179
+ if tuple(token_history[i : i + ngram_size - 1]) == prefix:
1180
+ banned.add(int(token_history[i + ngram_size - 1]))
1181
+ if not banned:
1182
+ return logits
1183
+ out = logits.clone()
1184
+ banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long)
1185
+ out[banned_ids] = float("-inf")
1186
+ return out
1187
+
1188
+
1189
+ def score_candidate(
1190
+ prompt: str,
1191
+ raw_text: str,
1192
+ visible_text: str,
1193
+ avg_logprob: float,
1194
+ ) -> float:
1195
+ clean = visible_text.strip()
1196
+ if not clean:
1197
+ return -1e9
1198
+ score = avg_logprob
1199
+ words = clean.lower().split()
1200
+ prompt_words = re.findall(r"[A-Za-z][A-Za-z'-]{2,}", prompt.lower())
1201
+ prompt_stop = {
1202
+ "what",
1203
+ "which",
1204
+ "when",
1205
+ "where",
1206
+ "why",
1207
+ "how",
1208
+ "are",
1209
+ "is",
1210
+ "the",
1211
+ "and",
1212
+ "for",
1213
+ "with",
1214
+ "that",
1215
+ "this",
1216
+ "from",
1217
+ "into",
1218
+ "about",
1219
+ "explain",
1220
+ "tell",
1221
+ "give",
1222
+ "list",
1223
+ "show",
1224
+ "write",
1225
+ "their",
1226
+ "there",
1227
+ "your",
1228
+ }
1229
+ prompt_keywords = {w for w in prompt_words if w not in prompt_stop}
1230
+ candidate_keywords = set(re.findall(r"[A-Za-z][A-Za-z'-]{2,}", clean.lower()))
1231
+ if len(words) < 6:
1232
+ score -= 2.0
1233
+ else:
1234
+ score += min(2.0, len(words) * 0.03)
1235
+ if clean[-1:] in ".!?":
1236
+ score += 0.5
1237
+ if "<|user|>" in raw_text or "<|system|>" in raw_text:
1238
+ score -= 4.0
1239
+ if raw_text.count("<|assistant|>") > 1:
1240
+ score -= 2.0
1241
+ if prompt_keywords:
1242
+ overlap = len(prompt_keywords & candidate_keywords) / len(prompt_keywords)
1243
+ if overlap == 0.0:
1244
+ score -= 2.5
1245
+ else:
1246
+ score += min(3.5, overlap * 4.0)
1247
+ for open_tok, close_tok in [
1248
+ ("<|begin_of_thought|>", "<|end_of_thought|>"),
1249
+ ("<|begin_of_solution|>", "<|end_of_solution|>"),
1250
+ ]:
1251
+ if (open_tok in raw_text) != (close_tok in raw_text):
1252
+ score -= 1.0
1253
+ if len(words) >= 3:
1254
+ trigrams = [tuple(words[i : i + 3]) for i in range(len(words) - 2)]
1255
+ if trigrams:
1256
+ unique_ratio = len(set(trigrams)) / len(trigrams)
1257
+ if unique_ratio < 0.35:
1258
+ score -= 4.0
1259
+ elif unique_ratio < 0.55:
1260
+ score -= 2.0
1261
+ else:
1262
+ score += min(1.0, (unique_ratio - 0.55) * 2.0)
1263
+ alpha_words = [
1264
+ w
1265
+ for w in words
1266
+ if len(w) <= 18 and (sum(ch.isalpha() for ch in w) / max(len(w), 1)) > 0.7
1267
+ ]
1268
+ alpha_ratio = len(alpha_words) / max(len(words), 1)
1269
+ if alpha_ratio < 0.45:
1270
+ score -= 3.0
1271
+ elif alpha_ratio < 0.65:
1272
+ score -= 1.0
1273
+ return score
1274
+
1275
+
1276
+ def generate_candidate(
1277
+ model: TinyMemoryLM,
1278
+ tokenizer: WordTokenizer,
1279
+ prompt: str,
1280
+ max_new_tokens: int,
1281
+ temperature: float,
1282
+ top_k: int,
1283
+ repetition_penalty: float,
1284
+ no_repeat_ngram_size: int,
1285
+ device: str,
1286
+ sft_mode: bool,
1287
+ force_thought: bool,
1288
+ stream: bool,
1289
+ context_window: int,
1290
+ ) -> Tuple[str, str, float, int]:
1291
+ if sft_mode:
1292
+ full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
1293
+ else:
1294
+ full_prompt = prompt
1295
+ if force_thought:
1296
+ full_prompt = f"{full_prompt}<|begin_of_thought|> "
1297
+ input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
1298
+ input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device)
1299
+ visible_tokens: List[str] = []
1300
+ raw_tokens: List[str] = []
1301
+ stop_token_ids = build_stop_token_ids(tokenizer)
1302
+ total_logprob = 0.0
1303
+ sampled_tokens = 0
1304
+ with torch.no_grad():
1305
+ for _ in range(max_new_tokens):
1306
+ ctx_ids = (
1307
+ input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
1308
+ )
1309
+ logits, _, _, _ = model(ctx_ids)
1310
+ next_logits = logits[0, -1, :].clone()
1311
+ raw_next_logits = next_logits.clone()
1312
+ if repetition_penalty != 1.0:
1313
+ seen = set(input_ids_t[0].tolist())
1314
+ for token_id in seen:
1315
+ if next_logits[token_id] > 0:
1316
+ next_logits[token_id] /= repetition_penalty
1317
+ else:
1318
+ next_logits[token_id] *= repetition_penalty
1319
+ if temperature != 1.0:
1320
+ next_logits = next_logits / max(temperature, 1e-6)
1321
+ if no_repeat_ngram_size > 1:
1322
+ next_logits = apply_no_repeat_ngram(
1323
+ next_logits,
1324
+ input_ids_t[0].tolist(),
1325
+ no_repeat_ngram_size,
1326
+ )
1327
+ if top_k > 0:
1328
+ v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0)))
1329
+ next_logits[next_logits < v[-1]] = float("-inf")
1330
+ top_p = 0.9
1331
+ if top_p < 1.0:
1332
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
1333
+ cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
1334
+ remove_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
1335
+ sorted_logits[remove_mask] = float("-inf")
1336
+ next_logits = sorted_logits.scatter(0, sorted_indices, sorted_logits)
1337
+ if not torch.isfinite(next_logits).any():
1338
+ next_logits = raw_next_logits
1339
+ if temperature != 1.0:
1340
+ next_logits = next_logits / max(temperature, 1e-6)
1341
+ probs = torch.softmax(next_logits, dim=-1)
1342
+ next_id = torch.multinomial(probs, num_samples=1).item()
1343
+ total_logprob += float(torch.log(probs[next_id] + 1e-12).item())
1344
+ sampled_tokens += 1
1345
+ if next_id in stop_token_ids:
1346
+ break
1347
+ token_str = (
1348
+ tokenizer.id_to_token[next_id]
1349
+ if next_id < len(tokenizer.id_to_token)
1350
+ else ""
1351
+ )
1352
+ raw_tokens.append(token_str)
1353
+ if token_str not in tokenizer.special:
1354
+ visible_tokens.append(token_str)
1355
+ if stream:
1356
+ print(token_str, end="", flush=True)
1357
+ input_ids_t = torch.cat(
1358
+ [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1
1359
+ )
1360
+ if stream:
1361
+ print()
1362
+ avg_logprob = total_logprob / max(1, sampled_tokens)
1363
+ return "".join(visible_tokens), "".join(raw_tokens), avg_logprob, 0
1364
+
1365
+
1366
+ def generate_beam_search(
1367
+ model: TinyMemoryLM,
1368
+ tokenizer: WordTokenizer,
1369
+ prompt: str,
1370
+ max_new_tokens: int = 60,
1371
+ beam_width: int = 8,
1372
+ length_penalty: float = 0.7,
1373
+ no_repeat_ngram_size: int = 3,
1374
+ device: str = "cuda",
1375
+ sft_mode: bool = False,
1376
+ context_window: int = 2048,
1377
+ ) -> str:
1378
+ if sft_mode:
1379
+ full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
1380
+ else:
1381
+ full_prompt = prompt
1382
+ prompt_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
1383
+ prompt_len = len(prompt_ids)
1384
+ stop_ids = build_stop_token_ids(tokenizer)
1385
+ beams: List[Tuple[float, List[int]]] = [(0.0, list(prompt_ids))]
1386
+ completed: List[Tuple[float, List[int]]] = []
1387
+ for _step in range(max_new_tokens):
1388
+ if not beams:
1389
+ break
1390
+ candidates: List[Tuple[float, List[int]]] = []
1391
+ for beam_score, beam_ids in beams:
1392
+ x = torch.tensor(
1393
+ [beam_ids[-context_window:]], dtype=torch.long, device=device
1394
+ )
1395
+ with torch.no_grad():
1396
+ logits, _, _, _ = model(x)
1397
+ nl = logits[0, -1, :]
1398
+ log_probs = F.log_softmax(nl, dim=-1)
1399
+ gen_ids = beam_ids[prompt_len:]
1400
+ if no_repeat_ngram_size > 1 and len(gen_ids) >= no_repeat_ngram_size - 1:
1401
+ prefix = tuple(gen_ids[-(no_repeat_ngram_size - 1) :])
1402
+ for i in range(len(gen_ids) - no_repeat_ngram_size + 1):
1403
+ if tuple(gen_ids[i : i + no_repeat_ngram_size - 1]) == prefix:
1404
+ log_probs[gen_ids[i + no_repeat_ngram_size - 1]] = float("-inf")
1405
+ topk_lp, topk_ids = torch.topk(log_probs, beam_width)
1406
+ for i in range(beam_width):
1407
+ tid = topk_ids[i].item()
1408
+ new_score = beam_score + topk_lp[i].item()
1409
+ new_ids = beam_ids + [tid]
1410
+ if tid in stop_ids:
1411
+ completed.append((new_score, new_ids))
1412
+ else:
1413
+ candidates.append((new_score, new_ids))
1414
+
1415
+ def _norm_score(pair):
1416
+ gen_len = max(1, len(pair[1]) - prompt_len)
1417
+ return pair[0] / (gen_len**length_penalty)
1418
+
1419
+ candidates.sort(key=_norm_score, reverse=True)
1420
+ beams = candidates[:beam_width]
1421
+
1422
+ pool = completed + beams
1423
+ if not pool:
1424
+ return ""
1425
+
1426
+ def _norm_score_final(pair):
1427
+ gen_len = max(1, len(pair[1]) - prompt_len)
1428
+ return pair[0] / (gen_len**length_penalty)
1429
+
1430
+ pool.sort(key=_norm_score_final, reverse=True)
1431
+ best_ids = pool[0][1][prompt_len:]
1432
+ text = tokenizer.decode(best_ids, skip_special=True)
1433
+ nl_pos = text.find("\n")
1434
+ if nl_pos > 5:
1435
+ text = text[:nl_pos]
1436
+ return text.strip()
1437
+
1438
+
1439
+ def generate(
1440
+ model: TinyMemoryLM,
1441
+ tokenizer: WordTokenizer,
1442
+ prompt: str,
1443
+ max_new_tokens: int = 256,
1444
+ temperature: float = 0.8,
1445
+ top_k: int = 40,
1446
+ repetition_penalty: float = 1.0,
1447
+ device: str = "cuda",
1448
+ sft_mode: bool = False,
1449
+ force_thought: bool = False,
1450
+ stream: bool = True,
1451
+ decode_mode: str = "legacy",
1452
+ best_of: int = 3,
1453
+ no_repeat_ngram_size: int = 3,
1454
+ context_window: int = 2048,
1455
+ beam_width: int = 8,
1456
+ length_penalty: float = 0.7,
1457
+ ) -> str:
1458
+ if decode_mode == "beam":
1459
+ text = generate_beam_search(
1460
+ model=model,
1461
+ tokenizer=tokenizer,
1462
+ prompt=prompt,
1463
+ max_new_tokens=max_new_tokens,
1464
+ beam_width=beam_width,
1465
+ length_penalty=length_penalty,
1466
+ no_repeat_ngram_size=no_repeat_ngram_size,
1467
+ device=device,
1468
+ sft_mode=sft_mode,
1469
+ context_window=context_window,
1470
+ )
1471
+ if stream:
1472
+ print(text)
1473
+ return text
1474
+ if decode_mode == "legacy":
1475
+ text, _, _, _ = generate_candidate(
1476
+ model=model,
1477
+ tokenizer=tokenizer,
1478
+ prompt=prompt,
1479
+ max_new_tokens=max_new_tokens,
1480
+ temperature=temperature,
1481
+ top_k=top_k,
1482
+ repetition_penalty=repetition_penalty,
1483
+ no_repeat_ngram_size=no_repeat_ngram_size,
1484
+ device=device,
1485
+ sft_mode=sft_mode,
1486
+ force_thought=force_thought,
1487
+ stream=stream,
1488
+ context_window=context_window,
1489
+ )
1490
+ return text
1491
+ candidates: List[Tuple[float, str, str, float]] = []
1492
+ for _ in range(max(1, best_of)):
1493
+ candidate_text, raw_text, avg_logprob, _ = generate_candidate(
1494
+ model=model,
1495
+ tokenizer=tokenizer,
1496
+ prompt=prompt,
1497
+ max_new_tokens=max_new_tokens,
1498
+ temperature=temperature,
1499
+ top_k=top_k,
1500
+ repetition_penalty=repetition_penalty,
1501
+ no_repeat_ngram_size=no_repeat_ngram_size,
1502
+ device=device,
1503
+ sft_mode=sft_mode,
1504
+ force_thought=force_thought,
1505
+ stream=False,
1506
+ context_window=context_window,
1507
+ )
1508
+ score = score_candidate(prompt, raw_text, candidate_text, avg_logprob)
1509
+ candidates.append((score, candidate_text, raw_text, avg_logprob))
1510
+ best_score, best_text, _, _ = max(candidates, key=lambda item: item[0])
1511
+ if stream:
1512
+ print(best_text, end="", flush=True)
1513
+ print()
1514
+ return best_text
1515
+
1516
+
1517
+ # ---------------------------------------------------------------------------
1518
+ # Web server (from interactive.py)
1519
+ # ---------------------------------------------------------------------------
1520
+
1521
+ ROOT = Path(__file__).resolve().parent
1522
+ if str(ROOT) not in sys.path:
1523
+ sys.path.insert(0, str(ROOT))
1524
+
1525
+
1526
+ HF_ORG = "CompactAI"
1527
+ HF_API = "https://huggingface.co/api"
1528
+ CACHE_ROOT = Path.home() / ".cache" / "compactai_web"
1529
+ USER_AGENT = "Mozilla/5.0 CompactAI-Web"
1530
+ MODEL_CACHE: dict[tuple[str, str], dict[str, object]] = {}
1531
+ MODEL_CACHE_LOCK = threading.RLock()
1532
+ GENERATION_LOCK = threading.Lock()
1533
+
1534
+
1535
+ def request_json(url: str):
1536
+ req = Request(url, headers={"User-Agent": USER_AGENT})
1537
+ with urlopen(req, timeout=60) as response:
1538
+ return json.loads(response.read().decode("utf-8"))
1539
+
1540
+
1541
+ def request_text(url: str) -> str:
1542
+ req = Request(url, headers={"User-Agent": USER_AGENT})
1543
+ with urlopen(req, timeout=60) as response:
1544
+ return response.read().decode("utf-8", errors="replace")
1545
+
1546
+
1547
+ def download_file(url: str, destination: Path) -> None:
1548
+ destination.parent.mkdir(parents=True, exist_ok=True)
1549
+ temp_path = destination.with_suffix(destination.suffix + ".tmp")
1550
+ req = Request(url, headers={"User-Agent": USER_AGENT})
1551
+ with urlopen(req, timeout=120) as response, temp_path.open("wb") as handle:
1552
+ shutil.copyfileobj(response, handle)
1553
+ temp_path.replace(destination)
1554
+
1555
+
1556
+ def normalize_repo_id(raw_repo_id: str) -> str:
1557
+ if not isinstance(raw_repo_id, str):
1558
+ return ""
1559
+ repo_id = raw_repo_id.strip()
1560
+ if not repo_id:
1561
+ return ""
1562
+ try:
1563
+ repo_id = unquote(repo_id)
1564
+ except Exception:
1565
+ pass
1566
+ return (
1567
+ repo_id.replace("https://huggingface.co/", "")
1568
+ .replace("http://huggingface.co/", "")
1569
+ .replace("api/models/", "")
1570
+ .replace("models/", "")
1571
+ .split("?", 1)[0]
1572
+ .split("#", 1)[0]
1573
+ .strip("/")
1574
+ )
1575
+
1576
+
1577
+ def series_from_name(name: str) -> str | None:
1578
+ lower = (name or "").lower()
1579
+ if "haiku" in lower:
1580
+ return "Haiku"
1581
+ if "sonnet" in lower:
1582
+ return "Sonnet"
1583
+ if "opus" in lower:
1584
+ return "Opus"
1585
+ return None
1586
+
1587
+
1588
+ def encoded_repo_id(repo_id: str) -> str:
1589
+ return "/".join(
1590
+ quote(part, safe="") for part in normalize_repo_id(repo_id).split("/") if part
1591
+ )
1592
+
1593
+
1594
+ def hf_file_url(repo_id: str, filename: str) -> str:
1595
+ encoded_name = "/".join(
1596
+ quote(part, safe="") for part in filename.split("/") if part
1597
+ )
1598
+ return (
1599
+ f"https://huggingface.co/{encoded_repo_id(repo_id)}/resolve/main/{encoded_name}"
1600
+ )
1601
+
1602
+
1603
+ def model_list() -> list[dict[str, object]]:
1604
+ data = request_json(f"{HF_API}/models?author={quote(HF_ORG)}&full=true&limit=200")
1605
+ models: list[dict[str, object]] = []
1606
+ for item in data:
1607
+ siblings = item.get("siblings") or []
1608
+ filenames = [s.get("rfilename", "") for s in siblings if isinstance(s, dict)]
1609
+ has_model = "model.pt" in filenames or "model/model.pt" in filenames
1610
+ has_pretrain = "pretrain.pt" in filenames or "model/pretrain.pt" in filenames
1611
+ has_tokenizer = (
1612
+ "tokenizer.json" in filenames or "model/tokenizer.json" in filenames
1613
+ )
1614
+ if not has_model and not has_pretrain:
1615
+ continue
1616
+ name = (item.get("id") or "").split("/")[-1]
1617
+ series = series_from_name(name)
1618
+ if not series:
1619
+ continue
1620
+ models.append(
1621
+ {
1622
+ "id": item.get("id", ""),
1623
+ "name": name,
1624
+ "series": series,
1625
+ "downloads": item.get("downloads", 0) or 0,
1626
+ "likes": item.get("likes", 0) or 0,
1627
+ "has_model": has_model,
1628
+ "has_pretrain": has_pretrain,
1629
+ "has_tokenizer": has_tokenizer,
1630
+ }
1631
+ )
1632
+ return sorted(models, key=lambda entry: entry["downloads"], reverse=True)
1633
+
1634
+
1635
+ def model_details(repo_id: str) -> dict[str, object] | None:
1636
+ normalized = normalize_repo_id(repo_id)
1637
+ if not normalized:
1638
+ return None
1639
+ data = request_json(f"{HF_API}/models/{encoded_repo_id(normalized)}")
1640
+ siblings = data.get("siblings") or []
1641
+ files: dict[str, dict[str, float]] = {}
1642
+ has_model = False
1643
+ has_pretrain = False
1644
+ for sibling in siblings:
1645
+ if not isinstance(sibling, dict):
1646
+ continue
1647
+ filename = sibling.get("rfilename") or ""
1648
+ if not filename:
1649
+ continue
1650
+ size_mb = round((sibling.get("size") or 0) / (1024 * 1024), 2)
1651
+ files[filename] = {"size_mb": size_mb}
1652
+ if filename.startswith("model/"):
1653
+ files[filename.removeprefix("model/")] = {"size_mb": size_mb}
1654
+ if filename in {"model.pt", "model/model.pt"}:
1655
+ has_model = True
1656
+ if filename in {"pretrain.pt", "model/pretrain.pt"}:
1657
+ has_pretrain = True
1658
+ readme_raw = ""
1659
+ try:
1660
+ readme_raw = request_text(
1661
+ f"https://huggingface.co/{encoded_repo_id(normalized)}/raw/main/README.md"
1662
+ )
1663
+ except Exception:
1664
+ readme_raw = ""
1665
+ name = (data.get("id") or normalized).split("/")[-1]
1666
+ return {
1667
+ "id": normalized,
1668
+ "name": name,
1669
+ "series": series_from_name(name) or "Sonnet",
1670
+ "downloads": data.get("downloads", 0) or 0,
1671
+ "files": files,
1672
+ "readme_raw": readme_raw,
1673
+ "hf_model_id": normalized,
1674
+ "has_model": has_model,
1675
+ "has_pretrain": has_pretrain,
1676
+ }
1677
+
1678
+
1679
+ def cache_dir(repo_id: str, model_type: str) -> Path:
1680
+ return CACHE_ROOT / normalize_repo_id(repo_id).replace("/", "__") / model_type
1681
+
1682
+
1683
+ def artifact_candidates(model_type: str) -> list[str]:
1684
+ return (
1685
+ ["model/pretrain.pt", "pretrain.pt"]
1686
+ if model_type == "pretrain"
1687
+ else ["model/model.pt", "model.pt"]
1688
+ )
1689
+
1690
+
1691
+ def ensure_artifact(repo_id: str, model_type: str, destination_name: str) -> Path:
1692
+ normalized = normalize_repo_id(repo_id)
1693
+ target = cache_dir(normalized, model_type) / destination_name
1694
+ if target.exists():
1695
+ return target
1696
+ last_error: Exception | None = None
1697
+ for candidate in (
1698
+ artifact_candidates(model_type)
1699
+ if destination_name.endswith(".pt")
1700
+ else ["model/tokenizer.json", "tokenizer.json"]
1701
+ ):
1702
+ try:
1703
+ download_file(hf_file_url(normalized, candidate), target)
1704
+ return target
1705
+ except Exception as exc:
1706
+ last_error = exc
1707
+ raise RuntimeError(
1708
+ f"Unable to download {destination_name} for {normalized}: {last_error}"
1709
+ )
1710
+
1711
+
1712
+ def series_config(series: str) -> dict[str, object]:
1713
+ return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"])
1714
+
1715
+
1716
+ def load_bundle(repo_id: str, model_type: str) -> dict[str, object]:
1717
+ normalized = normalize_repo_id(repo_id)
1718
+ details = model_details(normalized)
1719
+ if not details:
1720
+ raise RuntimeError("Model details are unavailable.")
1721
+ series = str(details["series"])
1722
+ key = (normalized, model_type)
1723
+ with MODEL_CACHE_LOCK:
1724
+ cached = MODEL_CACHE.get(key)
1725
+ if cached:
1726
+ return cached
1727
+ bundle_dir = cache_dir(normalized, model_type)
1728
+ bundle_dir.mkdir(parents=True, exist_ok=True)
1729
+ model_path = bundle_dir / (
1730
+ "pretrain.pt" if model_type == "pretrain" else "model.pt"
1731
+ )
1732
+ tokenizer_path = bundle_dir / "tokenizer.json"
1733
+ if not model_path.exists():
1734
+ ensure_artifact(normalized, model_type, model_path.name)
1735
+ if not tokenizer_path.exists():
1736
+ ensure_artifact(normalized, model_type, tokenizer_path.name)
1737
+ tokenizer = WordTokenizer.load(tokenizer_path)
1738
+ ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
1739
+ cfg = series_config(series)
1740
+ vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
1741
+ state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
1742
+ # Auto-detect new arch features from checkpoint weights
1743
+ engram_dim = _detect_engram_dim(state_dict) or int(
1744
+ cfg.get("engram_dim", model_config.engram_dim)
1745
+ )
1746
+ mhc_expansion = _detect_mhc_expansion(state_dict) or int(
1747
+ cfg.get("mhc_expansion", model_config.mhc_expansion)
1748
+ )
1749
+ model = TinyMemoryLM(
1750
+ vocab_size=vocab_size,
1751
+ dim=int(cfg.get("dim", model_config.dim)),
1752
+ n_unique_layers=int(
1753
+ cfg.get("n_unique_layers", model_config.n_unique_layers)
1754
+ ),
1755
+ n_logical_layers=int(
1756
+ cfg.get("n_logical_layers", model_config.n_logical_layers)
1757
+ ),
1758
+ n_heads=int(cfg.get("n_heads", model_config.n_heads)),
1759
+ n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
1760
+ ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
1761
+ dropout=float(cfg.get("dropout", model_config.dropout)),
1762
+ mtp_horizons=tuple(
1763
+ int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)
1764
+ ),
1765
+ grad_checkpoint=False,
1766
+ sliding_window=int(
1767
+ cfg.get("sliding_window_size", model_config.sliding_window_size)
1768
+ ),
1769
+ rope_fraction=float(
1770
+ cfg.get("rope_fraction", model_config.rope_fraction)
1771
+ ),
1772
+ embed_scale=bool(
1773
+ cfg.get("embed_scale", model_config.embed_scale)
1774
+ ),
1775
+ engram_dim=engram_dim,
1776
+ engram_heads=int(cfg.get("engram_heads", model_config.engram_heads)),
1777
+ engram_table_size=int(
1778
+ cfg.get("engram_table_size", model_config.engram_table_size)
1779
+ ),
1780
+ engram_max_ngram=int(
1781
+ cfg.get("engram_max_ngram", model_config.engram_max_ngram)
1782
+ ),
1783
+ mhc_expansion=mhc_expansion,
1784
+ )
1785
+ model.load_state_dict(state_dict, strict=False)
1786
+ model.eval()
1787
+ if tokenizer.vocab_size > vocab_size:
1788
+ model.resize_token_embeddings(tokenizer.vocab_size)
1789
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1790
+ model = model.to(device)
1791
+ bundle = {
1792
+ "repo_id": normalized,
1793
+ "name": details["name"],
1794
+ "series": series,
1795
+ "type": model_type,
1796
+ "model": model,
1797
+ "tokenizer": tokenizer,
1798
+ "device": device,
1799
+ "model_path": str(model_path),
1800
+ "tokenizer_path": str(tokenizer_path),
1801
+ "downloads": details["downloads"],
1802
+ }
1803
+ MODEL_CACHE[key] = bundle
1804
+ return bundle
1805
+
1806
+
1807
+ def ensure_port(start_port: int) -> int:
1808
+ for port in range(start_port, start_port + 50):
1809
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1810
+ try:
1811
+ sock.bind(("127.0.0.1", port))
1812
+ except OSError:
1813
+ continue
1814
+ return port
1815
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1816
+ sock.bind(("127.0.0.1", 0))
1817
+ return sock.getsockname()[1]
1818
+
1819
+
1820
+ def page_html() -> str:
1821
+ return f"""<!doctype html>
1822
+ <html lang="en">
1823
+ <head>
1824
+ <meta charset="utf-8">
1825
+ <meta name="viewport" content="width=device-width, initial-scale=1">
1826
+ <title>CompactAI Web</title>
1827
+ <style>
1828
+ :root {{
1829
+ color-scheme: dark;
1830
+ --bg: #050505;
1831
+ --panel: #111111;
1832
+ --panel-2: #161616;
1833
+ --line: #262626;
1834
+ --text: #f5f5f5;
1835
+ --muted: #a3a3a3;
1836
+ --accent: #d97706;
1837
+ --accent-2: #b45309;
1838
+ --soft: #1f1f1f;
1839
+ }}
1840
+ * {{ box-sizing: border-box; }}
1841
+ body {{
1842
+ margin: 0;
1843
+ font-family: Geist, -apple-system, BlinkMacSystemFont, sans-serif;
1844
+ background: var(--bg);
1845
+ color: var(--text);
1846
+ line-height: 1.5;
1847
+ }}
1848
+ a {{ color: inherit; }}
1849
+ .wrap {{ max-width: 1120px; margin: 0 auto; padding: 28px 20px 40px; }}
1850
+ .hero {{
1851
+ display: flex;
1852
+ justify-content: space-between;
1853
+ align-items: end;
1854
+ gap: 16px;
1855
+ padding: 22px 0 28px;
1856
+ border-bottom: 1px solid var(--line);
1857
+ margin-bottom: 22px;
1858
+ }}
1859
+ h1 {{ margin: 0; font-size: clamp(2rem, 5vw, 3.5rem); letter-spacing: -0.04em; }}
1860
+ .subtitle {{ margin: 10px 0 0; color: var(--muted); max-width: 58ch; }}
1861
+ .grid {{
1862
+ display: grid;
1863
+ grid-template-columns: 1.1fr 1fr;
1864
+ gap: 18px;
1865
+ }}
1866
+ .panel {{
1867
+ background: var(--panel);
1868
+ border: 1px solid var(--line);
1869
+ border-radius: 18px;
1870
+ padding: 18px;
1871
+ }}
1872
+ .panel h2 {{ margin: 0 0 12px; font-size: 15px; letter-spacing: 0.02em; text-transform: uppercase; color: var(--muted); }}
1873
+ .row {{ display: flex; gap: 10px; flex-wrap: wrap; }}
1874
+ select, textarea, input {{
1875
+ width: 100%;
1876
+ background: var(--panel-2);
1877
+ color: var(--text);
1878
+ border: 1px solid var(--line);
1879
+ border-radius: 12px;
1880
+ padding: 12px 14px;
1881
+ font: inherit;
1882
+ outline: none;
1883
+ }}
1884
+ textarea {{ min-height: 170px; resize: vertical; }}
1885
+ select {{ appearance: none; }}
1886
+ .choice {{
1887
+ flex: 1 1 150px;
1888
+ display: flex;
1889
+ align-items: center;
1890
+ gap: 10px;
1891
+ padding: 10px 12px;
1892
+ border: 1px solid var(--line);
1893
+ border-radius: 12px;
1894
+ background: var(--panel-2);
1895
+ cursor: pointer;
1896
+ }}
1897
+ .choice input {{ width: auto; }}
1898
+ .btns {{ display: flex; flex-wrap: wrap; gap: 10px; }}
1899
+ button {{
1900
+ border: 1px solid var(--line);
1901
+ border-radius: 12px;
1902
+ padding: 11px 14px;
1903
+ background: var(--soft);
1904
+ color: var(--text);
1905
+ font: inherit;
1906
+ cursor: pointer;
1907
+ transition: transform 0.15s ease, border-color 0.15s ease, background 0.15s ease;
1908
+ }}
1909
+ button:hover {{ transform: translateY(-1px); border-color: #3a3a3a; }}
1910
+ .primary {{ background: var(--accent); border-color: var(--accent); color: #fff; }}
1911
+ .primary:hover {{ background: var(--accent-2); border-color: var(--accent-2); }}
1912
+ .status {{
1913
+ margin-top: 12px;
1914
+ color: var(--muted);
1915
+ font-size: 13px;
1916
+ min-height: 1.4em;
1917
+ }}
1918
+ .output {{
1919
+ white-space: pre-wrap;
1920
+ background: #0b0b0b;
1921
+ border: 1px solid var(--line);
1922
+ border-radius: 16px;
1923
+ min-height: 280px;
1924
+ padding: 16px;
1925
+ color: #e7e5e4;
1926
+ overflow: auto;
1927
+ }}
1928
+ .meta {{
1929
+ display: flex;
1930
+ flex-wrap: wrap;
1931
+ gap: 8px;
1932
+ margin-top: 8px;
1933
+ }}
1934
+ .chip {{
1935
+ display: inline-flex;
1936
+ align-items: center;
1937
+ gap: 6px;
1938
+ padding: 6px 10px;
1939
+ border-radius: 999px;
1940
+ border: 1px solid var(--line);
1941
+ background: var(--panel-2);
1942
+ font-size: 12px;
1943
+ color: var(--muted);
1944
+ }}
1945
+ .code {{
1946
+ margin-top: 14px;
1947
+ padding: 12px 14px;
1948
+ border-radius: 12px;
1949
+ border: 1px solid var(--line);
1950
+ background: #0b0b0b;
1951
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
1952
+ font-size: 13px;
1953
+ overflow-x: auto;
1954
+ }}
1955
+ @media (max-width: 900px) {{
1956
+ .grid {{ grid-template-columns: 1fr; }}
1957
+ .hero {{ align-items: start; flex-direction: column; }}
1958
+ }}
1959
+ </style>
1960
+ </head>
1961
+ <body>
1962
+ <div class="wrap">
1963
+ <div class="hero">
1964
+ <div>
1965
+ <h1>CompactAI Web</h1>
1966
+ <p class="subtitle">Pull a model from Hugging Face, keep it cached locally, and chat in the browser.</p>
1967
+ </div>
1968
+ <div class="meta">
1969
+ <span class="chip">Hugging Face: CompactAI</span>
1970
+ <span class="chip">pip install -r requirements.txt</span>
1971
+ <span class="chip">Local inference</span>
1972
+ </div>
1973
+ </div>
1974
+
1975
+ <div class="grid">
1976
+ <section class="panel">
1977
+ <h2>Model</h2>
1978
+ <select id="modelSelect"></select>
1979
+ <div class="row" style="margin-top: 10px;">
1980
+ <label class="choice"><input type="radio" name="type" value="model" checked> Instruct / final</label>
1981
+ <label class="choice"><input type="radio" name="type" value="pretrain"> Pretrain</label>
1982
+ </div>
1983
+ <div class="btns" style="margin-top: 12px;">
1984
+ <button id="downloadBtn">Download</button>
1985
+ <button id="refreshBtn">Refresh models</button>
1986
+ </div>
1987
+ <div class="status" id="modelStatus">Loading model list…</div>
1988
+ <div class="code">python3 interactive_web.py</div>
1989
+ </section>
1990
+
1991
+ <section class="panel">
1992
+ <h2>Prompt</h2>
1993
+ <textarea id="prompt" placeholder="Ask something…"></textarea>
1994
+ <div class="row" style="margin-top: 10px;">
1995
+ <input id="temperature" type="number" min="0.1" max="2" step="0.05" value="0.8" style="flex: 1 1 120px;">
1996
+ <input id="topK" type="number" min="1" max="100" step="1" value="40" style="flex: 1 1 120px;">
1997
+ <input id="maxTokens" type="number" min="16" max="2048" step="16" value="256" style="flex: 1 1 120px;">
1998
+ </div>
1999
+ <div class="btns" style="margin-top: 12px;">
2000
+ <button id="generateBtn" class="primary">Generate</button>
2001
+ </div>
2002
+ <div class="status" id="genStatus"></div>
2003
+ </section>
2004
+ </div>
2005
+
2006
+ <section class="panel" style="margin-top: 18px;">
2007
+ <h2>Response</h2>
2008
+ <div id="output" class="output"></div>
2009
+ </section>
2010
+ </div>
2011
+
2012
+ <script>
2013
+ const modelSelect = document.getElementById('modelSelect');
2014
+ const modelStatus = document.getElementById('modelStatus');
2015
+ const genStatus = document.getElementById('genStatus');
2016
+ const output = document.getElementById('output');
2017
+ const promptBox = document.getElementById('prompt');
2018
+
2019
+ async function api(path, body) {{
2020
+ const response = await fetch(path, {{
2021
+ method: body ? 'POST' : 'GET',
2022
+ headers: body ? {{ 'Content-Type': 'application/json' }} : undefined,
2023
+ body: body ? JSON.stringify(body) : undefined,
2024
+ }});
2025
+ return response.json();
2026
+ }}
2027
+
2028
+ function currentType() {{
2029
+ return document.querySelector('input[name="type"]:checked').value;
2030
+ }}
2031
+
2032
+ function currentModelId() {{
2033
+ return modelSelect.value;
2034
+ }}
2035
+
2036
+ function setModels(models) {{
2037
+ modelSelect.innerHTML = '';
2038
+ for (const model of models) {{
2039
+ const option = document.createElement('option');
2040
+ option.value = model.id;
2041
+ option.textContent = `${{model.name}} • ${{model.series}}`;
2042
+ modelSelect.appendChild(option);
2043
+ }}
2044
+ if (models.length === 0) {{
2045
+ const option = document.createElement('option');
2046
+ option.value = '';
2047
+ option.textContent = 'No CompactAI models found';
2048
+ modelSelect.appendChild(option);
2049
+ }}
2050
+ }}
2051
+
2052
+ async function refreshModels() {{
2053
+ modelStatus.textContent = 'Loading model list…';
2054
+ try {{
2055
+ const models = await api('/api/models');
2056
+ setModels(models);
2057
+ modelStatus.textContent = models.length ? `${{models.length}} models available from CompactAI` : 'No compatible models found.';
2058
+ }} catch (error) {{
2059
+ modelStatus.textContent = 'Failed to load model list.';
2060
+ }}
2061
+ }}
2062
+
2063
+ async function ensureModel() {{
2064
+ const modelId = currentModelId();
2065
+ if (!modelId) {{
2066
+ modelStatus.textContent = 'Pick a model first.';
2067
+ return null;
2068
+ }}
2069
+ modelStatus.textContent = 'Downloading model files…';
2070
+ const result = await api('/api/ensure', {{ modelId, type: currentType() }});
2071
+ if (!result.success) {{
2072
+ modelStatus.textContent = result.error || 'Download failed.';
2073
+ return null;
2074
+ }}
2075
+ modelStatus.textContent = `${{result.name}} ready on ${{result.series}}`;
2076
+ return result;
2077
+ }}
2078
+
2079
+ async function generate() {{
2080
+ output.textContent = '';
2081
+ genStatus.textContent = '';
2082
+ const modelId = currentModelId();
2083
+ const prompt = promptBox.value.trim();
2084
+ if (!modelId) {{
2085
+ genStatus.textContent = 'Pick a model first.';
2086
+ return;
2087
+ }}
2088
+ if (!prompt) {{
2089
+ genStatus.textContent = 'Enter a prompt first.';
2090
+ return;
2091
+ }}
2092
+ genStatus.textContent = 'Preparing model…';
2093
+ const result = await api('/api/generate', {{
2094
+ modelId,
2095
+ type: currentType(),
2096
+ prompt,
2097
+ temperature: Number(document.getElementById('temperature').value || 0.8),
2098
+ top_k: Number(document.getElementById('topK').value || 40),
2099
+ max_new_tokens: Number(document.getElementById('maxTokens').value || 256),
2100
+ }});
2101
+ if (!result.success) {{
2102
+ genStatus.textContent = result.error || 'Generation failed.';
2103
+ return;
2104
+ }}
2105
+ output.textContent = result.text || '';
2106
+ genStatus.textContent = 'Done.';
2107
+ }}
2108
+
2109
+ document.getElementById('refreshBtn').addEventListener('click', refreshModels);
2110
+ document.getElementById('downloadBtn').addEventListener('click', ensureModel);
2111
+ document.getElementById('generateBtn').addEventListener('click', generate);
2112
+ promptBox.addEventListener('keydown', (event) => {{
2113
+ if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {{
2114
+ event.preventDefault();
2115
+ generate();
2116
+ }}
2117
+ }});
2118
+
2119
+ refreshModels();
2120
+ </script>
2121
+ </body>
2122
+ </html>"""
2123
+
2124
+
2125
+ class Handler(BaseHTTPRequestHandler):
2126
+ def _send_json(self, payload, status=200):
2127
+ body = json.dumps(payload).encode("utf-8")
2128
+ self.send_response(status)
2129
+ self.send_header("Content-Type", "application/json; charset=utf-8")
2130
+ self.send_header("Content-Length", str(len(body)))
2131
+ self.send_header("Cache-Control", "no-store")
2132
+ self.end_headers()
2133
+ self.wfile.write(body)
2134
+
2135
+ def _send_html(self, payload: str, status=200):
2136
+ body = payload.encode("utf-8")
2137
+ self.send_response(status)
2138
+ self.send_header("Content-Type", "text/html; charset=utf-8")
2139
+ self.send_header("Content-Length", str(len(body)))
2140
+ self.send_header("Cache-Control", "no-store")
2141
+ self.end_headers()
2142
+ self.wfile.write(body)
2143
+
2144
+ def do_GET(self):
2145
+ parsed = urlparse(self.path)
2146
+ if parsed.path in {"/", "/index.html"}:
2147
+ self._send_html(page_html())
2148
+ return
2149
+ if parsed.path == "/api/models":
2150
+ try:
2151
+ self._send_json(model_list())
2152
+ except Exception as exc:
2153
+ self._send_json({"success": False, "error": str(exc)}, 500)
2154
+ return
2155
+ if parsed.path.startswith("/api/models/"):
2156
+ repo_id = normalize_repo_id(parsed.path.removeprefix("/api/models/"))
2157
+ try:
2158
+ details = model_details(repo_id)
2159
+ if not details:
2160
+ self._send_json(
2161
+ {"success": False, "error": "Model not found."}, 404
2162
+ )
2163
+ else:
2164
+ self._send_json(details)
2165
+ except Exception as exc:
2166
+ self._send_json({"success": False, "error": str(exc)}, 500)
2167
+ return
2168
+ self._send_json({"success": False, "error": "Not found."}, 404)
2169
+
2170
+ def do_POST(self):
2171
+ parsed = urlparse(self.path)
2172
+ length = int(self.headers.get("Content-Length", "0") or "0")
2173
+ raw = self.rfile.read(length).decode("utf-8") if length else "{}"
2174
+ try:
2175
+ payload = json.loads(raw or "{}")
2176
+ except Exception:
2177
+ payload = {}
2178
+ if parsed.path == "/api/ensure":
2179
+ try:
2180
+ repo_id = normalize_repo_id(payload.get("modelId", ""))
2181
+ model_type = payload.get("type", "model")
2182
+ if not repo_id:
2183
+ self._send_json(
2184
+ {"success": False, "error": "Missing model ID."}, 400
2185
+ )
2186
+ return
2187
+ details = model_details(repo_id)
2188
+ if not details:
2189
+ self._send_json(
2190
+ {"success": False, "error": "Model not found."}, 404
2191
+ )
2192
+ return
2193
+ bundle = load_bundle(repo_id, model_type)
2194
+ self._send_json(
2195
+ {
2196
+ "success": True,
2197
+ "id": bundle["repo_id"],
2198
+ "name": bundle["name"],
2199
+ "series": bundle["series"],
2200
+ "type": bundle["type"],
2201
+ }
2202
+ )
2203
+ except Exception as exc:
2204
+ self._send_json({"success": False, "error": str(exc)}, 500)
2205
+ return
2206
+ if parsed.path == "/api/generate":
2207
+ try:
2208
+ repo_id = normalize_repo_id(payload.get("modelId", ""))
2209
+ model_type = payload.get("type", "model")
2210
+ prompt = str(payload.get("prompt", ""))
2211
+ if not repo_id:
2212
+ self._send_json(
2213
+ {"success": False, "error": "Missing model ID."}, 400
2214
+ )
2215
+ return
2216
+ bundle = load_bundle(repo_id, model_type)
2217
+ with GENERATION_LOCK:
2218
+ text = generate(
2219
+ model=bundle["model"],
2220
+ tokenizer=bundle["tokenizer"],
2221
+ prompt=prompt,
2222
+ max_new_tokens=int(payload.get("max_new_tokens", 256)),
2223
+ temperature=float(payload.get("temperature", 0.8)),
2224
+ top_k=int(payload.get("top_k", 40)),
2225
+ repetition_penalty=float(
2226
+ payload.get("repetition_penalty", 1.0)
2227
+ ),
2228
+ device=str(bundle["device"]),
2229
+ sft_mode=model_type != "pretrain",
2230
+ force_thought=bool(payload.get("force_thought", False)),
2231
+ stream=False,
2232
+ decode_mode=str(payload.get("decode_mode", "legacy")),
2233
+ best_of=int(payload.get("best_of", 3)),
2234
+ no_repeat_ngram_size=int(
2235
+ payload.get("no_repeat_ngram_size", 3)
2236
+ ),
2237
+ context_window=int(payload.get("context_window", 2048)),
2238
+ beam_width=int(payload.get("beam_width", 8)),
2239
+ length_penalty=float(payload.get("length_penalty", 0.7)),
2240
+ )
2241
+ self._send_json(
2242
+ {
2243
+ "success": True,
2244
+ "text": text,
2245
+ "name": bundle["name"],
2246
+ "series": bundle["series"],
2247
+ }
2248
+ )
2249
+ except Exception as exc:
2250
+ self._send_json({"success": False, "error": str(exc)}, 500)
2251
+ return
2252
+ self._send_json({"success": False, "error": "Not found."}, 404)
2253
+
2254
+ def log_message(self, format, *args):
2255
+ return
2256
+
2257
+
2258
+ def main():
2259
+ CACHE_ROOT.mkdir(parents=True, exist_ok=True)
2260
+ port = ensure_port(int(os.environ.get("PORT", "7860")))
2261
+ server = ThreadingHTTPServer(("127.0.0.1", port), Handler)
2262
+ url = f"http://127.0.0.1:{port}"
2263
+ print(url, flush=True)
2264
+ try:
2265
+ webbrowser.open(url)
2266
+ except Exception:
2267
+ pass
2268
+ try:
2269
+ server.serve_forever()
2270
+ except KeyboardInterrupt:
2271
+ pass
2272
+ finally:
2273
+ server.server_close()
2274
+
2275
+
2276
+ if __name__ == "__main__":
2277
+ main()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch>=2.0.0