CompactAI commited on
Commit
f6593cb
·
verified ·
1 Parent(s): 9a71e9c

Upload interactive.py

Browse files
Files changed (1) hide show
  1. downloads/interactive.py +1984 -1
downloads/interactive.py CHANGED
@@ -1,4 +1,1987 @@
1
- from interactive import main
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  if __name__ == "__main__":
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import math
6
+ import os
7
+ import re
8
+ import shutil
9
+ import socket
10
+ import string
11
+ import sys
12
+ import threading
13
+ import webbrowser
14
+ from dataclasses import dataclass
15
+ from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple
18
+ from urllib.parse import quote, unquote, urlparse
19
+ from urllib.request import Request, urlopen
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.utils.checkpoint import checkpoint
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Config (from ailay.config)
29
+ # ---------------------------------------------------------------------------
30
+
31
+
32
+ @dataclass
33
+ class ModelConfig:
34
+ dim: int = 128
35
+ n_unique_layers: int = 8
36
+ n_logical_layers: int = 16
37
+ n_heads: int = 4
38
+ n_kv_heads: int = 2
39
+ ffn_dim: int = 224
40
+ dropout: float = 0.0
41
+ seq_len: int = 2048
42
+ sliding_window_size: int = 512
43
+ mtp_horizons: Tuple[int, ...] = (2, 3, 4)
44
+ rope_fraction: float = 0.25
45
+ embed_scale: bool = True
46
+ logit_soft_cap: float = -1.0
47
+ quantization: str = "nvfp4"
48
+
49
+ @property
50
+ def head_dim(self) -> int:
51
+ return self.dim // self.n_heads
52
+
53
+
54
+ model_config = ModelConfig()
55
+
56
+ MODEL_SERIES = {
57
+ "haiku": {
58
+ "dim": 128,
59
+ "n_unique_layers": 8,
60
+ "n_logical_layers": 16,
61
+ "n_heads": 4,
62
+ "n_kv_heads": 2,
63
+ "ffn_dim": 224,
64
+ "dropout": 0.0,
65
+ "seq_len": 2048,
66
+ "mtp_horizons": (2, 3, 4),
67
+ "batch_size": 48,
68
+ "grad_accum": 1,
69
+ "lr": 8e-4,
70
+ "min_lr": 1e-5,
71
+ "sft_lr": 2e-4,
72
+ "sft_min_lr": 1e-5,
73
+ "warmup_steps": 300,
74
+ "weight_decay": 0.02,
75
+ "pretrain_passes": 2,
76
+ "sft_passes": 3,
77
+ "max_sft_target_chars": 128,
78
+ "use_grad_checkpoint": False,
79
+ "use_torch_compile": True,
80
+ "num_workers": 24,
81
+ "prefetch_factor": 64,
82
+ "shuffle_buffer": 8192,
83
+ "max_pretrain_tokens": 0,
84
+ "min_pretrain_tokens": 100_000_000,
85
+ "quantization": "nvfp4",
86
+ },
87
+ "sonnet": {
88
+ "dim": 768,
89
+ "n_unique_layers": 18,
90
+ "n_logical_layers": 36,
91
+ "n_heads": 12,
92
+ "n_kv_heads": 4,
93
+ "ffn_dim": 2538,
94
+ "dropout": 0.0,
95
+ "seq_len": 2048,
96
+ "mtp_horizons": (2,),
97
+ "batch_size": 6,
98
+ "grad_accum": 1,
99
+ "lr": 2e-4,
100
+ "min_lr": 2e-5,
101
+ "sft_lr": 5e-5,
102
+ "sft_min_lr": 5e-6,
103
+ "warmup_steps": 250,
104
+ "weight_decay": 0.1,
105
+ "pretrain_passes": 1,
106
+ "sft_passes": 1,
107
+ "max_sft_target_chars": 512,
108
+ "use_grad_checkpoint": True,
109
+ "use_torch_compile": True,
110
+ "num_workers": 32,
111
+ "prefetch_factor": 48,
112
+ "shuffle_buffer": 16384,
113
+ "max_pretrain_tokens": 0,
114
+ "min_pretrain_tokens": 0,
115
+ "quantization": "nvfp4",
116
+ },
117
+ "opus": {
118
+ "dim": 1024,
119
+ "n_unique_layers": 20,
120
+ "n_logical_layers": 40,
121
+ "n_heads": 16,
122
+ "n_kv_heads": 4,
123
+ "ffn_dim": 3557,
124
+ "dropout": 0.0,
125
+ "seq_len": 2048,
126
+ "mtp_horizons": (2,),
127
+ "batch_size": 12,
128
+ "grad_accum": 1,
129
+ "lr": 1.6e-4,
130
+ "min_lr": 1.6e-5,
131
+ "sft_lr": 3e-5,
132
+ "sft_min_lr": 3e-6,
133
+ "warmup_steps": 200,
134
+ "weight_decay": 0.1,
135
+ "pretrain_passes": 1,
136
+ "sft_passes": 1,
137
+ "max_sft_target_chars": 1024,
138
+ "use_grad_checkpoint": True,
139
+ "use_torch_compile": True,
140
+ "num_workers": 48,
141
+ "prefetch_factor": 48,
142
+ "shuffle_buffer": 16384,
143
+ "max_pretrain_tokens": 0,
144
+ "min_pretrain_tokens": 0,
145
+ "quantization": "nvfp4",
146
+ },
147
+ }
148
+
149
+
150
+ # ---------------------------------------------------------------------------
151
+ # Tokenizer (from ailay.tokenizer)
152
+ # ---------------------------------------------------------------------------
153
+
154
+ FORMAT_TOKENS = [
155
+ "<|user|>",
156
+ "<|assistant|>",
157
+ "<|system|>",
158
+ "<|start_header_id|>",
159
+ "<|end_header_id|>",
160
+ "<|begin_of_thought|>",
161
+ "<|end_of_thought|>",
162
+ "<|begin_of_solution|>",
163
+ "<|end_of_solution|>",
164
+ ]
165
+
166
+
167
+ class WordTokenizer:
168
+ WORD_RE = re.compile(
169
+ r"\s+|[^\W\d_]+(?:['\u2019][^\W\d_]+)?|\d+|[^\w\s]+", re.UNICODE
170
+ )
171
+
172
+ def __init__(
173
+ self, extra_chars: str = "", format_tokens: Optional[List[str]] = None
174
+ ) -> None:
175
+ base = string.ascii_letters + string.digits + string.punctuation + " \n\t\r"
176
+ fallback_chars = sorted(set(base + extra_chars))
177
+ self.core_special = ["<PAD>", "<BOS>", "<EOS>", "<UNK>"]
178
+ self.format_tokens = (
179
+ list(format_tokens) if format_tokens else list(FORMAT_TOKENS)
180
+ )
181
+ self.special = list(self.core_special) + list(self.format_tokens)
182
+ self.id_to_token: List[str] = (
183
+ list(self.core_special) + self.format_tokens + fallback_chars
184
+ )
185
+ self.token_to_id: Dict[str, int] = {
186
+ t: i for i, t in enumerate(self.id_to_token)
187
+ }
188
+ self.special_multi_tokens = sorted(
189
+ [t for t in self.special if len(t) > 1], key=len, reverse=True
190
+ )
191
+ self.multi_char_tokens = self.special_multi_tokens
192
+ self.dynamic_additions = 0
193
+
194
+ @property
195
+ def pad_id(self) -> int:
196
+ return self.token_to_id["<PAD>"]
197
+
198
+ @property
199
+ def bos_id(self) -> int:
200
+ return self.token_to_id["<BOS>"]
201
+
202
+ @property
203
+ def eos_id(self) -> int:
204
+ return self.token_to_id["<EOS>"]
205
+
206
+ @property
207
+ def unk_id(self) -> int:
208
+ return self.token_to_id["<UNK>"]
209
+
210
+ @property
211
+ def vocab_size(self) -> int:
212
+ return len(self.id_to_token)
213
+
214
+ def maybe_add_char(self, ch: str) -> bool:
215
+ if ch in self.token_to_id:
216
+ return False
217
+ self.token_to_id[ch] = len(self.id_to_token)
218
+ self.id_to_token.append(ch)
219
+ self.dynamic_additions += 1
220
+ return True
221
+
222
+ def maybe_add_token(self, token: str) -> bool:
223
+ if token in self.token_to_id:
224
+ return False
225
+ self.token_to_id[token] = len(self.id_to_token)
226
+ self.id_to_token.append(token)
227
+ self.dynamic_additions += 1
228
+ return True
229
+
230
+ def iter_lexical_tokens(self, text: str) -> Iterator[str]:
231
+ i = 0
232
+ n = len(text)
233
+ while i < n:
234
+ matched_special = False
235
+ for token in self.special_multi_tokens:
236
+ if text.startswith(token, i):
237
+ yield token
238
+ i += len(token)
239
+ matched_special = True
240
+ break
241
+ if matched_special:
242
+ continue
243
+ m = self.WORD_RE.match(text, i)
244
+ if m is None:
245
+ yield text[i]
246
+ i += 1
247
+ continue
248
+ tok = m.group(0)
249
+ yield tok
250
+ i = m.end()
251
+
252
+ def encode(
253
+ self, text: str, add_bos: bool = False, add_eos: bool = False
254
+ ) -> List[int]:
255
+ out: List[int] = []
256
+ if add_bos:
257
+ out.append(self.bos_id)
258
+ unk = self.unk_id
259
+ t2i = self.token_to_id
260
+ for tok in self.iter_lexical_tokens(text):
261
+ tid = t2i.get(tok)
262
+ if tid is not None:
263
+ out.append(tid)
264
+ continue
265
+ for ch in tok:
266
+ out.append(t2i.get(ch, unk))
267
+ if add_eos:
268
+ out.append(self.eos_id)
269
+ return out
270
+
271
+ def decode(self, ids: Sequence[int], skip_special: bool = True) -> str:
272
+ pieces: List[str] = []
273
+ for idx in ids:
274
+ if int(idx) < 0 or int(idx) >= len(self.id_to_token):
275
+ continue
276
+ tok = self.id_to_token[int(idx)]
277
+ if skip_special and tok in self.special:
278
+ continue
279
+ pieces.append(tok)
280
+ return "".join(pieces)
281
+
282
+ def save(self, path: Path) -> None:
283
+ with path.open("w", encoding="utf-8") as f:
284
+ json.dump(
285
+ {
286
+ "id_to_token": self.id_to_token,
287
+ "format_tokens": self.format_tokens,
288
+ "core_special": self.core_special,
289
+ "tokenizer_type": "word_level_v1",
290
+ },
291
+ f,
292
+ ensure_ascii=False,
293
+ indent=2,
294
+ )
295
+
296
+ @classmethod
297
+ def load(cls, path: Path) -> WordTokenizer:
298
+ with path.open("r", encoding="utf-8") as f:
299
+ data = json.load(f)
300
+ format_tokens = data.get("format_tokens", FORMAT_TOKENS)
301
+ tokenizer = cls(extra_chars="", format_tokens=format_tokens)
302
+ tokenizer.id_to_token = data["id_to_token"]
303
+ tokenizer.token_to_id = {t: i for i, t in enumerate(tokenizer.id_to_token)}
304
+ tokenizer.special = list(tokenizer.core_special) + list(tokenizer.format_tokens)
305
+ tokenizer.special_multi_tokens = sorted(
306
+ [t for t in tokenizer.special if len(t) > 1], key=len, reverse=True
307
+ )
308
+ tokenizer.multi_char_tokens = tokenizer.special_multi_tokens
309
+ return tokenizer
310
+
311
+
312
+ LetterTokenizer = WordTokenizer
313
+
314
+
315
+ # ---------------------------------------------------------------------------
316
+ # Model (from ailay.model)
317
+ # ---------------------------------------------------------------------------
318
+
319
+
320
+ class RMSNorm(nn.Module):
321
+ def __init__(self, dim: int, eps: float = 1e-6) -> None:
322
+ super().__init__()
323
+ self.weight = nn.Parameter(torch.ones(dim))
324
+ self.eps = eps
325
+
326
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
327
+ if hasattr(torch.nn.functional, "rms_norm"):
328
+ return torch.nn.functional.rms_norm(
329
+ x, self.weight.shape, self.weight, self.eps
330
+ )
331
+ x_fp = x.float()
332
+ rms = torch.rsqrt(x_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps)
333
+ return (x_fp * rms).to(dtype=x.dtype) * self.weight
334
+
335
+
336
+ class RotaryEmbedding(nn.Module):
337
+ def __init__(self, dim: int, base: float = 10000.0) -> None:
338
+ super().__init__()
339
+ inv = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
340
+ self.register_buffer("inv_freq", inv, persistent=False)
341
+
342
+ def cos_sin(
343
+ self, seq_len: int, device: torch.device, dtype: torch.dtype
344
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
345
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
346
+ freqs = torch.outer(t, self.inv_freq)
347
+ emb = torch.cat([freqs, freqs], dim=-1)
348
+ cos = emb.cos()[None, None, :, :].to(dtype=dtype)
349
+ sin = emb.sin()[None, None, :, :].to(dtype=dtype)
350
+ return cos, sin
351
+
352
+
353
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
354
+ x1 = x[..., : x.shape[-1] // 2]
355
+ x2 = x[..., x.shape[-1] // 2 :]
356
+ return torch.cat((-x2, x1), dim=-1)
357
+
358
+
359
+ class CausalSelfAttention(nn.Module):
360
+ def __init__(
361
+ self,
362
+ dim: int,
363
+ n_heads: int,
364
+ n_kv_heads: int,
365
+ head_dim: int,
366
+ dropout: float,
367
+ sliding_window: int,
368
+ rope_fraction: float,
369
+ ) -> None:
370
+ super().__init__()
371
+ self.dim = dim
372
+ self.n_heads = n_heads
373
+ self.n_kv_heads = n_kv_heads
374
+ self.head_dim = head_dim
375
+ self.n_rep = n_heads // n_kv_heads
376
+ self.dropout = dropout
377
+ self.sliding_window = sliding_window
378
+
379
+ self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
380
+ self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
381
+ self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
382
+ self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
383
+
384
+ self.rope_dim = max(2, int(head_dim * rope_fraction) // 2 * 2)
385
+ self.rope = RotaryEmbedding(self.rope_dim)
386
+
387
+ self.q_norm = RMSNorm(head_dim)
388
+ self.k_norm = RMSNorm(head_dim)
389
+
390
+ self.output_gate = nn.Parameter(torch.zeros(n_heads))
391
+
392
+ def forward(
393
+ self,
394
+ x: torch.Tensor,
395
+ is_global: bool,
396
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
397
+ use_cache: bool = False,
398
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
399
+ B, T, _ = x.shape
400
+
401
+ q = self.wq(x).view(B, T, self.n_heads, self.head_dim)
402
+ k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim)
403
+ v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim)
404
+
405
+ q = self.q_norm(q)
406
+ k = self.k_norm(k)
407
+
408
+ q = q.transpose(1, 2)
409
+ k = k.transpose(1, 2)
410
+ v = v.transpose(1, 2)
411
+
412
+ past_len = past_kv[0].shape[2] if past_kv is not None else 0
413
+ cos, sin = self.rope.cos_sin(T + past_len, x.device, q.dtype)
414
+ cos_slice = cos[:, :, past_len : past_len + T, :]
415
+ sin_slice = sin[:, :, past_len : past_len + T, :]
416
+
417
+ q_rope = q[..., : self.rope_dim]
418
+ q_pass = q[..., self.rope_dim :]
419
+ k_rope = k[..., : self.rope_dim]
420
+ k_pass = k[..., self.rope_dim :]
421
+
422
+ q_rope = (q_rope * cos_slice) + (_rotate_half(q_rope) * sin_slice)
423
+ k_rope = (k_rope * cos_slice) + (_rotate_half(k_rope) * sin_slice)
424
+
425
+ q = torch.cat([q_rope, q_pass], dim=-1)
426
+ k = torch.cat([k_rope, k_pass], dim=-1)
427
+
428
+ if past_kv is not None:
429
+ k = torch.cat([past_kv[0], k], dim=2)
430
+ v = torch.cat([past_kv[1], v], dim=2)
431
+
432
+ new_kv = (k, v) if use_cache else None
433
+
434
+ S = k.shape[2]
435
+ if self.n_rep > 1:
436
+ k = (
437
+ k[:, :, None, :, :]
438
+ .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
439
+ .reshape(B, self.n_heads, S, self.head_dim)
440
+ )
441
+ v = (
442
+ v[:, :, None, :, :]
443
+ .expand(B, self.n_kv_heads, self.n_rep, S, self.head_dim)
444
+ .reshape(B, self.n_heads, S, self.head_dim)
445
+ )
446
+
447
+ drop_p = self.dropout if self.training else 0.0
448
+
449
+ if is_global:
450
+ if past_kv is None and T > 1:
451
+ out = F.scaled_dot_product_attention(
452
+ q, k, v, is_causal=True, dropout_p=drop_p
453
+ )
454
+ else:
455
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=drop_p)
456
+ else:
457
+ T_q = q.shape[2]
458
+ q_pos = torch.arange(past_len, past_len + T_q, device=q.device).unsqueeze(1)
459
+ k_pos = torch.arange(S, device=q.device).unsqueeze(0)
460
+ mask = (q_pos >= k_pos) & ((q_pos - k_pos) < self.sliding_window)
461
+ out = F.scaled_dot_product_attention(
462
+ q, k, v, attn_mask=mask.unsqueeze(0).unsqueeze(0), dropout_p=drop_p
463
+ )
464
+
465
+ gate = torch.sigmoid(self.output_gate).view(1, self.n_heads, 1, 1)
466
+ out = out * gate
467
+
468
+ out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
469
+ out = self.wo(out)
470
+
471
+ return out, new_kv
472
+
473
+
474
+ class SwiGLU(nn.Module):
475
+ def __init__(self, dim: int, hidden_dim: int, dropout: float) -> None:
476
+ super().__init__()
477
+ self.gate = nn.Linear(dim, hidden_dim, bias=False)
478
+ self.up = nn.Linear(dim, hidden_dim, bias=False)
479
+ self.down = nn.Linear(hidden_dim, dim, bias=False)
480
+ self.drop = nn.Dropout(dropout)
481
+
482
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
483
+ h = F.silu(self.gate(x)) * self.up(x)
484
+ return self.drop(self.down(h))
485
+
486
+
487
+ class TransformerBlock(nn.Module):
488
+ def __init__(
489
+ self,
490
+ dim: int,
491
+ n_heads: int,
492
+ n_kv_heads: int,
493
+ head_dim: int,
494
+ ffn_dim: int,
495
+ dropout: float,
496
+ sliding_window: int,
497
+ rope_fraction: float,
498
+ ) -> None:
499
+ super().__init__()
500
+ self.norm1 = RMSNorm(dim)
501
+ self.attn = CausalSelfAttention(
502
+ dim=dim,
503
+ n_heads=n_heads,
504
+ n_kv_heads=n_kv_heads,
505
+ head_dim=head_dim,
506
+ dropout=dropout,
507
+ sliding_window=sliding_window,
508
+ rope_fraction=rope_fraction,
509
+ )
510
+ self.norm2 = RMSNorm(dim)
511
+ self.ffn = SwiGLU(dim, ffn_dim, dropout)
512
+
513
+ def forward(
514
+ self,
515
+ x: torch.Tensor,
516
+ is_global: bool,
517
+ past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
518
+ use_cache: bool = False,
519
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
520
+ attn_out, new_kv = self.attn(self.norm1(x), is_global, past_kv, use_cache)
521
+ x = x + attn_out
522
+ x = x + self.ffn(self.norm2(x))
523
+ return x, new_kv
524
+
525
+
526
+ class TinyMemoryLM(nn.Module):
527
+ def __init__(
528
+ self,
529
+ vocab_size: int,
530
+ dim: int,
531
+ n_unique_layers: int,
532
+ n_logical_layers: int,
533
+ n_heads: int,
534
+ n_kv_heads: int,
535
+ ffn_dim: int,
536
+ dropout: float,
537
+ mtp_horizons: Sequence[int],
538
+ grad_checkpoint: bool,
539
+ sliding_window: int = 512,
540
+ rope_fraction: float = 0.25,
541
+ embed_scale: bool = True,
542
+ ) -> None:
543
+ super().__init__()
544
+ self.dim = dim
545
+ self.n_unique_layers = n_unique_layers
546
+ self.n_logical_layers = n_logical_layers
547
+ self.grad_checkpoint = grad_checkpoint
548
+ self.embed_scale_factor = math.sqrt(dim) if embed_scale else 1.0
549
+ head_dim = dim // n_heads
550
+
551
+ self.embed_tokens = nn.Embedding(vocab_size, dim)
552
+ self.head = nn.Linear(dim, vocab_size, bias=False)
553
+ self.head.weight = self.embed_tokens.weight
554
+
555
+ self.output_bias = nn.Parameter(torch.zeros(vocab_size))
556
+
557
+ self.blocks = nn.ModuleList(
558
+ [
559
+ TransformerBlock(
560
+ dim=dim,
561
+ n_heads=n_heads,
562
+ n_kv_heads=n_kv_heads,
563
+ head_dim=head_dim,
564
+ ffn_dim=ffn_dim,
565
+ dropout=dropout,
566
+ sliding_window=sliding_window,
567
+ rope_fraction=rope_fraction,
568
+ )
569
+ for _ in range(n_unique_layers)
570
+ ]
571
+ )
572
+ self.norm = RMSNorm(dim)
573
+
574
+ self.mtp_horizons = sorted({int(h) for h in mtp_horizons if int(h) > 1})
575
+ self.mtp_adapters = nn.ModuleDict(
576
+ {str(h): nn.Linear(dim, dim, bias=False) for h in self.mtp_horizons}
577
+ )
578
+ self.mtp_norms = nn.ModuleDict(
579
+ {str(h): RMSNorm(dim) for h in self.mtp_horizons}
580
+ )
581
+
582
+ res_scale = (2 * n_logical_layers) ** -0.5
583
+ for block in self.blocks:
584
+ block.attn.wo.weight.data.mul_(res_scale)
585
+ block.ffn.down.weight.data.mul_(res_scale)
586
+
587
+ def resize_token_embeddings(self, new_vocab_size: int) -> None:
588
+ old_vocab_size = self.embed_tokens.num_embeddings
589
+ if new_vocab_size == old_vocab_size:
590
+ return
591
+ device = self.embed_tokens.weight.device
592
+ old_embed_weight = self.embed_tokens.weight.data.clone()
593
+ self.embed_tokens = nn.Embedding(
594
+ new_vocab_size, self.embed_tokens.embedding_dim
595
+ ).to(device)
596
+ self.head = nn.Linear(
597
+ self.embed_tokens.embedding_dim, new_vocab_size, bias=False
598
+ ).to(device)
599
+ self.head.weight = self.embed_tokens.weight
600
+ old_bias = self.output_bias.data.clone()
601
+ self.output_bias = nn.Parameter(torch.zeros(new_vocab_size, device=device))
602
+ copy_size = min(old_vocab_size, new_vocab_size)
603
+ self.output_bias.data[:copy_size] = old_bias[:copy_size]
604
+ self.embed_tokens.weight.data[:copy_size] = old_embed_weight[:copy_size]
605
+
606
+ def _build_logical_layers(self) -> List[Tuple[nn.Module, int]]:
607
+ logical = []
608
+ blocks_list = list(self.blocks)
609
+ full_sequence = blocks_list + blocks_list
610
+ for logical_idx, block in enumerate(full_sequence[: self.n_logical_layers]):
611
+ logical.append((block, logical_idx))
612
+ return logical
613
+
614
+ def forward(
615
+ self,
616
+ ids: torch.Tensor,
617
+ use_cache: bool = False,
618
+ past_key_values: Optional[
619
+ List[Optional[Tuple[torch.Tensor, torch.Tensor]]]
620
+ ] = None,
621
+ return_hidden: bool = False,
622
+ ) -> Tuple[
623
+ torch.Tensor,
624
+ Dict[int, torch.Tensor],
625
+ Optional[torch.Tensor],
626
+ Optional[List[Tuple[torch.Tensor, torch.Tensor]]],
627
+ ]:
628
+ B, T = ids.shape
629
+ x = self.embed_tokens(ids) * self.embed_scale_factor
630
+
631
+ logical_layers = self._build_logical_layers()
632
+ new_past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = (
633
+ [] if use_cache else None
634
+ )
635
+
636
+ for layer_idx, (block, logical_idx) in enumerate(logical_layers):
637
+ is_global = logical_idx % 3 == 0
638
+ past_kv = (
639
+ past_key_values[layer_idx]
640
+ if past_key_values is not None and layer_idx < len(past_key_values)
641
+ else None
642
+ )
643
+
644
+ if self.grad_checkpoint and self.training and not use_cache:
645
+ x, layer_kv = checkpoint(
646
+ block, x, is_global, past_kv, use_cache, use_reentrant=False
647
+ )
648
+ else:
649
+ x, layer_kv = block(x, is_global, past_kv, use_cache)
650
+
651
+ if new_past_key_values is not None:
652
+ new_past_key_values.append(layer_kv)
653
+
654
+ x = self.norm(x)
655
+ h_out = x if return_hidden else None
656
+ logits = self.head(x) + self.output_bias
657
+
658
+ mtp: Dict[int, torch.Tensor] = {}
659
+ if self.mtp_horizons and self.training:
660
+ for horizon in self.mtp_horizons:
661
+ if horizon > 1 and horizon <= T - 1:
662
+ shifted_h = x[:, :-horizon, :]
663
+ adapted_h = self.mtp_adapters[str(horizon)](shifted_h)
664
+ adapted_h = self.mtp_norms[str(horizon)](adapted_h)
665
+ mtp_logits = self.head(adapted_h) + self.output_bias
666
+ mtp[horizon] = mtp_logits
667
+
668
+ return logits, mtp, h_out, new_past_key_values
669
+
670
+
671
+ # ---------------------------------------------------------------------------
672
+ # Generation (from ailay.generation)
673
+ # ---------------------------------------------------------------------------
674
+
675
+
676
+ def sample_text(
677
+ model: TinyMemoryLM,
678
+ tokenizer: WordTokenizer,
679
+ prompt: str,
680
+ max_new_tokens: int,
681
+ temperature: float,
682
+ top_k: int,
683
+ branches: int,
684
+ branch_len: int,
685
+ device: torch.device,
686
+ seq_len: int,
687
+ ) -> str:
688
+ def _sample_id(logits: torch.Tensor) -> torch.Tensor:
689
+ if not torch.isfinite(logits).any():
690
+ logits = torch.zeros_like(logits)
691
+ logits = torch.where(
692
+ torch.isfinite(logits), logits, torch.full_like(logits, -1e9)
693
+ )
694
+ if top_k > 0:
695
+ v, idx = torch.topk(logits, k=min(top_k, logits.shape[-1]))
696
+ p = torch.softmax(v, dim=-1)
697
+ return idx.gather(-1, torch.multinomial(p, 1))
698
+ p = torch.softmax(logits, dim=-1)
699
+ return torch.multinomial(p, 1)
700
+
701
+ model.eval()
702
+ ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
703
+ prompt_len = len(ids)
704
+ x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
705
+
706
+ with torch.no_grad():
707
+ generated = 0
708
+ while generated < max_new_tokens:
709
+ if branches <= 1:
710
+ ctx = x[:, -seq_len:]
711
+ logits, _, _, _ = model(ctx)
712
+ nlogits = logits[:, -1, :] / max(temperature, 1e-6)
713
+ nid = _sample_id(nlogits)
714
+ x = torch.cat([x, nid], dim=1)
715
+ generated += 1
716
+ continue
717
+ rollout = min(branch_len, max_new_tokens - generated)
718
+ best_nll: Optional[float] = None
719
+ best_tokens: Optional[List[torch.Tensor]] = None
720
+ for _ in range(branches):
721
+ cand = x
722
+ cand_tokens: List[torch.Tensor] = []
723
+ nll = 0.0
724
+ for _ in range(rollout):
725
+ ctx = cand[:, -seq_len:]
726
+ logits, _, _, _ = model(ctx)
727
+ nlogits = logits[:, -1, :] / max(temperature, 1e-6)
728
+ nid = _sample_id(nlogits)
729
+ lp = F.log_softmax(nlogits.float(), dim=-1)
730
+ nll += float(-lp.gather(-1, nid).item())
731
+ cand = torch.cat([cand, nid], dim=1)
732
+ cand_tokens.append(nid)
733
+ if best_nll is None or nll < best_nll:
734
+ best_nll = nll
735
+ best_tokens = cand_tokens
736
+ assert best_tokens is not None
737
+ for t in best_tokens:
738
+ x = torch.cat([x, t], dim=1)
739
+ generated += 1
740
+
741
+ generated_ids = x[0, prompt_len:].tolist()
742
+ return tokenizer.decode(generated_ids, skip_special=True)
743
+
744
+
745
+ def sample_text_cached(
746
+ model: TinyMemoryLM,
747
+ tokenizer: WordTokenizer,
748
+ prompt: str,
749
+ max_new_tokens: int,
750
+ temperature: float,
751
+ top_k: int,
752
+ device: torch.device,
753
+ seq_len: int,
754
+ ) -> str:
755
+ model.eval()
756
+ ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
757
+ prompt_len = len(ids)
758
+ x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
759
+
760
+ with torch.no_grad():
761
+ logits, _, _, past_kv = model(x, use_cache=True)
762
+ nlogits = logits[:, -1, :] / max(temperature, 1e-6)
763
+ if top_k > 0:
764
+ v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1]))
765
+ p = torch.softmax(v, dim=-1)
766
+ nid = idx.gather(-1, torch.multinomial(p, 1))
767
+ else:
768
+ p = torch.softmax(nlogits, dim=-1)
769
+ nid = torch.multinomial(p, 1)
770
+ all_ids = [int(nid.item())]
771
+
772
+ for _ in range(max_new_tokens - 1):
773
+ logits, _, _, past_kv = model(nid, use_cache=True, past_key_values=past_kv)
774
+ nlogits = logits[:, -1, :] / max(temperature, 1e-6)
775
+ if top_k > 0:
776
+ v, idx = torch.topk(nlogits, k=min(top_k, nlogits.shape[-1]))
777
+ p = torch.softmax(v, dim=-1)
778
+ nid = idx.gather(-1, torch.multinomial(p, 1))
779
+ else:
780
+ p = torch.softmax(nlogits, dim=-1)
781
+ nid = torch.multinomial(p, 1)
782
+ tid = int(nid.item())
783
+ all_ids.append(tid)
784
+ if tid == tokenizer.eos_id:
785
+ break
786
+
787
+ return tokenizer.decode(all_ids, skip_special=True)
788
+
789
+
790
+ def speculative_decode(
791
+ model: TinyMemoryLM,
792
+ tokenizer: WordTokenizer,
793
+ prompt: str,
794
+ max_new_tokens: int,
795
+ temperature: float,
796
+ top_k: int,
797
+ device: torch.device,
798
+ seq_len: int,
799
+ ) -> str:
800
+ model.eval()
801
+ ids = tokenizer.encode(prompt, add_bos=True, add_eos=False)
802
+ x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
803
+ all_generated: List[int] = []
804
+
805
+ with torch.no_grad():
806
+ logits, _, h_out, past_kv = model(x, use_cache=True, return_hidden=True)
807
+
808
+ def _sample_from(lg: torch.Tensor) -> int:
809
+ lg = lg / max(temperature, 1e-6)
810
+ if top_k > 0:
811
+ v, idx = torch.topk(lg, k=min(top_k, lg.shape[-1]))
812
+ p = torch.softmax(v, dim=-1)
813
+ return int(idx[torch.multinomial(p, 1)].item())
814
+ p = torch.softmax(lg, dim=-1)
815
+ return int(torch.multinomial(p, 1).item())
816
+
817
+ main_token = _sample_from(logits[0, -1, :])
818
+ all_generated.append(main_token)
819
+
820
+ while len(all_generated) < max_new_tokens:
821
+ if main_token == tokenizer.eos_id:
822
+ break
823
+
824
+ draft_tokens = []
825
+ if h_out is not None and model.mtp_horizons:
826
+ last_hidden = h_out[:, -1:, :]
827
+ for h in model.mtp_horizons:
828
+ adapter = model.mtp_adapters[str(h)]
829
+ norm = model.mtp_norms[str(h)]
830
+ adapted = norm(adapter(last_hidden))
831
+ draft_logits = model.head(adapted) + model.output_bias
832
+ draft_tok = _sample_from(draft_logits[0, 0, :])
833
+ draft_tokens.append(draft_tok)
834
+
835
+ if not draft_tokens:
836
+ nid = torch.tensor([[main_token]], dtype=torch.long, device=device)
837
+ logits, _, h_out, past_kv = model(
838
+ nid, use_cache=True, past_key_values=past_kv, return_hidden=True
839
+ )
840
+ main_token = _sample_from(logits[0, -1, :])
841
+ all_generated.append(main_token)
842
+ continue
843
+
844
+ verify_input = torch.tensor(
845
+ [[main_token] + draft_tokens], dtype=torch.long, device=device
846
+ )
847
+ verify_logits, _, h_out, past_kv = model(
848
+ verify_input,
849
+ use_cache=True,
850
+ past_key_values=past_kv,
851
+ return_hidden=True,
852
+ )
853
+
854
+ accepted = 0
855
+ all_generated.append(main_token) if main_token not in all_generated[
856
+ -1:
857
+ ] else None
858
+ for i, draft_tok in enumerate(draft_tokens):
859
+ verified_tok = _sample_from(verify_logits[0, i, :])
860
+ if verified_tok == draft_tok:
861
+ all_generated.append(draft_tok)
862
+ accepted += 1
863
+ if draft_tok == tokenizer.eos_id:
864
+ break
865
+ else:
866
+ all_generated.append(verified_tok)
867
+ break
868
+
869
+ if accepted < len(draft_tokens):
870
+ trim_len = len(draft_tokens) - accepted - 1
871
+ if trim_len > 0 and past_kv is not None:
872
+ past_kv = [
873
+ (k[:, :, :-trim_len, :], v[:, :, :-trim_len, :])
874
+ if k is not None
875
+ else None
876
+ for k, v in past_kv
877
+ ]
878
+
879
+ main_token = all_generated[-1]
880
+
881
+ return tokenizer.decode(all_generated, skip_special=True)
882
+
883
+
884
+ def build_stop_token_ids(tokenizer: WordTokenizer) -> set:
885
+ stop_tokens = {tokenizer.eos_id}
886
+ for tok in ("<|user|>", "<|system|>", "<|assistant|>"):
887
+ tid = tokenizer.token_to_id.get(tok)
888
+ if tid is not None:
889
+ stop_tokens.add(int(tid))
890
+ return stop_tokens
891
+
892
+
893
+ def apply_no_repeat_ngram(
894
+ logits: torch.Tensor,
895
+ token_history: Sequence[int],
896
+ ngram_size: int,
897
+ ) -> torch.Tensor:
898
+ if ngram_size <= 1 or len(token_history) < max(0, ngram_size - 1):
899
+ return logits
900
+ prefix = tuple(token_history[-(ngram_size - 1) :]) if ngram_size > 1 else tuple()
901
+ banned: set = set()
902
+ for i in range(len(token_history) - ngram_size + 1):
903
+ if tuple(token_history[i : i + ngram_size - 1]) == prefix:
904
+ banned.add(int(token_history[i + ngram_size - 1]))
905
+ if not banned:
906
+ return logits
907
+ out = logits.clone()
908
+ banned_ids = torch.tensor(sorted(banned), device=logits.device, dtype=torch.long)
909
+ out[banned_ids] = float("-inf")
910
+ return out
911
+
912
+
913
+ def score_candidate(
914
+ prompt: str,
915
+ raw_text: str,
916
+ visible_text: str,
917
+ avg_logprob: float,
918
+ ) -> float:
919
+ clean = visible_text.strip()
920
+ if not clean:
921
+ return -1e9
922
+ score = avg_logprob
923
+ words = clean.lower().split()
924
+ prompt_words = re.findall(r"[A-Za-z][A-Za-z'-]{2,}", prompt.lower())
925
+ prompt_stop = {
926
+ "what",
927
+ "which",
928
+ "when",
929
+ "where",
930
+ "why",
931
+ "how",
932
+ "are",
933
+ "is",
934
+ "the",
935
+ "and",
936
+ "for",
937
+ "with",
938
+ "that",
939
+ "this",
940
+ "from",
941
+ "into",
942
+ "about",
943
+ "explain",
944
+ "tell",
945
+ "give",
946
+ "list",
947
+ "show",
948
+ "write",
949
+ "their",
950
+ "there",
951
+ "your",
952
+ }
953
+ prompt_keywords = {w for w in prompt_words if w not in prompt_stop}
954
+ candidate_keywords = set(re.findall(r"[A-Za-z][A-Za-z'-]{2,}", clean.lower()))
955
+ if len(words) < 6:
956
+ score -= 2.0
957
+ else:
958
+ score += min(2.0, len(words) * 0.03)
959
+ if clean[-1:] in ".!?":
960
+ score += 0.5
961
+ if "<|user|>" in raw_text or "<|system|>" in raw_text:
962
+ score -= 4.0
963
+ if raw_text.count("<|assistant|>") > 1:
964
+ score -= 2.0
965
+ if prompt_keywords:
966
+ overlap = len(prompt_keywords & candidate_keywords) / len(prompt_keywords)
967
+ if overlap == 0.0:
968
+ score -= 2.5
969
+ else:
970
+ score += min(3.5, overlap * 4.0)
971
+ for open_tok, close_tok in [
972
+ ("<|begin_of_thought|>", "<|end_of_thought|>"),
973
+ ("<|begin_of_solution|>", "<|end_of_solution|>"),
974
+ ]:
975
+ if (open_tok in raw_text) != (close_tok in raw_text):
976
+ score -= 1.0
977
+ if len(words) >= 3:
978
+ trigrams = [tuple(words[i : i + 3]) for i in range(len(words) - 2)]
979
+ if trigrams:
980
+ unique_ratio = len(set(trigrams)) / len(trigrams)
981
+ if unique_ratio < 0.35:
982
+ score -= 4.0
983
+ elif unique_ratio < 0.55:
984
+ score -= 2.0
985
+ else:
986
+ score += min(1.0, (unique_ratio - 0.55) * 2.0)
987
+ alpha_words = [
988
+ w
989
+ for w in words
990
+ if len(w) <= 18 and (sum(ch.isalpha() for ch in w) / max(len(w), 1)) > 0.7
991
+ ]
992
+ alpha_ratio = len(alpha_words) / max(len(words), 1)
993
+ if alpha_ratio < 0.45:
994
+ score -= 3.0
995
+ elif alpha_ratio < 0.65:
996
+ score -= 1.0
997
+ return score
998
+
999
+
1000
+ def generate_candidate(
1001
+ model: TinyMemoryLM,
1002
+ tokenizer: WordTokenizer,
1003
+ prompt: str,
1004
+ max_new_tokens: int,
1005
+ temperature: float,
1006
+ top_k: int,
1007
+ repetition_penalty: float,
1008
+ no_repeat_ngram_size: int,
1009
+ device: str,
1010
+ sft_mode: bool,
1011
+ force_thought: bool,
1012
+ stream: bool,
1013
+ context_window: int,
1014
+ ) -> Tuple[str, str, float, int]:
1015
+ if sft_mode:
1016
+ full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
1017
+ else:
1018
+ full_prompt = prompt
1019
+ if force_thought:
1020
+ full_prompt = f"{full_prompt}<|begin_of_thought|> "
1021
+ input_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
1022
+ input_ids_t = torch.tensor([input_ids], dtype=torch.long, device=device)
1023
+ visible_tokens: List[str] = []
1024
+ raw_tokens: List[str] = []
1025
+ stop_token_ids = build_stop_token_ids(tokenizer)
1026
+ total_logprob = 0.0
1027
+ sampled_tokens = 0
1028
+ with torch.no_grad():
1029
+ for _ in range(max_new_tokens):
1030
+ ctx_ids = (
1031
+ input_ids_t[:, -context_window:] if context_window > 0 else input_ids_t
1032
+ )
1033
+ logits, _, _, _ = model(ctx_ids)
1034
+ next_logits = logits[0, -1, :].clone()
1035
+ raw_next_logits = next_logits.clone()
1036
+ if repetition_penalty != 1.0:
1037
+ seen = set(input_ids_t[0].tolist())
1038
+ for token_id in seen:
1039
+ if next_logits[token_id] > 0:
1040
+ next_logits[token_id] /= repetition_penalty
1041
+ else:
1042
+ next_logits[token_id] *= repetition_penalty
1043
+ if temperature != 1.0:
1044
+ next_logits = next_logits / max(temperature, 1e-6)
1045
+ if no_repeat_ngram_size > 1:
1046
+ next_logits = apply_no_repeat_ngram(
1047
+ next_logits,
1048
+ input_ids_t[0].tolist(),
1049
+ no_repeat_ngram_size,
1050
+ )
1051
+ if top_k > 0:
1052
+ v, _ = torch.topk(next_logits, min(top_k, next_logits.size(0)))
1053
+ next_logits[next_logits < v[-1]] = float("-inf")
1054
+ top_p = 0.9
1055
+ if top_p < 1.0:
1056
+ sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
1057
+ cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
1058
+ remove_mask = cum_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
1059
+ sorted_logits[remove_mask] = float("-inf")
1060
+ next_logits = sorted_logits.scatter(0, sorted_indices, sorted_logits)
1061
+ if not torch.isfinite(next_logits).any():
1062
+ next_logits = raw_next_logits
1063
+ if temperature != 1.0:
1064
+ next_logits = next_logits / max(temperature, 1e-6)
1065
+ probs = torch.softmax(next_logits, dim=-1)
1066
+ next_id = torch.multinomial(probs, num_samples=1).item()
1067
+ total_logprob += float(torch.log(probs[next_id] + 1e-12).item())
1068
+ sampled_tokens += 1
1069
+ if next_id in stop_token_ids:
1070
+ break
1071
+ token_str = (
1072
+ tokenizer.id_to_token[next_id]
1073
+ if next_id < len(tokenizer.id_to_token)
1074
+ else ""
1075
+ )
1076
+ raw_tokens.append(token_str)
1077
+ if token_str not in tokenizer.special:
1078
+ visible_tokens.append(token_str)
1079
+ if stream:
1080
+ print(token_str, end="", flush=True)
1081
+ input_ids_t = torch.cat(
1082
+ [input_ids_t, torch.tensor([[next_id]], device=device)], dim=1
1083
+ )
1084
+ if stream:
1085
+ print()
1086
+ avg_logprob = total_logprob / max(1, sampled_tokens)
1087
+ return "".join(visible_tokens), "".join(raw_tokens), avg_logprob, 0
1088
+
1089
+
1090
+ def generate_beam_search(
1091
+ model: TinyMemoryLM,
1092
+ tokenizer: WordTokenizer,
1093
+ prompt: str,
1094
+ max_new_tokens: int = 60,
1095
+ beam_width: int = 8,
1096
+ length_penalty: float = 0.7,
1097
+ no_repeat_ngram_size: int = 3,
1098
+ device: str = "cuda",
1099
+ sft_mode: bool = False,
1100
+ context_window: int = 2048,
1101
+ ) -> str:
1102
+ if sft_mode:
1103
+ full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
1104
+ else:
1105
+ full_prompt = prompt
1106
+ prompt_ids = tokenizer.encode(full_prompt, add_bos=True, add_eos=False)
1107
+ prompt_len = len(prompt_ids)
1108
+ stop_ids = build_stop_token_ids(tokenizer)
1109
+ beams: List[Tuple[float, List[int]]] = [(0.0, list(prompt_ids))]
1110
+ completed: List[Tuple[float, List[int]]] = []
1111
+ for _step in range(max_new_tokens):
1112
+ if not beams:
1113
+ break
1114
+ candidates: List[Tuple[float, List[int]]] = []
1115
+ for beam_score, beam_ids in beams:
1116
+ x = torch.tensor(
1117
+ [beam_ids[-context_window:]], dtype=torch.long, device=device
1118
+ )
1119
+ with torch.no_grad():
1120
+ logits, _, _, _ = model(x)
1121
+ nl = logits[0, -1, :]
1122
+ log_probs = F.log_softmax(nl, dim=-1)
1123
+ gen_ids = beam_ids[prompt_len:]
1124
+ if no_repeat_ngram_size > 1 and len(gen_ids) >= no_repeat_ngram_size - 1:
1125
+ prefix = tuple(gen_ids[-(no_repeat_ngram_size - 1) :])
1126
+ for i in range(len(gen_ids) - no_repeat_ngram_size + 1):
1127
+ if tuple(gen_ids[i : i + no_repeat_ngram_size - 1]) == prefix:
1128
+ log_probs[gen_ids[i + no_repeat_ngram_size - 1]] = float("-inf")
1129
+ topk_lp, topk_ids = torch.topk(log_probs, beam_width)
1130
+ for i in range(beam_width):
1131
+ tid = topk_ids[i].item()
1132
+ new_score = beam_score + topk_lp[i].item()
1133
+ new_ids = beam_ids + [tid]
1134
+ if tid in stop_ids:
1135
+ completed.append((new_score, new_ids))
1136
+ else:
1137
+ candidates.append((new_score, new_ids))
1138
+
1139
+ def _norm_score(pair):
1140
+ gen_len = max(1, len(pair[1]) - prompt_len)
1141
+ return pair[0] / (gen_len**length_penalty)
1142
+
1143
+ candidates.sort(key=_norm_score, reverse=True)
1144
+ beams = candidates[:beam_width]
1145
+
1146
+ pool = completed + beams
1147
+ if not pool:
1148
+ return ""
1149
+
1150
+ def _norm_score_final(pair):
1151
+ gen_len = max(1, len(pair[1]) - prompt_len)
1152
+ return pair[0] / (gen_len**length_penalty)
1153
+
1154
+ pool.sort(key=_norm_score_final, reverse=True)
1155
+ best_ids = pool[0][1][prompt_len:]
1156
+ text = tokenizer.decode(best_ids, skip_special=True)
1157
+ nl_pos = text.find("\n")
1158
+ if nl_pos > 5:
1159
+ text = text[:nl_pos]
1160
+ return text.strip()
1161
+
1162
+
1163
+ def generate(
1164
+ model: TinyMemoryLM,
1165
+ tokenizer: WordTokenizer,
1166
+ prompt: str,
1167
+ max_new_tokens: int = 256,
1168
+ temperature: float = 0.8,
1169
+ top_k: int = 40,
1170
+ repetition_penalty: float = 1.0,
1171
+ device: str = "cuda",
1172
+ sft_mode: bool = False,
1173
+ force_thought: bool = False,
1174
+ stream: bool = True,
1175
+ decode_mode: str = "legacy",
1176
+ best_of: int = 3,
1177
+ no_repeat_ngram_size: int = 3,
1178
+ context_window: int = 2048,
1179
+ beam_width: int = 8,
1180
+ length_penalty: float = 0.7,
1181
+ ) -> str:
1182
+ if decode_mode == "beam":
1183
+ text = generate_beam_search(
1184
+ model=model,
1185
+ tokenizer=tokenizer,
1186
+ prompt=prompt,
1187
+ max_new_tokens=max_new_tokens,
1188
+ beam_width=beam_width,
1189
+ length_penalty=length_penalty,
1190
+ no_repeat_ngram_size=no_repeat_ngram_size,
1191
+ device=device,
1192
+ sft_mode=sft_mode,
1193
+ context_window=context_window,
1194
+ )
1195
+ if stream:
1196
+ print(text)
1197
+ return text
1198
+ if decode_mode == "legacy":
1199
+ text, _, _, _ = generate_candidate(
1200
+ model=model,
1201
+ tokenizer=tokenizer,
1202
+ prompt=prompt,
1203
+ max_new_tokens=max_new_tokens,
1204
+ temperature=temperature,
1205
+ top_k=top_k,
1206
+ repetition_penalty=repetition_penalty,
1207
+ no_repeat_ngram_size=no_repeat_ngram_size,
1208
+ device=device,
1209
+ sft_mode=sft_mode,
1210
+ force_thought=force_thought,
1211
+ stream=stream,
1212
+ context_window=context_window,
1213
+ )
1214
+ return text
1215
+ candidates: List[Tuple[float, str, str, float]] = []
1216
+ for _ in range(max(1, best_of)):
1217
+ candidate_text, raw_text, avg_logprob, _ = generate_candidate(
1218
+ model=model,
1219
+ tokenizer=tokenizer,
1220
+ prompt=prompt,
1221
+ max_new_tokens=max_new_tokens,
1222
+ temperature=temperature,
1223
+ top_k=top_k,
1224
+ repetition_penalty=repetition_penalty,
1225
+ no_repeat_ngram_size=no_repeat_ngram_size,
1226
+ device=device,
1227
+ sft_mode=sft_mode,
1228
+ force_thought=force_thought,
1229
+ stream=False,
1230
+ context_window=context_window,
1231
+ )
1232
+ score = score_candidate(prompt, raw_text, candidate_text, avg_logprob)
1233
+ candidates.append((score, candidate_text, raw_text, avg_logprob))
1234
+ best_score, best_text, _, _ = max(candidates, key=lambda item: item[0])
1235
+ if stream:
1236
+ print(best_text, end="", flush=True)
1237
+ print()
1238
+ return best_text
1239
+
1240
+
1241
+ # ---------------------------------------------------------------------------
1242
+ # Web server (from interactive.py)
1243
+ # ---------------------------------------------------------------------------
1244
+
1245
+ ROOT = Path(__file__).resolve().parent
1246
+ if str(ROOT) not in sys.path:
1247
+ sys.path.insert(0, str(ROOT))
1248
+
1249
+
1250
+ HF_ORG = "CompactAI"
1251
+ HF_API = "https://huggingface.co/api"
1252
+ CACHE_ROOT = Path.home() / ".cache" / "compactai_web"
1253
+ USER_AGENT = "Mozilla/5.0 CompactAI-Web"
1254
+ MODEL_CACHE: dict[tuple[str, str], dict[str, object]] = {}
1255
+ MODEL_CACHE_LOCK = threading.RLock()
1256
+ GENERATION_LOCK = threading.Lock()
1257
+
1258
+
1259
+ def request_json(url: str):
1260
+ req = Request(url, headers={"User-Agent": USER_AGENT})
1261
+ with urlopen(req, timeout=60) as response:
1262
+ return json.loads(response.read().decode("utf-8"))
1263
+
1264
+
1265
+ def request_text(url: str) -> str:
1266
+ req = Request(url, headers={"User-Agent": USER_AGENT})
1267
+ with urlopen(req, timeout=60) as response:
1268
+ return response.read().decode("utf-8", errors="replace")
1269
+
1270
+
1271
+ def download_file(url: str, destination: Path) -> None:
1272
+ destination.parent.mkdir(parents=True, exist_ok=True)
1273
+ temp_path = destination.with_suffix(destination.suffix + ".tmp")
1274
+ req = Request(url, headers={"User-Agent": USER_AGENT})
1275
+ with urlopen(req, timeout=120) as response, temp_path.open("wb") as handle:
1276
+ shutil.copyfileobj(response, handle)
1277
+ temp_path.replace(destination)
1278
+
1279
+
1280
+ def normalize_repo_id(raw_repo_id: str) -> str:
1281
+ if not isinstance(raw_repo_id, str):
1282
+ return ""
1283
+ repo_id = raw_repo_id.strip()
1284
+ if not repo_id:
1285
+ return ""
1286
+ try:
1287
+ repo_id = unquote(repo_id)
1288
+ except Exception:
1289
+ pass
1290
+ return (
1291
+ repo_id.replace("https://huggingface.co/", "")
1292
+ .replace("http://huggingface.co/", "")
1293
+ .replace("api/models/", "")
1294
+ .replace("models/", "")
1295
+ .split("?", 1)[0]
1296
+ .split("#", 1)[0]
1297
+ .strip("/")
1298
+ )
1299
+
1300
+
1301
+ def series_from_name(name: str) -> str | None:
1302
+ lower = (name or "").lower()
1303
+ if "haiku" in lower:
1304
+ return "Haiku"
1305
+ if "sonnet" in lower:
1306
+ return "Sonnet"
1307
+ if "opus" in lower:
1308
+ return "Opus"
1309
+ return None
1310
+
1311
+
1312
+ def encoded_repo_id(repo_id: str) -> str:
1313
+ return "/".join(
1314
+ quote(part, safe="") for part in normalize_repo_id(repo_id).split("/") if part
1315
+ )
1316
+
1317
+
1318
+ def hf_file_url(repo_id: str, filename: str) -> str:
1319
+ encoded_name = "/".join(
1320
+ quote(part, safe="") for part in filename.split("/") if part
1321
+ )
1322
+ return (
1323
+ f"https://huggingface.co/{encoded_repo_id(repo_id)}/resolve/main/{encoded_name}"
1324
+ )
1325
+
1326
+
1327
+ def model_list() -> list[dict[str, object]]:
1328
+ data = request_json(f"{HF_API}/models?author={quote(HF_ORG)}&full=true&limit=200")
1329
+ models: list[dict[str, object]] = []
1330
+ for item in data:
1331
+ siblings = item.get("siblings") or []
1332
+ filenames = [s.get("rfilename", "") for s in siblings if isinstance(s, dict)]
1333
+ has_model = "model.pt" in filenames or "model/model.pt" in filenames
1334
+ has_pretrain = "pretrain.pt" in filenames or "model/pretrain.pt" in filenames
1335
+ has_tokenizer = (
1336
+ "tokenizer.json" in filenames or "model/tokenizer.json" in filenames
1337
+ )
1338
+ if not has_model and not has_pretrain:
1339
+ continue
1340
+ name = (item.get("id") or "").split("/")[-1]
1341
+ series = series_from_name(name)
1342
+ if not series:
1343
+ continue
1344
+ models.append(
1345
+ {
1346
+ "id": item.get("id", ""),
1347
+ "name": name,
1348
+ "series": series,
1349
+ "downloads": item.get("downloads", 0) or 0,
1350
+ "likes": item.get("likes", 0) or 0,
1351
+ "has_model": has_model,
1352
+ "has_pretrain": has_pretrain,
1353
+ "has_tokenizer": has_tokenizer,
1354
+ }
1355
+ )
1356
+ return sorted(models, key=lambda entry: entry["downloads"], reverse=True)
1357
+
1358
+
1359
+ def model_details(repo_id: str) -> dict[str, object] | None:
1360
+ normalized = normalize_repo_id(repo_id)
1361
+ if not normalized:
1362
+ return None
1363
+ data = request_json(f"{HF_API}/models/{encoded_repo_id(normalized)}")
1364
+ siblings = data.get("siblings") or []
1365
+ files: dict[str, dict[str, float]] = {}
1366
+ has_model = False
1367
+ has_pretrain = False
1368
+ for sibling in siblings:
1369
+ if not isinstance(sibling, dict):
1370
+ continue
1371
+ filename = sibling.get("rfilename") or ""
1372
+ if not filename:
1373
+ continue
1374
+ size_mb = round((sibling.get("size") or 0) / (1024 * 1024), 2)
1375
+ files[filename] = {"size_mb": size_mb}
1376
+ if filename.startswith("model/"):
1377
+ files[filename.removeprefix("model/")] = {"size_mb": size_mb}
1378
+ if filename in {"model.pt", "model/model.pt"}:
1379
+ has_model = True
1380
+ if filename in {"pretrain.pt", "model/pretrain.pt"}:
1381
+ has_pretrain = True
1382
+ readme_raw = ""
1383
+ try:
1384
+ readme_raw = request_text(
1385
+ f"https://huggingface.co/{encoded_repo_id(normalized)}/raw/main/README.md"
1386
+ )
1387
+ except Exception:
1388
+ readme_raw = ""
1389
+ name = (data.get("id") or normalized).split("/")[-1]
1390
+ return {
1391
+ "id": normalized,
1392
+ "name": name,
1393
+ "series": series_from_name(name) or "Sonnet",
1394
+ "downloads": data.get("downloads", 0) or 0,
1395
+ "files": files,
1396
+ "readme_raw": readme_raw,
1397
+ "hf_model_id": normalized,
1398
+ "has_model": has_model,
1399
+ "has_pretrain": has_pretrain,
1400
+ }
1401
+
1402
+
1403
+ def cache_dir(repo_id: str, model_type: str) -> Path:
1404
+ return CACHE_ROOT / normalize_repo_id(repo_id).replace("/", "__") / model_type
1405
+
1406
+
1407
+ def artifact_candidates(model_type: str) -> list[str]:
1408
+ return (
1409
+ ["model/pretrain.pt", "pretrain.pt"]
1410
+ if model_type == "pretrain"
1411
+ else ["model/model.pt", "model.pt"]
1412
+ )
1413
+
1414
+
1415
+ def ensure_artifact(repo_id: str, model_type: str, destination_name: str) -> Path:
1416
+ normalized = normalize_repo_id(repo_id)
1417
+ target = cache_dir(normalized, model_type) / destination_name
1418
+ if target.exists():
1419
+ return target
1420
+ last_error: Exception | None = None
1421
+ for candidate in (
1422
+ artifact_candidates(model_type)
1423
+ if destination_name.endswith(".pt")
1424
+ else ["model/tokenizer.json", "tokenizer.json"]
1425
+ ):
1426
+ try:
1427
+ download_file(hf_file_url(normalized, candidate), target)
1428
+ return target
1429
+ except Exception as exc:
1430
+ last_error = exc
1431
+ raise RuntimeError(
1432
+ f"Unable to download {destination_name} for {normalized}: {last_error}"
1433
+ )
1434
+
1435
+
1436
+ def series_config(series: str) -> dict[str, object]:
1437
+ return MODEL_SERIES.get(series.lower(), MODEL_SERIES["sonnet"])
1438
+
1439
+
1440
+ def load_bundle(repo_id: str, model_type: str) -> dict[str, object]:
1441
+ normalized = normalize_repo_id(repo_id)
1442
+ details = model_details(normalized)
1443
+ if not details:
1444
+ raise RuntimeError("Model details are unavailable.")
1445
+ series = str(details["series"])
1446
+ key = (normalized, model_type)
1447
+ with MODEL_CACHE_LOCK:
1448
+ cached = MODEL_CACHE.get(key)
1449
+ if cached:
1450
+ return cached
1451
+ bundle_dir = cache_dir(normalized, model_type)
1452
+ bundle_dir.mkdir(parents=True, exist_ok=True)
1453
+ model_path = bundle_dir / (
1454
+ "pretrain.pt" if model_type == "pretrain" else "model.pt"
1455
+ )
1456
+ tokenizer_path = bundle_dir / "tokenizer.json"
1457
+ if not model_path.exists():
1458
+ ensure_artifact(normalized, model_type, model_path.name)
1459
+ if not tokenizer_path.exists():
1460
+ ensure_artifact(normalized, model_type, tokenizer_path.name)
1461
+ tokenizer = WordTokenizer.load(tokenizer_path)
1462
+ ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
1463
+ cfg = series_config(series)
1464
+ vocab_size = int(ckpt.get("vocab_size", tokenizer.vocab_size))
1465
+ model = TinyMemoryLM(
1466
+ vocab_size=vocab_size,
1467
+ dim=int(cfg.get("dim", model_config.dim)),
1468
+ n_unique_layers=int(
1469
+ cfg.get("n_unique_layers", model_config.n_unique_layers)
1470
+ ),
1471
+ n_logical_layers=int(
1472
+ cfg.get("n_logical_layers", model_config.n_logical_layers)
1473
+ ),
1474
+ n_heads=int(cfg.get("n_heads", model_config.n_heads)),
1475
+ n_kv_heads=int(cfg.get("n_kv_heads", model_config.n_kv_heads)),
1476
+ ffn_dim=int(cfg.get("ffn_dim", model_config.ffn_dim)),
1477
+ dropout=float(cfg.get("dropout", model_config.dropout)),
1478
+ mtp_horizons=tuple(
1479
+ int(v) for v in cfg.get("mtp_horizons", model_config.mtp_horizons)
1480
+ ),
1481
+ grad_checkpoint=False,
1482
+ sliding_window=int(
1483
+ cfg.get(
1484
+ "sliding_window_size",
1485
+ getattr(model_config, "sliding_window_size", 512),
1486
+ )
1487
+ ),
1488
+ rope_fraction=float(
1489
+ cfg.get("rope_fraction", getattr(model_config, "rope_fraction", 0.25))
1490
+ ),
1491
+ embed_scale=bool(
1492
+ cfg.get("embed_scale", getattr(model_config, "embed_scale", True))
1493
+ ),
1494
+ )
1495
+ state_dict = ckpt.get("model_state") or ckpt.get("state_dict") or ckpt
1496
+ model.load_state_dict(state_dict, strict=False)
1497
+ model.eval()
1498
+ if tokenizer.vocab_size > vocab_size:
1499
+ model.resize_token_embeddings(tokenizer.vocab_size)
1500
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1501
+ model = model.to(device)
1502
+ bundle = {
1503
+ "repo_id": normalized,
1504
+ "name": details["name"],
1505
+ "series": series,
1506
+ "type": model_type,
1507
+ "model": model,
1508
+ "tokenizer": tokenizer,
1509
+ "device": device,
1510
+ "model_path": str(model_path),
1511
+ "tokenizer_path": str(tokenizer_path),
1512
+ "downloads": details["downloads"],
1513
+ }
1514
+ MODEL_CACHE[key] = bundle
1515
+ return bundle
1516
+
1517
+
1518
+ def ensure_port(start_port: int) -> int:
1519
+ for port in range(start_port, start_port + 50):
1520
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1521
+ try:
1522
+ sock.bind(("127.0.0.1", port))
1523
+ except OSError:
1524
+ continue
1525
+ return port
1526
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
1527
+ sock.bind(("127.0.0.1", 0))
1528
+ return sock.getsockname()[1]
1529
+
1530
+
1531
+ def page_html() -> str:
1532
+ return f"""<!doctype html>
1533
+ <html lang="en">
1534
+ <head>
1535
+ <meta charset="utf-8">
1536
+ <meta name="viewport" content="width=device-width, initial-scale=1">
1537
+ <title>CompactAI Web</title>
1538
+ <style>
1539
+ :root {{
1540
+ color-scheme: dark;
1541
+ --bg: #050505;
1542
+ --panel: #111111;
1543
+ --panel-2: #161616;
1544
+ --line: #262626;
1545
+ --text: #f5f5f5;
1546
+ --muted: #a3a3a3;
1547
+ --accent: #d97706;
1548
+ --accent-2: #b45309;
1549
+ --soft: #1f1f1f;
1550
+ }}
1551
+ * {{ box-sizing: border-box; }}
1552
+ body {{
1553
+ margin: 0;
1554
+ font-family: Geist, -apple-system, BlinkMacSystemFont, sans-serif;
1555
+ background: var(--bg);
1556
+ color: var(--text);
1557
+ line-height: 1.5;
1558
+ }}
1559
+ a {{ color: inherit; }}
1560
+ .wrap {{ max-width: 1120px; margin: 0 auto; padding: 28px 20px 40px; }}
1561
+ .hero {{
1562
+ display: flex;
1563
+ justify-content: space-between;
1564
+ align-items: end;
1565
+ gap: 16px;
1566
+ padding: 22px 0 28px;
1567
+ border-bottom: 1px solid var(--line);
1568
+ margin-bottom: 22px;
1569
+ }}
1570
+ h1 {{ margin: 0; font-size: clamp(2rem, 5vw, 3.5rem); letter-spacing: -0.04em; }}
1571
+ .subtitle {{ margin: 10px 0 0; color: var(--muted); max-width: 58ch; }}
1572
+ .grid {{
1573
+ display: grid;
1574
+ grid-template-columns: 1.1fr 1fr;
1575
+ gap: 18px;
1576
+ }}
1577
+ .panel {{
1578
+ background: var(--panel);
1579
+ border: 1px solid var(--line);
1580
+ border-radius: 18px;
1581
+ padding: 18px;
1582
+ }}
1583
+ .panel h2 {{ margin: 0 0 12px; font-size: 15px; letter-spacing: 0.02em; text-transform: uppercase; color: var(--muted); }}
1584
+ .row {{ display: flex; gap: 10px; flex-wrap: wrap; }}
1585
+ select, textarea, input {{
1586
+ width: 100%;
1587
+ background: var(--panel-2);
1588
+ color: var(--text);
1589
+ border: 1px solid var(--line);
1590
+ border-radius: 12px;
1591
+ padding: 12px 14px;
1592
+ font: inherit;
1593
+ outline: none;
1594
+ }}
1595
+ textarea {{ min-height: 170px; resize: vertical; }}
1596
+ select {{ appearance: none; }}
1597
+ .choice {{
1598
+ flex: 1 1 150px;
1599
+ display: flex;
1600
+ align-items: center;
1601
+ gap: 10px;
1602
+ padding: 10px 12px;
1603
+ border: 1px solid var(--line);
1604
+ border-radius: 12px;
1605
+ background: var(--panel-2);
1606
+ cursor: pointer;
1607
+ }}
1608
+ .choice input {{ width: auto; }}
1609
+ .btns {{ display: flex; flex-wrap: wrap; gap: 10px; }}
1610
+ button {{
1611
+ border: 1px solid var(--line);
1612
+ border-radius: 12px;
1613
+ padding: 11px 14px;
1614
+ background: var(--soft);
1615
+ color: var(--text);
1616
+ font: inherit;
1617
+ cursor: pointer;
1618
+ transition: transform 0.15s ease, border-color 0.15s ease, background 0.15s ease;
1619
+ }}
1620
+ button:hover {{ transform: translateY(-1px); border-color: #3a3a3a; }}
1621
+ .primary {{ background: var(--accent); border-color: var(--accent); color: #fff; }}
1622
+ .primary:hover {{ background: var(--accent-2); border-color: var(--accent-2); }}
1623
+ .status {{
1624
+ margin-top: 12px;
1625
+ color: var(--muted);
1626
+ font-size: 13px;
1627
+ min-height: 1.4em;
1628
+ }}
1629
+ .output {{
1630
+ white-space: pre-wrap;
1631
+ background: #0b0b0b;
1632
+ border: 1px solid var(--line);
1633
+ border-radius: 16px;
1634
+ min-height: 280px;
1635
+ padding: 16px;
1636
+ color: #e7e5e4;
1637
+ overflow: auto;
1638
+ }}
1639
+ .meta {{
1640
+ display: flex;
1641
+ flex-wrap: wrap;
1642
+ gap: 8px;
1643
+ margin-top: 8px;
1644
+ }}
1645
+ .chip {{
1646
+ display: inline-flex;
1647
+ align-items: center;
1648
+ gap: 6px;
1649
+ padding: 6px 10px;
1650
+ border-radius: 999px;
1651
+ border: 1px solid var(--line);
1652
+ background: var(--panel-2);
1653
+ font-size: 12px;
1654
+ color: var(--muted);
1655
+ }}
1656
+ .code {{
1657
+ margin-top: 14px;
1658
+ padding: 12px 14px;
1659
+ border-radius: 12px;
1660
+ border: 1px solid var(--line);
1661
+ background: #0b0b0b;
1662
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
1663
+ font-size: 13px;
1664
+ overflow-x: auto;
1665
+ }}
1666
+ @media (max-width: 900px) {{
1667
+ .grid {{ grid-template-columns: 1fr; }}
1668
+ .hero {{ align-items: start; flex-direction: column; }}
1669
+ }}
1670
+ </style>
1671
+ </head>
1672
+ <body>
1673
+ <div class="wrap">
1674
+ <div class="hero">
1675
+ <div>
1676
+ <h1>CompactAI Web</h1>
1677
+ <p class="subtitle">Pull a model from Hugging Face, keep it cached locally, and chat in the browser.</p>
1678
+ </div>
1679
+ <div class="meta">
1680
+ <span class="chip">Hugging Face: CompactAI</span>
1681
+ <span class="chip">Auto-installs deps</span>
1682
+ <span class="chip">Local inference</span>
1683
+ </div>
1684
+ </div>
1685
+
1686
+ <div class="grid">
1687
+ <section class="panel">
1688
+ <h2>Model</h2>
1689
+ <select id="modelSelect"></select>
1690
+ <div class="row" style="margin-top: 10px;">
1691
+ <label class="choice"><input type="radio" name="type" value="model" checked> Instruct / final</label>
1692
+ <label class="choice"><input type="radio" name="type" value="pretrain"> Pretrain</label>
1693
+ </div>
1694
+ <div class="btns" style="margin-top: 12px;">
1695
+ <button id="downloadBtn">Download</button>
1696
+ <button id="refreshBtn">Refresh models</button>
1697
+ </div>
1698
+ <div class="status" id="modelStatus">Loading model list…</div>
1699
+ <div class="code">python3 interactive_web.py</div>
1700
+ </section>
1701
+
1702
+ <section class="panel">
1703
+ <h2>Prompt</h2>
1704
+ <textarea id="prompt" placeholder="Ask something…"></textarea>
1705
+ <div class="row" style="margin-top: 10px;">
1706
+ <input id="temperature" type="number" min="0.1" max="2" step="0.05" value="0.8" style="flex: 1 1 120px;">
1707
+ <input id="topK" type="number" min="1" max="100" step="1" value="40" style="flex: 1 1 120px;">
1708
+ <input id="maxTokens" type="number" min="16" max="2048" step="16" value="256" style="flex: 1 1 120px;">
1709
+ </div>
1710
+ <div class="btns" style="margin-top: 12px;">
1711
+ <button id="generateBtn" class="primary">Generate</button>
1712
+ </div>
1713
+ <div class="status" id="genStatus"></div>
1714
+ </section>
1715
+ </div>
1716
+
1717
+ <section class="panel" style="margin-top: 18px;">
1718
+ <h2>Response</h2>
1719
+ <div id="output" class="output"></div>
1720
+ </section>
1721
+ </div>
1722
+
1723
+ <script>
1724
+ const modelSelect = document.getElementById('modelSelect');
1725
+ const modelStatus = document.getElementById('modelStatus');
1726
+ const genStatus = document.getElementById('genStatus');
1727
+ const output = document.getElementById('output');
1728
+ const promptBox = document.getElementById('prompt');
1729
+
1730
+ async function api(path, body) {{
1731
+ const response = await fetch(path, {{
1732
+ method: body ? 'POST' : 'GET',
1733
+ headers: body ? {{ 'Content-Type': 'application/json' }} : undefined,
1734
+ body: body ? JSON.stringify(body) : undefined,
1735
+ }});
1736
+ return response.json();
1737
+ }}
1738
+
1739
+ function currentType() {{
1740
+ return document.querySelector('input[name="type"]:checked').value;
1741
+ }}
1742
+
1743
+ function currentModelId() {{
1744
+ return modelSelect.value;
1745
+ }}
1746
+
1747
+ function setModels(models) {{
1748
+ modelSelect.innerHTML = '';
1749
+ for (const model of models) {{
1750
+ const option = document.createElement('option');
1751
+ option.value = model.id;
1752
+ option.textContent = `${{model.name}} • ${{model.series}}`;
1753
+ modelSelect.appendChild(option);
1754
+ }}
1755
+ if (models.length === 0) {{
1756
+ const option = document.createElement('option');
1757
+ option.value = '';
1758
+ option.textContent = 'No CompactAI models found';
1759
+ modelSelect.appendChild(option);
1760
+ }}
1761
+ }}
1762
+
1763
+ async function refreshModels() {{
1764
+ modelStatus.textContent = 'Loading model list…';
1765
+ try {{
1766
+ const models = await api('/api/models');
1767
+ setModels(models);
1768
+ modelStatus.textContent = models.length ? `${{models.length}} models available from CompactAI` : 'No compatible models found.';
1769
+ }} catch (error) {{
1770
+ modelStatus.textContent = 'Failed to load model list.';
1771
+ }}
1772
+ }}
1773
+
1774
+ async function ensureModel() {{
1775
+ const modelId = currentModelId();
1776
+ if (!modelId) {{
1777
+ modelStatus.textContent = 'Pick a model first.';
1778
+ return null;
1779
+ }}
1780
+ modelStatus.textContent = 'Downloading model files…';
1781
+ const result = await api('/api/ensure', {{ modelId, type: currentType() }});
1782
+ if (!result.success) {{
1783
+ modelStatus.textContent = result.error || 'Download failed.';
1784
+ return null;
1785
+ }}
1786
+ modelStatus.textContent = `${{result.name}} ready on ${{result.series}}`;
1787
+ return result;
1788
+ }}
1789
+
1790
+ async function generate() {{
1791
+ output.textContent = '';
1792
+ genStatus.textContent = '';
1793
+ const modelId = currentModelId();
1794
+ const prompt = promptBox.value.trim();
1795
+ if (!modelId) {{
1796
+ genStatus.textContent = 'Pick a model first.';
1797
+ return;
1798
+ }}
1799
+ if (!prompt) {{
1800
+ genStatus.textContent = 'Enter a prompt first.';
1801
+ return;
1802
+ }}
1803
+ genStatus.textContent = 'Preparing model…';
1804
+ const result = await api('/api/generate', {{
1805
+ modelId,
1806
+ type: currentType(),
1807
+ prompt,
1808
+ temperature: Number(document.getElementById('temperature').value || 0.8),
1809
+ top_k: Number(document.getElementById('topK').value || 40),
1810
+ max_new_tokens: Number(document.getElementById('maxTokens').value || 256),
1811
+ }});
1812
+ if (!result.success) {{
1813
+ genStatus.textContent = result.error || 'Generation failed.';
1814
+ return;
1815
+ }}
1816
+ output.textContent = result.text || '';
1817
+ genStatus.textContent = 'Done.';
1818
+ }}
1819
+
1820
+ document.getElementById('refreshBtn').addEventListener('click', refreshModels);
1821
+ document.getElementById('downloadBtn').addEventListener('click', ensureModel);
1822
+ document.getElementById('generateBtn').addEventListener('click', generate);
1823
+ promptBox.addEventListener('keydown', (event) => {{
1824
+ if (event.key === 'Enter' && (event.ctrlKey || event.metaKey)) {{
1825
+ event.preventDefault();
1826
+ generate();
1827
+ }}
1828
+ }});
1829
+
1830
+ refreshModels();
1831
+ </script>
1832
+ </body>
1833
+ </html>"""
1834
+
1835
+
1836
+ class Handler(BaseHTTPRequestHandler):
1837
+ def _send_json(self, payload, status=200):
1838
+ body = json.dumps(payload).encode("utf-8")
1839
+ self.send_response(status)
1840
+ self.send_header("Content-Type", "application/json; charset=utf-8")
1841
+ self.send_header("Content-Length", str(len(body)))
1842
+ self.send_header("Cache-Control", "no-store")
1843
+ self.end_headers()
1844
+ self.wfile.write(body)
1845
+
1846
+ def _send_html(self, payload: str, status=200):
1847
+ body = payload.encode("utf-8")
1848
+ self.send_response(status)
1849
+ self.send_header("Content-Type", "text/html; charset=utf-8")
1850
+ self.send_header("Content-Length", str(len(body)))
1851
+ self.send_header("Cache-Control", "no-store")
1852
+ self.end_headers()
1853
+ self.wfile.write(body)
1854
+
1855
+ def do_GET(self):
1856
+ parsed = urlparse(self.path)
1857
+ if parsed.path in {"/", "/index.html"}:
1858
+ self._send_html(page_html())
1859
+ return
1860
+ if parsed.path == "/api/models":
1861
+ try:
1862
+ self._send_json(model_list())
1863
+ except Exception as exc:
1864
+ self._send_json({"success": False, "error": str(exc)}, 500)
1865
+ return
1866
+ if parsed.path.startswith("/api/models/"):
1867
+ repo_id = normalize_repo_id(parsed.path.removeprefix("/api/models/"))
1868
+ try:
1869
+ details = model_details(repo_id)
1870
+ if not details:
1871
+ self._send_json(
1872
+ {"success": False, "error": "Model not found."}, 404
1873
+ )
1874
+ else:
1875
+ self._send_json(details)
1876
+ except Exception as exc:
1877
+ self._send_json({"success": False, "error": str(exc)}, 500)
1878
+ return
1879
+ self._send_json({"success": False, "error": "Not found."}, 404)
1880
+
1881
+ def do_POST(self):
1882
+ parsed = urlparse(self.path)
1883
+ length = int(self.headers.get("Content-Length", "0") or "0")
1884
+ raw = self.rfile.read(length).decode("utf-8") if length else "{}"
1885
+ try:
1886
+ payload = json.loads(raw or "{}")
1887
+ except Exception:
1888
+ payload = {}
1889
+ if parsed.path == "/api/ensure":
1890
+ try:
1891
+ repo_id = normalize_repo_id(payload.get("modelId", ""))
1892
+ model_type = payload.get("type", "model")
1893
+ if not repo_id:
1894
+ self._send_json(
1895
+ {"success": False, "error": "Missing model ID."}, 400
1896
+ )
1897
+ return
1898
+ details = model_details(repo_id)
1899
+ if not details:
1900
+ self._send_json(
1901
+ {"success": False, "error": "Model not found."}, 404
1902
+ )
1903
+ return
1904
+ bundle = load_bundle(repo_id, model_type)
1905
+ self._send_json(
1906
+ {
1907
+ "success": True,
1908
+ "id": bundle["repo_id"],
1909
+ "name": bundle["name"],
1910
+ "series": bundle["series"],
1911
+ "type": bundle["type"],
1912
+ }
1913
+ )
1914
+ except Exception as exc:
1915
+ self._send_json({"success": False, "error": str(exc)}, 500)
1916
+ return
1917
+ if parsed.path == "/api/generate":
1918
+ try:
1919
+ repo_id = normalize_repo_id(payload.get("modelId", ""))
1920
+ model_type = payload.get("type", "model")
1921
+ prompt = str(payload.get("prompt", ""))
1922
+ if not repo_id:
1923
+ self._send_json(
1924
+ {"success": False, "error": "Missing model ID."}, 400
1925
+ )
1926
+ return
1927
+ bundle = load_bundle(repo_id, model_type)
1928
+ with GENERATION_LOCK:
1929
+ text = generate(
1930
+ model=bundle["model"],
1931
+ tokenizer=bundle["tokenizer"],
1932
+ prompt=prompt,
1933
+ max_new_tokens=int(payload.get("max_new_tokens", 256)),
1934
+ temperature=float(payload.get("temperature", 0.8)),
1935
+ top_k=int(payload.get("top_k", 40)),
1936
+ repetition_penalty=float(
1937
+ payload.get("repetition_penalty", 1.0)
1938
+ ),
1939
+ device=str(bundle["device"]),
1940
+ sft_mode=model_type != "pretrain",
1941
+ force_thought=bool(payload.get("force_thought", False)),
1942
+ stream=False,
1943
+ decode_mode=str(payload.get("decode_mode", "legacy")),
1944
+ best_of=int(payload.get("best_of", 3)),
1945
+ no_repeat_ngram_size=int(
1946
+ payload.get("no_repeat_ngram_size", 3)
1947
+ ),
1948
+ context_window=int(payload.get("context_window", 2048)),
1949
+ beam_width=int(payload.get("beam_width", 8)),
1950
+ length_penalty=float(payload.get("length_penalty", 0.7)),
1951
+ )
1952
+ self._send_json(
1953
+ {
1954
+ "success": True,
1955
+ "text": text,
1956
+ "name": bundle["name"],
1957
+ "series": bundle["series"],
1958
+ }
1959
+ )
1960
+ except Exception as exc:
1961
+ self._send_json({"success": False, "error": str(exc)}, 500)
1962
+ return
1963
+ self._send_json({"success": False, "error": "Not found."}, 404)
1964
+
1965
+ def log_message(self, format, *args):
1966
+ return
1967
+
1968
+
1969
+ def main():
1970
+ CACHE_ROOT.mkdir(parents=True, exist_ok=True)
1971
+ port = ensure_port(int(os.environ.get("PORT", "7860")))
1972
+ server = ThreadingHTTPServer(("127.0.0.1", port), Handler)
1973
+ url = f"http://127.0.0.1:{port}"
1974
+ print(url, flush=True)
1975
+ try:
1976
+ webbrowser.open(url)
1977
+ except Exception:
1978
+ pass
1979
+ try:
1980
+ server.serve_forever()
1981
+ except KeyboardInterrupt:
1982
+ pass
1983
+ finally:
1984
+ server.server_close()
1985
 
1986
 
1987
  if __name__ == "__main__":