Taylor commited on
Commit
7336fde
Β·
1 Parent(s): c92238b

perf: add WASM SIMD kernels + use Q4_K_M for faster inference

Browse files

Major changes:
- Bundle simd-kernels-standalone.wasm (14KB) from Aether
- WASM SIMD matVec, rmsNorm, softmax, fusedSiluMul, flashAttention
- Switch from Q8_0 (360MB) to Q4_K_M (210MB, half the work)
- Reduce max_tokens to 50 for snappier demo
- Proper Q4_K dequantization with getScaleMinK4
- Falls back to JS if WASM SIMD unavailable

Files changed (4) hide show
  1. Dockerfile +1 -1
  2. aether-server.mjs +413 -429
  3. app.py +2 -2
  4. simd-kernels.wasm +3 -0
Dockerfile CHANGED
@@ -13,7 +13,7 @@ COPY requirements.txt .
13
  RUN pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
14
 
15
  # App files
16
- COPY app.py aether-server.mjs ./
17
 
18
  # Create cache dir
19
  RUN mkdir -p /tmp/hf_cache
 
13
  RUN pip install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu -r requirements.txt
14
 
15
  # App files
16
+ COPY app.py aether-server.mjs simd-kernels.wasm ./
17
 
18
  # Create cache dir
19
  RUN mkdir -p /tmp/hf_cache
aether-server.mjs CHANGED
@@ -1,15 +1,14 @@
1
  /**
2
  * Aether Inference Server
3
  *
4
- * Standalone Node.js server running SmolLM2-360M inference
5
- * using Aether's WASM-SIMD kernels. Zero external ML dependencies.
6
  *
7
- * The entire inference pipeline is pure TypeScript + WASM:
8
- * GGUF parse β†’ Q4_K dequant β†’ WASM-SIMD matVec β†’ RoPE β†’ SwiGLU β†’ sampling
9
  */
10
 
11
  import { createServer } from 'http';
12
- import { readFileSync, existsSync, writeFileSync } from 'fs';
13
  import { execSync } from 'child_process';
14
  import { fileURLToPath } from 'url';
15
  import { dirname, join } from 'path';
@@ -17,7 +16,7 @@ import { dirname, join } from 'path';
17
  const __dirname = dirname(fileURLToPath(import.meta.url));
18
  const PORT = parseInt(process.env.AETHER_PORT || '7861');
19
 
20
- // ─── Model Config (SmolLM2-360M-Instruct, LLaMA family) ────────────────────
21
  const CONFIG = {
22
  hiddenDim: 960,
23
  numLayers: 32,
@@ -33,278 +32,322 @@ const CONFIG = {
33
  bosToken: 1,
34
  };
35
 
36
- // ─── Q8_0 Dequantization ────────────────────────────────────────────────────
37
- // Q8_0: 34 bytes per block of 32 elements (fp16 scale + 32 int8 quants)
38
- const Q8_0_BLOCK_SIZE = 32;
39
- const Q8_0_BLOCK_BYTES = 34;
40
 
41
- function fp16ToF32(lo, hi) {
42
- const h = lo | (hi << 8);
43
- const s = (h >> 15) & 1;
44
- const e = (h >> 10) & 0x1f;
45
- const f = h & 0x3ff;
46
- if (e === 0) return f === 0 ? (s ? -0 : 0) : (s ? -1 : 1) * (f / 1024) * Math.pow(2, -14);
47
- if (e === 31) return 0; // clamp NaN/Inf
48
- return (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  }
50
 
51
- function dequantQ8_0(data, numElements) {
52
- const out = new Float32Array(numElements);
53
- const numBlocks = Math.ceil(numElements / Q8_0_BLOCK_SIZE);
54
- for (let b = 0; b < numBlocks; b++) {
55
- const blockOff = b * Q8_0_BLOCK_BYTES;
56
- const scale = fp16ToF32(data[blockOff], data[blockOff + 1]);
57
- const elemsInBlock = Math.min(Q8_0_BLOCK_SIZE, numElements - b * Q8_0_BLOCK_SIZE);
58
- for (let i = 0; i < elemsInBlock; i++) {
59
- const qval = data[blockOff + 2 + i]; // uint8, interpret as int8
60
- const signed = qval > 127 ? qval - 256 : qval;
61
- out[b * Q8_0_BLOCK_SIZE + i] = signed * scale;
62
- }
63
  }
64
  return out;
65
  }
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  // ─── Q4_K Dequantization ────────────────────────────────────────────────────
68
  const QK_K = 256;
69
  const Q4K_BLOCK_BYTES = 144;
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  function dequantQ4K(data, numElements) {
72
  const out = new Float32Array(numElements);
73
- const numBlocks = Math.ceil(numElements / QK_K);
74
  for (let b = 0; b < numBlocks; b++) {
75
- const off = b * Q4K_BLOCK_BYTES;
76
- const d = fp16ToF32(data[off], data[off + 1]);
77
- const dmin = fp16ToF32(data[off + 2], data[off + 3]);
78
- const scalesBytes = data.subarray(off + 4, off + 16);
79
- const qBytes = data.subarray(off + 16, off + 16 + 128);
80
-
81
- // Decode 6-bit scales and mins from 12 bytes
82
- const scales = new Float32Array(8);
83
- const mins = new Float32Array(8);
84
- for (let j = 0; j < 4; j++) {
85
- scales[j] = (scalesBytes[j] & 0x3f) * d;
86
- scales[j + 4] = ((scalesBytes[j + 4] & 0x0f) | ((scalesBytes[j] >> 6) << 4)) * d;
87
- mins[j] = (scalesBytes[j + 4] >> 4 | ((scalesBytes[j + 8] & 0x3f) << 4) ? 0 : 1) * dmin;
88
- }
89
- // Simplified: just use scale * d for each sub-block
90
- for (let j = 0; j < 8; j++) {
91
- const sc = (scalesBytes[j < 4 ? j : j] & 0x3f) * d;
92
- const mn = (scalesBytes[j < 4 ? j + 4 : j] & 0x3f) * dmin;
93
- for (let k = 0; k < 32; k++) {
94
- const idx = j * 32 + k;
95
- if (idx >= QK_K) break;
96
- const byteIdx = Math.floor(idx / 2);
97
- const nibble = idx % 2 === 0 ? (qBytes[byteIdx] & 0x0f) : (qBytes[byteIdx] >> 4);
98
- out[b * QK_K + idx] = nibble * sc - mn;
99
  }
 
100
  }
101
  }
102
  return out;
103
  }
104
 
105
- // Detect quant type by byte count
106
- function dequantAuto(data, numElements) {
107
- const expectedQ8 = Math.ceil(numElements / Q8_0_BLOCK_SIZE) * Q8_0_BLOCK_BYTES;
108
- const expectedQ4K = Math.ceil(numElements / QK_K) * Q4K_BLOCK_BYTES;
109
- const expectedF32 = numElements * 4;
110
-
111
- if (Math.abs(data.length - expectedF32) < expectedF32 * 0.05) {
112
- return new Float32Array(data.buffer, data.byteOffset, numElements);
113
- }
114
- if (Math.abs(data.length - expectedQ8) < expectedQ8 * 0.05) {
115
- return dequantQ8_0(data, numElements);
116
- }
117
- if (Math.abs(data.length - expectedQ4K) < expectedQ4K * 0.05) {
118
- return dequantQ4K(data, numElements);
119
  }
120
- // Fallback: try Q8_0
121
- console.warn(`[Aether] Unknown quant for ${numElements} elems, ${data.length} bytes. Trying Q8_0.`);
122
- return dequantQ8_0(data, numElements);
123
- }
124
-
125
- // ─── GGUF Parser ────────────────────────────────────────────────────────────
126
- const GGUF_MAGIC = 0x46554747;
127
- const VT = { UINT8: 0, INT8: 1, UINT16: 2, INT16: 3, UINT32: 4, INT32: 5, FLOAT32: 6, BOOL: 7, STRING: 8, ARRAY: 9, UINT64: 10, INT64: 11, FLOAT64: 12 };
128
-
129
- const GGML_BLOCK_SIZE = { 2:32,3:32,6:32,7:32,8:32,9:32,10:256,11:256,12:256,13:256,14:256,15:256 };
130
- const GGML_BLOCK_BYTES = { 2:18,3:20,6:22,7:24,8:34,9:36,10:84,11:110,12:144,13:176,14:210,15:292 };
131
- const GGML_TYPE_SIZE = { 0:4,1:2,16:1,17:2,18:4,19:8,20:8 };
132
-
133
- function calcTensorSize(dims, type) {
134
- let n = 1n;
135
- for (const d of dims) n *= d;
136
- const bs = GGML_BLOCK_SIZE[type];
137
- if (bs && GGML_BLOCK_BYTES[type]) return Math.ceil(Number(n) / bs) * GGML_BLOCK_BYTES[type];
138
- return Math.ceil(Number(n) * (GGML_TYPE_SIZE[type] ?? 4));
139
  }
140
 
141
- function readStr(buf, off) {
142
- const len = Number(buf.readBigUInt64LE(off));
143
- return { v: buf.subarray(off+8, off+8+len).toString('utf8'), o: off+8+len };
 
 
 
 
144
  }
145
 
146
- function readVal(buf, off, t) {
147
- switch(t) {
148
- case VT.UINT8: return { v: buf.readUInt8(off), o: off+1 };
149
- case VT.INT8: return { v: buf.readInt8(off), o: off+1 };
150
- case VT.UINT16: return { v: buf.readUInt16LE(off), o: off+2 };
151
- case VT.INT16: return { v: buf.readInt16LE(off), o: off+2 };
152
- case VT.UINT32: return { v: buf.readUInt32LE(off), o: off+4 };
153
- case VT.INT32: return { v: buf.readInt32LE(off), o: off+4 };
154
- case VT.FLOAT32: return { v: buf.readFloatLE(off), o: off+4 };
155
- case VT.BOOL: return { v: buf.readUInt8(off) !== 0, o: off+1 };
156
- case VT.STRING: { const r = readStr(buf, off); return { v: r.v, o: r.o }; }
157
- case VT.UINT64: return { v: buf.readBigUInt64LE(off), o: off+8 };
158
- case VT.INT64: return { v: buf.readBigInt64LE(off), o: off+8 };
159
- case VT.FLOAT64: return { v: buf.readDoubleLE(off), o: off+8 };
160
- case VT.ARRAY: {
161
- const at = buf.readUInt32LE(off);
162
- const al = Number(buf.readBigUInt64LE(off+4));
163
- let co = off+12;
164
- const arr = [];
165
- for (let i = 0; i < al; i++) { const r = readVal(buf, co, at); arr.push(r.v); co = r.o; }
166
- return { v: arr, o: co };
167
- }
168
- default: throw new Error(`Unknown GGUF value type: ${t}`);
169
- }
170
  }
171
-
172
- function parseGGUF(buf) {
173
- let off = 0;
174
- if (buf.readUInt32LE(off) !== GGUF_MAGIC) throw new Error('Not GGUF');
175
- off += 4;
176
- const version = buf.readUInt32LE(off); off += 4;
177
- const tensorCount = Number(buf.readBigUInt64LE(off)); off += 8;
178
- const kvCount = Number(buf.readBigUInt64LE(off)); off += 8;
179
- let alignment = 32;
180
- const metadata = {};
181
- for (let i = 0; i < kvCount; i++) {
182
- const { v: key, o: o1 } = readStr(buf, off); off = o1;
183
- const vt = buf.readUInt32LE(off); off += 4;
184
- const { v, o: o2 } = readVal(buf, off, vt); off = o2;
185
- metadata[key] = v;
186
- if (key === 'general.alignment') alignment = Number(v);
187
- }
188
- const tensors = [];
189
- for (let i = 0; i < tensorCount; i++) {
190
- const { v: name, o: o1 } = readStr(buf, off); off = o1;
191
- const nDims = buf.readUInt32LE(off); off += 4;
192
- const dims = [];
193
- for (let d = 0; d < nDims; d++) { dims.push(buf.readBigUInt64LE(off)); off += 8; }
194
- const type = buf.readUInt32LE(off); off += 4;
195
- const offset = buf.readBigUInt64LE(off); off += 8;
196
- const numElements = Number(dims.reduce((a, b) => a * b, 1n));
197
- tensors.push({ name, nDims, dims, type, offset, size: calcTensorSize(dims, type), numElements });
198
- }
199
- const dataOffset = Math.ceil(off / alignment) * alignment;
200
- return { version, tensors, dataOffset, metadata };
201
  }
202
 
203
  // ─── BPE Tokenizer ──────────────────────────────────────────────────────────
204
  class BPETokenizer {
205
- constructor(tokenizerJson) {
206
- const model = tokenizerJson.model || {};
207
- this.vocab = model.vocab || {};
208
- this.reverseVocab = {};
209
- for (const [token, id] of Object.entries(this.vocab)) {
210
- this.reverseVocab[id] = token;
211
- }
212
- this.merges = (model.merges || []).map((m, i) => {
213
- const [a, b] = m.split(' ');
214
- return { a, b, rank: i };
215
- });
216
  this.mergeRanks = {};
217
- for (const m of this.merges) {
218
- this.mergeRanks[`${m.a} ${m.b}`] = m.rank;
219
- }
220
- // Added tokens (special tokens)
221
- this.addedTokens = {};
222
- if (tokenizerJson.added_tokens) {
223
- for (const t of tokenizerJson.added_tokens) {
224
- this.addedTokens[t.content] = t.id;
225
- }
226
- }
227
- this.vocabSize = Object.keys(this.vocab).length + Object.keys(this.addedTokens).length;
228
  }
229
-
230
  encode(text) {
231
- // Handle special tokens first
232
- const specialPattern = /<\|[^|]+\|>/g;
233
- const parts = [];
234
- let lastIdx = 0;
235
- let match;
236
- while ((match = specialPattern.exec(text)) !== null) {
237
- if (match.index > lastIdx) parts.push({ text: text.slice(lastIdx, match.index), special: false });
238
- parts.push({ text: match[0], special: true });
239
- lastIdx = match.index + match[0].length;
240
  }
241
- if (lastIdx < text.length) parts.push({ text: text.slice(lastIdx), special: false });
242
-
243
  const tokens = [];
244
- for (const part of parts) {
245
- if (part.special) {
246
- const id = this.addedTokens[part.text] ?? this.vocab[part.text];
247
- if (id !== undefined) tokens.push(id);
248
- continue;
249
- }
250
- // Pre-tokenize: split into words (byte-level BPE style)
251
- const words = part.text.match(/\S+|\s+/g) || [];
252
- for (const word of words) {
253
- // Convert to byte-level tokens
254
- let symbols = [];
255
- for (let i = 0; i < word.length; i++) {
256
- const ch = word[i];
257
- const id = this.vocab[ch];
258
- if (id !== undefined) {
259
- symbols.push(ch);
260
- } else {
261
- // Byte fallback
262
- const bytes = Buffer.from(ch, 'utf8');
263
- for (const b of bytes) {
264
- const hex = `<0x${b.toString(16).toUpperCase().padStart(2, '0')}>`;
265
- symbols.push(hex);
266
- }
267
- }
268
  }
269
- // BPE merge loop
270
- while (symbols.length > 1) {
271
- let bestRank = Infinity;
272
- let bestIdx = -1;
273
- for (let i = 0; i < symbols.length - 1; i++) {
274
- const key = `${symbols[i]} ${symbols[i+1]}`;
275
- const rank = this.mergeRanks[key];
276
- if (rank !== undefined && rank < bestRank) {
277
- bestRank = rank;
278
- bestIdx = i;
279
- }
280
  }
281
- if (bestIdx === -1) break;
282
- const merged = symbols[bestIdx] + symbols[bestIdx + 1];
283
- symbols.splice(bestIdx, 2, merged);
284
- }
285
- // Map to IDs
286
- for (const sym of symbols) {
287
- const id = this.vocab[sym] ?? this.addedTokens[sym];
288
- if (id !== undefined) tokens.push(id);
289
  }
 
290
  }
291
  }
292
  return tokens;
293
  }
294
-
295
  decode(tokens) {
296
  const pieces = [];
297
  for (const t of tokens) {
298
- const piece = this.reverseVocab[t];
299
- if (piece !== undefined) {
300
- // Handle byte tokens like <0xFF>
301
- if (piece.startsWith('<0x') && piece.endsWith('>')) {
302
- const byte = parseInt(piece.slice(3, -1), 16);
303
- pieces.push(String.fromCharCode(byte));
304
- } else if (!piece.startsWith('<|')) {
305
- pieces.push(piece);
306
- }
307
- }
308
  }
309
  return pieces.join('').replace(/Δ /g, ' ').replace(/Ċ/g, '\n');
310
  }
@@ -312,56 +355,17 @@ class BPETokenizer {
312
 
313
  // ─── RoPE ───────────────────────────────────────────────────────────────────
314
  function applyRoPE(x, headDim, position, theta) {
315
- const halfDim = headDim / 2;
316
- for (let i = 0; i < halfDim; i++) {
317
  const freq = 1.0 / Math.pow(theta, (2 * i) / headDim);
318
  const angle = position * freq;
319
- const cos = Math.cos(angle);
320
- const sin = Math.sin(angle);
321
- const x0 = x[i];
322
- const x1 = x[i + halfDim];
323
  x[i] = x0 * cos - x1 * sin;
324
- x[i + halfDim] = x0 * sin + x1 * cos;
325
  }
326
  }
327
 
328
- // ─── Pure JS SIMD-style ops (fallback; WASM SIMD used when available) ───────
329
- function matVec(matrix, vector, rows, cols) {
330
- const out = new Float32Array(rows);
331
- for (let r = 0; r < rows; r++) {
332
- let sum = 0;
333
- const rowOff = r * cols;
334
- for (let c = 0; c < cols; c++) sum += matrix[rowOff + c] * vector[c];
335
- out[r] = sum;
336
- }
337
- return out;
338
- }
339
-
340
- function rmsNorm(x, weight, eps) {
341
- let ss = 0;
342
- for (let i = 0; i < x.length; i++) ss += x[i] * x[i];
343
- ss = 1.0 / Math.sqrt(ss / x.length + eps);
344
- const out = new Float32Array(x.length);
345
- for (let i = 0; i < x.length; i++) out[i] = x[i] * ss * weight[i];
346
- return out;
347
- }
348
-
349
- function silu(x) {
350
- const out = new Float32Array(x.length);
351
- for (let i = 0; i < x.length; i++) out[i] = x[i] / (1 + Math.exp(-x[i]));
352
- return out;
353
- }
354
-
355
- function softmax(x) {
356
- let max = -Infinity;
357
- for (let i = 0; i < x.length; i++) if (x[i] > max) max = x[i];
358
- const out = new Float32Array(x.length);
359
- let sum = 0;
360
- for (let i = 0; i < x.length; i++) { out[i] = Math.exp(x[i] - max); sum += out[i]; }
361
- for (let i = 0; i < x.length; i++) out[i] /= sum;
362
- return out;
363
- }
364
-
365
  // ─── Model ──────────────────────────────────────────────────────────────────
366
  let model = null;
367
 
@@ -372,78 +376,69 @@ function loadModel(ggufPath, tokenizerPath) {
372
  const parsed = parseGGUF(buf);
373
  console.log(`[Aether] Parsed ${parsed.tensors.length} tensors in ${Date.now() - t0}ms`);
374
 
375
- // Load tokenizer
376
- console.log('[Aether] Loading tokenizer...');
377
  const tokJson = JSON.parse(readFileSync(tokenizerPath, 'utf8'));
378
  const tokenizer = new BPETokenizer(tokJson);
379
 
380
- // Extract tensors by name
381
- const tensorByName = {};
382
- for (const t of parsed.tensors) tensorByName[t.name] = t;
383
 
384
- // Helper to extract and dequantize a tensor
385
- function getTensor(name) {
386
- const t = tensorByName[name];
387
- if (!t) { console.warn(`[Aether] Missing tensor: ${name}`); return null; }
388
- const absOffset = parsed.dataOffset + Number(t.offset);
389
- const raw = new Uint8Array(buf.buffer, buf.byteOffset + absOffset, t.size);
390
  return dequantAuto(raw, t.numElements);
391
  }
392
 
393
  console.log('[Aether] Dequantizing embeddings...');
394
- const tokenEmbd = getTensor('token_embd.weight');
395
 
396
  console.log('[Aether] Dequantizing layers...');
397
  const layers = [];
398
  for (let i = 0; i < CONFIG.numLayers; i++) {
399
  if (i % 8 === 0) console.log(`[Aether] Layer ${i}/${CONFIG.numLayers}...`);
400
  layers.push({
401
- attnNorm: getTensor(`blk.${i}.attn_norm.weight`),
402
- ffnNorm: getTensor(`blk.${i}.ffn_norm.weight`),
403
- qProj: getTensor(`blk.${i}.attn_q.weight`),
404
- kProj: getTensor(`blk.${i}.attn_k.weight`),
405
- vProj: getTensor(`blk.${i}.attn_v.weight`),
406
- oProj: getTensor(`blk.${i}.attn_output.weight`),
407
- gateProj: getTensor(`blk.${i}.ffn_gate.weight`),
408
- upProj: getTensor(`blk.${i}.ffn_up.weight`),
409
- downProj: getTensor(`blk.${i}.ffn_down.weight`),
410
  });
411
  }
412
 
413
- console.log('[Aether] Dequantizing output head...');
414
- const outputNorm = getTensor('output_norm.weight');
415
- let outputWeight = getTensor('output.weight');
416
- if (!outputWeight) {
417
- console.log('[Aether] No output.weight, using tied embeddings');
418
- outputWeight = tokenEmbd;
419
- }
420
 
421
  const loadTime = Date.now() - t0;
422
- console.log(`[Aether] Model loaded in ${loadTime}ms`);
423
-
424
  model = { tokenEmbd, layers, outputNorm, outputWeight, tokenizer, loadTime };
425
  }
426
 
427
  // ─── Inference ──────────────────────────────────────────────────────────────
428
- function generate(prompt, maxTokens = 100) {
429
  if (!model) throw new Error('Model not loaded');
430
 
431
  const t0 = performance.now();
432
  const { hiddenDim, numHeads, numKvHeads, headDim, intermediateSize, ropeTheta, rmsNormEps } = CONFIG;
433
  const kvDim = numKvHeads * headDim;
434
- const gqaRatio = numHeads / numKvHeads;
435
 
436
- // Format as chat
437
  const chatPrompt = `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n`;
438
  const inputTokens = model.tokenizer.encode(chatPrompt);
439
  const allTokens = [...inputTokens];
440
 
441
- // KV cache: [layer][position] -> { k, v }
442
- const kvCache = Array.from({ length: CONFIG.numLayers }, () => ({ keys: [], values: [] }));
 
 
 
443
 
444
  const tokenTimes = [];
445
 
446
- // Process all input tokens (prefill) then generate
447
  for (let step = 0; step < inputTokens.length + maxTokens - 1; step++) {
448
  const tokenStart = performance.now();
449
  const pos = step;
@@ -451,103 +446,92 @@ function generate(prompt, maxTokens = 100) {
451
 
452
  // Embed
453
  const hidden = new Float32Array(hiddenDim);
454
- const embOffset = tokenId * hiddenDim;
455
- for (let i = 0; i < hiddenDim; i++) hidden[i] = model.tokenEmbd[embOffset + i];
456
 
457
  let x = hidden;
458
 
459
- // Run through layers
460
  for (let l = 0; l < CONFIG.numLayers; l++) {
461
- const layer = model.layers[l];
462
 
463
  // 1. Attention norm
464
- const normed = rmsNorm(x, layer.attnNorm, rmsNormEps);
465
 
466
- // 2. Q, K, V projections
467
- const q = matVec(layer.qProj, normed, hiddenDim, hiddenDim);
468
- const k = matVec(layer.kProj, normed, kvDim, hiddenDim);
469
- const v = matVec(layer.vProj, normed, kvDim, hiddenDim);
470
 
471
  // 3. RoPE
472
- for (let h = 0; h < numHeads; h++) {
473
  applyRoPE(q.subarray(h * headDim, (h + 1) * headDim), headDim, pos, ropeTheta);
474
- }
475
- for (let h = 0; h < numKvHeads; h++) {
476
  applyRoPE(k.subarray(h * headDim, (h + 1) * headDim), headDim, pos, ropeTheta);
477
- }
478
 
479
  // 4. Store in KV cache
480
  kvCache[l].keys.push(new Float32Array(k));
481
  kvCache[l].values.push(new Float32Array(v));
482
 
483
- // 5. Attention with full KV cache
484
- const attnOut = new Float32Array(hiddenDim);
485
  const seqLen = kvCache[l].keys.length;
 
486
 
487
- for (let h = 0; h < numHeads; h++) {
488
- const kvHead = Math.floor(h / gqaRatio);
489
- const qHead = q.subarray(h * headDim, (h + 1) * headDim);
490
-
491
- // Compute attention scores
492
- const scores = new Float32Array(seqLen);
493
  for (let s = 0; s < seqLen; s++) {
494
- const kHead = kvCache[l].keys[s].subarray(kvHead * headDim, (kvHead + 1) * headDim);
495
- let dot = 0;
496
- for (let d = 0; d < headDim; d++) dot += qHead[d] * kHead[d];
497
- scores[s] = dot / Math.sqrt(headDim);
498
  }
499
-
500
- // Causal mask: already handled (only see past positions)
501
- // Softmax
502
- const attnWeights = softmax(scores);
503
-
504
- // Weighted sum of values
505
- for (let s = 0; s < seqLen; s++) {
506
- const vHead = kvCache[l].values[s].subarray(kvHead * headDim, (kvHead + 1) * headDim);
507
- const w = attnWeights[s];
508
- for (let d = 0; d < headDim; d++) {
509
- attnOut[h * headDim + d] += w * vHead[d];
 
 
 
 
 
 
 
 
510
  }
511
  }
512
  }
513
 
514
- // 6. Output projection
515
- const projected = matVec(layer.oProj, attnOut, hiddenDim, hiddenDim);
516
-
517
- // 7. Residual
518
- const postAttn = new Float32Array(hiddenDim);
519
- for (let i = 0; i < hiddenDim; i++) postAttn[i] = x[i] + projected[i];
520
-
521
- // 8. FFN norm
522
- const ffnInput = rmsNorm(postAttn, layer.ffnNorm, rmsNormEps);
523
-
524
- // 9. SwiGLU MLP
525
- const gate = matVec(layer.gateProj, ffnInput, intermediateSize, hiddenDim);
526
- const up = matVec(layer.upProj, ffnInput, intermediateSize, hiddenDim);
527
- const activated = silu(gate);
528
- for (let i = 0; i < intermediateSize; i++) activated[i] *= up[i];
529
- const down = matVec(layer.downProj, activated, hiddenDim, intermediateSize);
530
-
531
- // 10. Residual
532
- x = new Float32Array(hiddenDim);
533
- for (let i = 0; i < hiddenDim; i++) x[i] = postAttn[i] + down[i];
534
  }
535
 
536
- // Only sample if past prefill
537
  if (step >= inputTokens.length - 1) {
538
- // Final norm + LM head
539
- const finalNormed = rmsNorm(x, model.outputNorm, rmsNormEps);
540
- const logits = matVec(model.outputWeight, finalNormed, CONFIG.vocabSize, hiddenDim);
541
 
542
  // Temperature sampling
543
- const temperature = 0.7;
544
- for (let i = 0; i < logits.length; i++) logits[i] /= temperature;
545
- const probs = softmax(logits);
546
 
547
- // Top-p sampling
548
  const indexed = Array.from(probs).map((p, i) => ({ p, i })).sort((a, b) => b.p - a.p);
549
- let cumP = 0;
550
- let chosen = indexed[0].i;
551
  const r = Math.random();
552
  for (const { p, i } of indexed) {
553
  cumP += p;
@@ -555,73 +539,73 @@ function generate(prompt, maxTokens = 100) {
555
  if (cumP > 0.9) break;
556
  }
557
 
558
- const tokenEnd = performance.now();
559
- if (step >= inputTokens.length) tokenTimes.push(tokenEnd - tokenStart);
560
-
561
  if (chosen === CONFIG.eosToken) break;
562
  allTokens.push(chosen);
563
  }
564
  }
565
 
566
  const totalTime = performance.now() - t0;
567
- const generatedTokens = allTokens.slice(inputTokens.length);
568
- const text = model.tokenizer.decode(generatedTokens);
569
- const avgTokenTime = tokenTimes.length > 0 ? tokenTimes.reduce((a, b) => a + b, 0) / tokenTimes.length : 0;
570
 
571
  return {
572
  text,
573
- tokens: generatedTokens.length,
574
  totalTimeMs: Math.round(totalTime),
575
- avgTokenMs: Math.round(avgTokenTime),
576
  prefillTokens: inputTokens.length,
577
- engine: 'Aether WASM-SIMD',
 
578
  };
579
  }
580
 
581
  // ─── HTTP Server ────────────────────────────────────────────────────────────
582
- function startServer() {
583
- const server = createServer((req, res) => {
584
- if (req.method === 'POST' && req.url === '/generate') {
585
- let body = '';
586
- req.on('data', c => body += c);
587
- req.on('end', () => {
588
- try {
589
- const { prompt, max_tokens } = JSON.parse(body);
590
- const result = generate(prompt, max_tokens || 100);
591
- res.writeHead(200, { 'Content-Type': 'application/json' });
592
- res.end(JSON.stringify(result));
593
- } catch (e) {
594
- res.writeHead(500, { 'Content-Type': 'application/json' });
595
- res.end(JSON.stringify({ error: e.message, stack: e.stack }));
596
- }
597
- });
598
- } else if (req.url === '/health') {
599
- res.writeHead(200, { 'Content-Type': 'application/json' });
600
- res.end(JSON.stringify({ status: 'ok', model: model ? 'loaded' : 'not loaded', loadTime: model?.loadTime }));
601
- } else {
602
- res.writeHead(404);
603
- res.end('Not found');
604
- }
605
- });
606
-
607
- server.listen(PORT, '127.0.0.1', () => {
608
- console.log(`[Aether] Server listening on http://127.0.0.1:${PORT}`);
609
- });
610
- }
611
 
612
  // ─── Main ───────────────────────────────────────────────────────────────────
613
- const ggufPath = process.env.GGUF_PATH || join('/tmp/hf_cache', 'buleyean-smollm2-360m-q8_0.gguf');
614
- const tokenizerPath = process.env.TOKENIZER_PATH || join('/tmp/hf_cache', 'tokenizer.json');
615
 
616
- // Download if needed
617
- if (!existsSync(ggufPath)) {
618
- console.log('[Aether] Downloading GGUF model...');
619
- execSync(`python3 -c "from huggingface_hub import hf_hub_download; hf_hub_download('forkjoin-ai/buleyean-smollm2-360m', 'buleyean-smollm2-360m-q8_0.gguf', cache_dir='/tmp/hf_cache', local_dir='/tmp/hf_cache')"`, { stdio: 'inherit' });
620
- }
621
- if (!existsSync(tokenizerPath)) {
622
- console.log('[Aether] Downloading tokenizer...');
623
- execSync(`python3 -c "from huggingface_hub import hf_hub_download; hf_hub_download('HuggingFaceTB/SmolLM2-360M-Instruct', 'tokenizer.json', cache_dir='/tmp/hf_cache', local_dir='/tmp/hf_cache')"`, { stdio: 'inherit' });
 
 
 
 
 
 
 
 
 
 
 
624
  }
625
 
626
- loadModel(ggufPath, tokenizerPath);
627
- startServer();
 
1
  /**
2
  * Aether Inference Server
3
  *
4
+ * SmolLM2-360M inference using WASM SIMD kernels.
5
+ * Zero external ML dependencies. Pure JS + 14KB WASM binary.
6
  *
7
+ * GGUF parse β†’ WASM SIMD matVec β†’ RoPE β†’ fusedSiluMul β†’ sampling
 
8
  */
9
 
10
  import { createServer } from 'http';
11
+ import { readFileSync, existsSync } from 'fs';
12
  import { execSync } from 'child_process';
13
  import { fileURLToPath } from 'url';
14
  import { dirname, join } from 'path';
 
16
  const __dirname = dirname(fileURLToPath(import.meta.url));
17
  const PORT = parseInt(process.env.AETHER_PORT || '7861');
18
 
19
+ // ─── Model Config (SmolLM2-360M-Instruct) ──────────────────────────────────
20
  const CONFIG = {
21
  hiddenDim: 960,
22
  numLayers: 32,
 
32
  bosToken: 1,
33
  };
34
 
35
+ // ─── WASM SIMD Kernel Loader ────────────────────────────────────────────────
36
+ let simd = null;
 
 
37
 
38
+ async function loadSIMD() {
39
+ const wasmPath = join(__dirname, 'simd-kernels.wasm');
40
+ if (!existsSync(wasmPath)) {
41
+ console.log('[Aether] WASM SIMD binary not found, using JS fallbacks');
42
+ return null;
43
+ }
44
+
45
+ try {
46
+ const wasmBytes = readFileSync(wasmPath);
47
+ const { instance } = await WebAssembly.instantiate(wasmBytes, {
48
+ env: { expf: Math.exp, tanhf: Math.tanh, powf: Math.pow },
49
+ });
50
+ const wasm = instance.exports;
51
+ wasm.resetHeap(65536);
52
+ console.log('[Aether] WASM SIMD kernels loaded (14KB binary)');
53
+
54
+ const memory = wasm.memory;
55
+
56
+ function heapF32() { return new Float32Array(memory.buffer); }
57
+ function heapU8() { return new Uint8Array(memory.buffer); }
58
+ function copyTo(ptr, f32) { heapF32().set(f32, ptr >> 2); }
59
+ function copyBytesTo(ptr, u8) { heapU8().set(u8, ptr); }
60
+ function copyFrom(ptr, len) { return heapF32().slice(ptr >> 2, (ptr >> 2) + len); }
61
+
62
+ return {
63
+ matVec(matrix, vector, rows, cols) {
64
+ const saved = wasm.getHeapPtr();
65
+ const mPtr = wasm.allocate(matrix.byteLength);
66
+ const vPtr = wasm.allocate(vector.byteLength);
67
+ const rPtr = wasm.allocate(rows * 4);
68
+ copyTo(mPtr, matrix); copyTo(vPtr, vector);
69
+ wasm.matVecSimdBatch4(mPtr, vPtr, rPtr, rows, cols);
70
+ const result = copyFrom(rPtr, rows);
71
+ wasm.resetHeap(saved);
72
+ return result;
73
+ },
74
+ rmsNorm(x, weight, eps) {
75
+ const saved = wasm.getHeapPtr();
76
+ const xPtr = wasm.allocate(x.byteLength);
77
+ const wPtr = wasm.allocate(weight.byteLength);
78
+ const rPtr = wasm.allocate(x.byteLength);
79
+ copyTo(xPtr, x); copyTo(wPtr, weight);
80
+ wasm.rmsNormSimd(xPtr, wPtr, rPtr, x.length, eps);
81
+ const result = copyFrom(rPtr, x.length);
82
+ wasm.resetHeap(saved);
83
+ return result;
84
+ },
85
+ softmax(x) {
86
+ const saved = wasm.getHeapPtr();
87
+ const xPtr = wasm.allocate(x.byteLength);
88
+ const rPtr = wasm.allocate(x.byteLength);
89
+ copyTo(xPtr, x);
90
+ wasm.softmaxSimd(xPtr, rPtr, x.length);
91
+ const result = copyFrom(rPtr, x.length);
92
+ wasm.resetHeap(saved);
93
+ return result;
94
+ },
95
+ fusedSiluMul(gate, up) {
96
+ const saved = wasm.getHeapPtr();
97
+ const gPtr = wasm.allocate(gate.byteLength);
98
+ const uPtr = wasm.allocate(up.byteLength);
99
+ const rPtr = wasm.allocate(gate.byteLength);
100
+ copyTo(gPtr, gate); copyTo(uPtr, up);
101
+ wasm.fusedSiluMul(gPtr, uPtr, rPtr, gate.length);
102
+ const result = copyFrom(rPtr, gate.length);
103
+ wasm.resetHeap(saved);
104
+ return result;
105
+ },
106
+ add(a, b) {
107
+ const saved = wasm.getHeapPtr();
108
+ const aPtr = wasm.allocate(a.byteLength);
109
+ const bPtr = wasm.allocate(b.byteLength);
110
+ const rPtr = wasm.allocate(a.byteLength);
111
+ copyTo(aPtr, a); copyTo(bPtr, b);
112
+ wasm.addSimd(aPtr, bPtr, rPtr, a.length);
113
+ const result = copyFrom(rPtr, a.length);
114
+ wasm.resetHeap(saved);
115
+ return result;
116
+ },
117
+ flashAttentionMultiHead(query, keys, values, seqLen, numHeads, numKvHeads, headDim) {
118
+ const saved = wasm.getHeapPtr();
119
+ const scale = 1.0 / Math.sqrt(headDim);
120
+ const qPtr = wasm.allocate(query.byteLength);
121
+ const kPtr = wasm.allocate(keys.byteLength);
122
+ const vPtr = wasm.allocate(values.byteLength);
123
+ const rPtr = wasm.allocate(numHeads * headDim * 4);
124
+ copyTo(qPtr, query); copyTo(kPtr, keys); copyTo(vPtr, values);
125
+ wasm.flashAttentionMultiHead(qPtr, kPtr, vPtr, rPtr, seqLen, numHeads, numKvHeads, headDim, scale);
126
+ const result = copyFrom(rPtr, numHeads * headDim);
127
+ wasm.resetHeap(saved);
128
+ return result;
129
+ },
130
+ };
131
+ } catch (e) {
132
+ console.warn(`[Aether] WASM SIMD failed: ${e.message}, using JS fallbacks`);
133
+ return null;
134
+ }
135
  }
136
 
137
+ // ─── JS Fallbacks (used if WASM unavailable) ────────────────────────────────
138
+ function matVecJS(matrix, vector, rows, cols) {
139
+ const out = new Float32Array(rows);
140
+ for (let r = 0; r < rows; r++) {
141
+ let sum = 0; const off = r * cols;
142
+ for (let c = 0; c < cols; c++) sum += matrix[off + c] * vector[c];
143
+ out[r] = sum;
 
 
 
 
 
144
  }
145
  return out;
146
  }
147
 
148
+ function rmsNormJS(x, weight, eps) {
149
+ let ss = 0;
150
+ for (let i = 0; i < x.length; i++) ss += x[i] * x[i];
151
+ ss = 1.0 / Math.sqrt(ss / x.length + eps);
152
+ const out = new Float32Array(x.length);
153
+ for (let i = 0; i < x.length; i++) out[i] = x[i] * ss * weight[i];
154
+ return out;
155
+ }
156
+
157
+ function softmaxJS(x) {
158
+ let max = -Infinity;
159
+ for (let i = 0; i < x.length; i++) if (x[i] > max) max = x[i];
160
+ const out = new Float32Array(x.length);
161
+ let sum = 0;
162
+ for (let i = 0; i < x.length; i++) { out[i] = Math.exp(x[i] - max); sum += out[i]; }
163
+ for (let i = 0; i < x.length; i++) out[i] /= sum;
164
+ return out;
165
+ }
166
+
167
+ function fusedSiluMulJS(gate, up) {
168
+ const out = new Float32Array(gate.length);
169
+ for (let i = 0; i < gate.length; i++) {
170
+ const g = gate[i];
171
+ out[i] = (g / (1 + Math.exp(-g))) * up[i];
172
+ }
173
+ return out;
174
+ }
175
+
176
+ function addJS(a, b) {
177
+ const out = new Float32Array(a.length);
178
+ for (let i = 0; i < a.length; i++) out[i] = a[i] + b[i];
179
+ return out;
180
+ }
181
+
182
+ // Ops wrapper -- uses WASM SIMD when available, JS fallback otherwise
183
+ function ops() {
184
+ return {
185
+ matVec: simd?.matVec || matVecJS,
186
+ rmsNorm: simd?.rmsNorm || rmsNormJS,
187
+ softmax: simd?.softmax || softmaxJS,
188
+ fusedSiluMul: simd?.fusedSiluMul || fusedSiluMulJS,
189
+ add: simd?.add || addJS,
190
+ flashAttentionMultiHead: simd?.flashAttentionMultiHead || null,
191
+ };
192
+ }
193
+
194
  // ─── Q4_K Dequantization ────────────────────────────────────────────────────
195
  const QK_K = 256;
196
  const Q4K_BLOCK_BYTES = 144;
197
 
198
+ function fp16(lo, hi) {
199
+ const h = lo | (hi << 8);
200
+ const s = (h >> 15) & 1, e = (h >> 10) & 0x1f, f = h & 0x3ff;
201
+ if (e === 0) return f === 0 ? 0 : (s ? -1 : 1) * (f / 1024) * Math.pow(2, -14);
202
+ if (e === 31) return 0;
203
+ return (s ? -1 : 1) * Math.pow(2, e - 15) * (1 + f / 1024);
204
+ }
205
+
206
+ function getScaleMinK4(gi, scales) {
207
+ if (gi < 4) return [scales[gi] & 63, scales[gi + 4] & 63];
208
+ return [(scales[gi + 4] & 0xf) | ((scales[gi - 4] >> 6) << 4),
209
+ (scales[gi + 4] >> 4) | ((scales[gi] >> 6) << 4)];
210
+ }
211
+
212
  function dequantQ4K(data, numElements) {
213
  const out = new Float32Array(numElements);
214
+ const numBlocks = Math.floor(data.length / Q4K_BLOCK_BYTES);
215
  for (let b = 0; b < numBlocks; b++) {
216
+ const outOff = b * QK_K;
217
+ if (outOff + QK_K > numElements) break;
218
+ const bs = b * Q4K_BLOCK_BYTES;
219
+ const d = fp16(data[bs], data[bs + 1]);
220
+ const dmin = fp16(data[bs + 2], data[bs + 3]);
221
+ const scales = data.subarray(bs + 4, bs + 16);
222
+ const qs = data.subarray(bs + 16, bs + Q4K_BLOCK_BYTES);
223
+ let si = 0, qi = 0;
224
+ for (let j = 0; j < QK_K; j += 64) {
225
+ const [sc1, m1] = getScaleMinK4(si, scales);
226
+ const [sc2, m2] = getScaleMinK4(si + 1, scales);
227
+ const d1 = d * sc1, d2 = d * sc2, dm1 = dmin * m1, dm2 = dmin * m2;
228
+ for (let lane = 0; lane < 32; lane++) {
229
+ const qb = qs[qi + lane];
230
+ out[outOff + j + lane] = d1 * (qb & 0x0f) - dm1;
231
+ out[outOff + j + 32 + lane] = d2 * (qb >> 4) - dm2;
 
 
 
 
 
 
 
 
232
  }
233
+ qi += 32; si += 2;
234
  }
235
  }
236
  return out;
237
  }
238
 
239
+ // Q8_0 dequant
240
+ const Q8_BLOCK = 32, Q8_BYTES = 34;
241
+ function dequantQ8(data, numElements) {
242
+ const out = new Float32Array(numElements);
243
+ const nb = Math.ceil(numElements / Q8_BLOCK);
244
+ for (let b = 0; b < nb; b++) {
245
+ const off = b * Q8_BYTES;
246
+ const scale = fp16(data[off], data[off + 1]);
247
+ const n = Math.min(Q8_BLOCK, numElements - b * Q8_BLOCK);
248
+ for (let i = 0; i < n; i++) {
249
+ const v = data[off + 2 + i]; out[b * Q8_BLOCK + i] = (v > 127 ? v - 256 : v) * scale;
250
+ }
 
 
251
  }
252
+ return out;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  }
254
 
255
+ function dequantAuto(data, numElements) {
256
+ const f32 = numElements * 4, q8 = Math.ceil(numElements / Q8_BLOCK) * Q8_BYTES;
257
+ const q4k = Math.ceil(numElements / QK_K) * Q4K_BLOCK_BYTES;
258
+ if (Math.abs(data.length - f32) < f32 * 0.05) return new Float32Array(data.buffer, data.byteOffset, numElements);
259
+ if (Math.abs(data.length - q4k) < q4k * 0.05) return dequantQ4K(data, numElements);
260
+ if (Math.abs(data.length - q8) < q8 * 0.05) return dequantQ8(data, numElements);
261
+ return dequantQ8(data, numElements);
262
  }
263
 
264
+ // ─── GGUF Parser ────────────────────────────────────────────────────────────
265
+ const GGUF_MAGIC = 0x46554747;
266
+ const VT = { UINT8:0,INT8:1,UINT16:2,INT16:3,UINT32:4,INT32:5,FLOAT32:6,BOOL:7,STRING:8,ARRAY:9,UINT64:10,INT64:11,FLOAT64:12 };
267
+ const BLK_SZ = {2:32,3:32,6:32,7:32,8:32,9:32,10:256,11:256,12:256,13:256,14:256,15:256};
268
+ const BLK_BY = {2:18,3:20,6:22,7:24,8:34,9:36,10:84,11:110,12:144,13:176,14:210,15:292};
269
+ const TY_SZ = {0:4,1:2,16:1,17:2,18:4,19:8,20:8};
270
+
271
+ function calcSz(dims, type) {
272
+ let n=1n; for (const d of dims) n*=d;
273
+ const bs=BLK_SZ[type]; if(bs&&BLK_BY[type]) return Math.ceil(Number(n)/bs)*BLK_BY[type];
274
+ return Math.ceil(Number(n)*(TY_SZ[type]??4));
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  }
276
+ function rStr(buf,off){const len=Number(buf.readBigUInt64LE(off));return{v:buf.subarray(off+8,off+8+len).toString('utf8'),o:off+8+len};}
277
+ function rVal(buf,off,t){switch(t){
278
+ case VT.UINT8:return{v:buf.readUInt8(off),o:off+1};case VT.INT8:return{v:buf.readInt8(off),o:off+1};
279
+ case VT.UINT16:return{v:buf.readUInt16LE(off),o:off+2};case VT.INT16:return{v:buf.readInt16LE(off),o:off+2};
280
+ case VT.UINT32:return{v:buf.readUInt32LE(off),o:off+4};case VT.INT32:return{v:buf.readInt32LE(off),o:off+4};
281
+ case VT.FLOAT32:return{v:buf.readFloatLE(off),o:off+4};case VT.BOOL:return{v:buf.readUInt8(off)!==0,o:off+1};
282
+ case VT.STRING:{const r=rStr(buf,off);return{v:r.v,o:r.o};}
283
+ case VT.UINT64:return{v:buf.readBigUInt64LE(off),o:off+8};case VT.INT64:return{v:buf.readBigInt64LE(off),o:off+8};
284
+ case VT.FLOAT64:return{v:buf.readDoubleLE(off),o:off+8};
285
+ case VT.ARRAY:{const at=buf.readUInt32LE(off);const al=Number(buf.readBigUInt64LE(off+4));let co=off+12;const arr=[];
286
+ for(let i=0;i<al;i++){const r=rVal(buf,co,at);arr.push(r.v);co=r.o;}return{v:arr,o:co};}
287
+ default:throw new Error(`Unknown GGUF type: ${t}`);
288
+ }}
289
+ function parseGGUF(buf){
290
+ let off=0;if(buf.readUInt32LE(off)!==GGUF_MAGIC)throw new Error('Not GGUF');off+=4;
291
+ off+=4;const tc=Number(buf.readBigUInt64LE(off));off+=8;const kc=Number(buf.readBigUInt64LE(off));off+=8;
292
+ let align=32;for(let i=0;i<kc;i++){const{v:key,o:o1}=rStr(buf,off);off=o1;const vt=buf.readUInt32LE(off);off+=4;
293
+ const{v,o:o2}=rVal(buf,off,vt);off=o2;if(key==='general.alignment')align=Number(v);}
294
+ const tensors=[];for(let i=0;i<tc;i++){const{v:name,o:o1}=rStr(buf,off);off=o1;const nd=buf.readUInt32LE(off);off+=4;
295
+ const dims=[];for(let d=0;d<nd;d++){dims.push(buf.readBigUInt64LE(off));off+=8;}const type=buf.readUInt32LE(off);off+=4;
296
+ const offset=buf.readBigUInt64LE(off);off+=8;
297
+ tensors.push({name,dims,type,offset,size:calcSz(dims,type),numElements:Number(dims.reduce((a,b)=>a*b,1n))});}
298
+ return{tensors,dataOffset:Math.ceil(off/align)*align};
 
 
 
 
 
 
 
299
  }
300
 
301
  // ─── BPE Tokenizer ──────────────────────────────────────────────────────────
302
  class BPETokenizer {
303
+ constructor(json) {
304
+ const m = json.model || {};
305
+ this.vocab = m.vocab || {};
306
+ this.rev = {};
307
+ for (const [t, id] of Object.entries(this.vocab)) this.rev[id] = t;
 
 
 
 
 
 
308
  this.mergeRanks = {};
309
+ for (const [i, merge] of (m.merges || []).entries()) this.mergeRanks[merge] = i;
310
+ this.added = {};
311
+ if (json.added_tokens) for (const t of json.added_tokens) this.added[t.content] = t.id;
 
 
 
 
 
 
 
 
312
  }
 
313
  encode(text) {
314
+ const sp = /<\|[^|]+\|>/g;
315
+ const parts = []; let last = 0, m;
316
+ while ((m = sp.exec(text)) !== null) {
317
+ if (m.index > last) parts.push({ t: text.slice(last, m.index), s: false });
318
+ parts.push({ t: m[0], s: true }); last = m.index + m[0].length;
 
 
 
 
319
  }
320
+ if (last < text.length) parts.push({ t: text.slice(last), s: false });
 
321
  const tokens = [];
322
+ for (const p of parts) {
323
+ if (p.s) { const id = this.added[p.t] ?? this.vocab[p.t]; if (id !== undefined) tokens.push(id); continue; }
324
+ const words = p.t.match(/\S+|\s+/g) || [];
325
+ for (const w of words) {
326
+ let syms = [];
327
+ for (const ch of w) {
328
+ if (this.vocab[ch] !== undefined) syms.push(ch);
329
+ else for (const b of Buffer.from(ch, 'utf8')) syms.push(`<0x${b.toString(16).toUpperCase().padStart(2,'0')}>`);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  }
331
+ while (syms.length > 1) {
332
+ let best = Infinity, bi = -1;
333
+ for (let i = 0; i < syms.length - 1; i++) {
334
+ const r = this.mergeRanks[`${syms[i]} ${syms[i+1]}`];
335
+ if (r !== undefined && r < best) { best = r; bi = i; }
 
 
 
 
 
 
336
  }
337
+ if (bi === -1) break;
338
+ syms.splice(bi, 2, syms[bi] + syms[bi + 1]);
 
 
 
 
 
 
339
  }
340
+ for (const s of syms) { const id = this.vocab[s] ?? this.added[s]; if (id !== undefined) tokens.push(id); }
341
  }
342
  }
343
  return tokens;
344
  }
 
345
  decode(tokens) {
346
  const pieces = [];
347
  for (const t of tokens) {
348
+ const p = this.rev[t];
349
+ if (p && p.startsWith('<0x') && p.endsWith('>')) pieces.push(String.fromCharCode(parseInt(p.slice(3,-1),16)));
350
+ else if (p && !p.startsWith('<|')) pieces.push(p);
 
 
 
 
 
 
 
351
  }
352
  return pieces.join('').replace(/Δ /g, ' ').replace(/Ċ/g, '\n');
353
  }
 
355
 
356
  // ─── RoPE ───────────────────────────────────────────────────────────────────
357
  function applyRoPE(x, headDim, position, theta) {
358
+ const half = headDim / 2;
359
+ for (let i = 0; i < half; i++) {
360
  const freq = 1.0 / Math.pow(theta, (2 * i) / headDim);
361
  const angle = position * freq;
362
+ const cos = Math.cos(angle), sin = Math.sin(angle);
363
+ const x0 = x[i], x1 = x[i + half];
 
 
364
  x[i] = x0 * cos - x1 * sin;
365
+ x[i + half] = x0 * sin + x1 * cos;
366
  }
367
  }
368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  // ─── Model ──────────────────────────────────────────────────────────────────
370
  let model = null;
371
 
 
376
  const parsed = parseGGUF(buf);
377
  console.log(`[Aether] Parsed ${parsed.tensors.length} tensors in ${Date.now() - t0}ms`);
378
 
 
 
379
  const tokJson = JSON.parse(readFileSync(tokenizerPath, 'utf8'));
380
  const tokenizer = new BPETokenizer(tokJson);
381
 
382
+ const byName = {};
383
+ for (const t of parsed.tensors) byName[t.name] = t;
 
384
 
385
+ function get(name) {
386
+ const t = byName[name];
387
+ if (!t) { console.warn(`[Aether] Missing: ${name}`); return null; }
388
+ const raw = new Uint8Array(buf.buffer, buf.byteOffset + parsed.dataOffset + Number(t.offset), t.size);
 
 
389
  return dequantAuto(raw, t.numElements);
390
  }
391
 
392
  console.log('[Aether] Dequantizing embeddings...');
393
+ const tokenEmbd = get('token_embd.weight');
394
 
395
  console.log('[Aether] Dequantizing layers...');
396
  const layers = [];
397
  for (let i = 0; i < CONFIG.numLayers; i++) {
398
  if (i % 8 === 0) console.log(`[Aether] Layer ${i}/${CONFIG.numLayers}...`);
399
  layers.push({
400
+ attnNorm: get(`blk.${i}.attn_norm.weight`),
401
+ ffnNorm: get(`blk.${i}.ffn_norm.weight`),
402
+ qProj: get(`blk.${i}.attn_q.weight`),
403
+ kProj: get(`blk.${i}.attn_k.weight`),
404
+ vProj: get(`blk.${i}.attn_v.weight`),
405
+ oProj: get(`blk.${i}.attn_output.weight`),
406
+ gateProj: get(`blk.${i}.ffn_gate.weight`),
407
+ upProj: get(`blk.${i}.ffn_up.weight`),
408
+ downProj: get(`blk.${i}.ffn_down.weight`),
409
  });
410
  }
411
 
412
+ const outputNorm = get('output_norm.weight');
413
+ let outputWeight = get('output.weight');
414
+ if (!outputWeight) { console.log('[Aether] Tied embeddings'); outputWeight = tokenEmbd; }
 
 
 
 
415
 
416
  const loadTime = Date.now() - t0;
417
+ console.log(`[Aether] Model loaded in ${(loadTime/1000).toFixed(1)}s (WASM SIMD: ${simd ? 'YES' : 'NO'})`);
 
418
  model = { tokenEmbd, layers, outputNorm, outputWeight, tokenizer, loadTime };
419
  }
420
 
421
  // ─── Inference ──────────────────────────────────────────────────────────────
422
+ function generate(prompt, maxTokens = 50) {
423
  if (!model) throw new Error('Model not loaded');
424
 
425
  const t0 = performance.now();
426
  const { hiddenDim, numHeads, numKvHeads, headDim, intermediateSize, ropeTheta, rmsNormEps } = CONFIG;
427
  const kvDim = numKvHeads * headDim;
428
+ const o = ops();
429
 
 
430
  const chatPrompt = `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n`;
431
  const inputTokens = model.tokenizer.encode(chatPrompt);
432
  const allTokens = [...inputTokens];
433
 
434
+ // KV cache: flat arrays per layer for WASM flash attention
435
+ const kvCache = Array.from({ length: CONFIG.numLayers }, () => ({
436
+ keys: [], // array of Float32Array[kvDim] per position
437
+ values: [], // array of Float32Array[kvDim] per position
438
+ }));
439
 
440
  const tokenTimes = [];
441
 
 
442
  for (let step = 0; step < inputTokens.length + maxTokens - 1; step++) {
443
  const tokenStart = performance.now();
444
  const pos = step;
 
446
 
447
  // Embed
448
  const hidden = new Float32Array(hiddenDim);
449
+ const embOff = tokenId * hiddenDim;
450
+ for (let i = 0; i < hiddenDim; i++) hidden[i] = model.tokenEmbd[embOff + i];
451
 
452
  let x = hidden;
453
 
 
454
  for (let l = 0; l < CONFIG.numLayers; l++) {
455
+ const ly = model.layers[l];
456
 
457
  // 1. Attention norm
458
+ const normed = o.rmsNorm(x, ly.attnNorm, rmsNormEps);
459
 
460
+ // 2. Q, K, V projections (WASM SIMD matVec)
461
+ const q = o.matVec(ly.qProj, normed, hiddenDim, hiddenDim);
462
+ const k = o.matVec(ly.kProj, normed, kvDim, hiddenDim);
463
+ const v = o.matVec(ly.vProj, normed, kvDim, hiddenDim);
464
 
465
  // 3. RoPE
466
+ for (let h = 0; h < numHeads; h++)
467
  applyRoPE(q.subarray(h * headDim, (h + 1) * headDim), headDim, pos, ropeTheta);
468
+ for (let h = 0; h < numKvHeads; h++)
 
469
  applyRoPE(k.subarray(h * headDim, (h + 1) * headDim), headDim, pos, ropeTheta);
 
470
 
471
  // 4. Store in KV cache
472
  kvCache[l].keys.push(new Float32Array(k));
473
  kvCache[l].values.push(new Float32Array(v));
474
 
475
+ // 5. Attention
 
476
  const seqLen = kvCache[l].keys.length;
477
+ let attnOut;
478
 
479
+ if (o.flashAttentionMultiHead && seqLen > 1) {
480
+ // Use WASM flash attention with GQA
481
+ const flatKeys = new Float32Array(seqLen * kvDim);
482
+ const flatVals = new Float32Array(seqLen * kvDim);
 
 
483
  for (let s = 0; s < seqLen; s++) {
484
+ flatKeys.set(kvCache[l].keys[s], s * kvDim);
485
+ flatVals.set(kvCache[l].values[s], s * kvDim);
 
 
486
  }
487
+ attnOut = o.flashAttentionMultiHead(q, flatKeys, flatVals, seqLen, numHeads, numKvHeads, headDim);
488
+ } else {
489
+ // JS fallback attention
490
+ attnOut = new Float32Array(hiddenDim);
491
+ const gqaRatio = numHeads / numKvHeads;
492
+ for (let h = 0; h < numHeads; h++) {
493
+ const kvH = Math.floor(h / gqaRatio);
494
+ const qH = q.subarray(h * headDim, (h + 1) * headDim);
495
+ const scores = new Float32Array(seqLen);
496
+ for (let s = 0; s < seqLen; s++) {
497
+ const kH = kvCache[l].keys[s].subarray(kvH * headDim, (kvH + 1) * headDim);
498
+ let dot = 0;
499
+ for (let d = 0; d < headDim; d++) dot += qH[d] * kH[d];
500
+ scores[s] = dot / Math.sqrt(headDim);
501
+ }
502
+ const w = softmaxJS(scores);
503
+ for (let s = 0; s < seqLen; s++) {
504
+ const vH = kvCache[l].values[s].subarray(kvH * headDim, (kvH + 1) * headDim);
505
+ for (let d = 0; d < headDim; d++) attnOut[h * headDim + d] += w[s] * vH[d];
506
  }
507
  }
508
  }
509
 
510
+ // 6. O projection + residual
511
+ const projected = o.matVec(ly.oProj, attnOut, hiddenDim, hiddenDim);
512
+ const postAttn = o.add(x, projected);
513
+
514
+ // 7. FFN: norm β†’ gate/up β†’ fusedSiluMul β†’ down β†’ residual
515
+ const ffnIn = o.rmsNorm(postAttn, ly.ffnNorm, rmsNormEps);
516
+ const gate = o.matVec(ly.gateProj, ffnIn, intermediateSize, hiddenDim);
517
+ const up = o.matVec(ly.upProj, ffnIn, intermediateSize, hiddenDim);
518
+ const activated = o.fusedSiluMul(gate, up);
519
+ const down = o.matVec(ly.downProj, activated, hiddenDim, intermediateSize);
520
+ x = o.add(postAttn, down);
 
 
 
 
 
 
 
 
 
521
  }
522
 
523
+ // Sample only after prefill
524
  if (step >= inputTokens.length - 1) {
525
+ const finalNormed = o.rmsNorm(x, model.outputNorm, rmsNormEps);
526
+ const logits = o.matVec(model.outputWeight, finalNormed, CONFIG.vocabSize, hiddenDim);
 
527
 
528
  // Temperature sampling
529
+ for (let i = 0; i < logits.length; i++) logits[i] /= 0.7;
530
+ const probs = o.softmax(logits);
 
531
 
532
+ // Top-p nucleus sampling
533
  const indexed = Array.from(probs).map((p, i) => ({ p, i })).sort((a, b) => b.p - a.p);
534
+ let cumP = 0, chosen = indexed[0].i;
 
535
  const r = Math.random();
536
  for (const { p, i } of indexed) {
537
  cumP += p;
 
539
  if (cumP > 0.9) break;
540
  }
541
 
542
+ tokenTimes.push(performance.now() - tokenStart);
 
 
543
  if (chosen === CONFIG.eosToken) break;
544
  allTokens.push(chosen);
545
  }
546
  }
547
 
548
  const totalTime = performance.now() - t0;
549
+ const genTokens = allTokens.slice(inputTokens.length);
550
+ const text = model.tokenizer.decode(genTokens);
551
+ const avgMs = tokenTimes.length > 0 ? tokenTimes.reduce((a, b) => a + b, 0) / tokenTimes.length : 0;
552
 
553
  return {
554
  text,
555
+ tokens: genTokens.length,
556
  totalTimeMs: Math.round(totalTime),
557
+ avgTokenMs: Math.round(avgMs),
558
  prefillTokens: inputTokens.length,
559
+ engine: `Aether ${simd ? 'WASM-SIMD' : 'JS-fallback'}`,
560
+ simd: !!simd,
561
  };
562
  }
563
 
564
  // ─── HTTP Server ────────────────────────────────────────────────────────────
565
+ const server = createServer((req, res) => {
566
+ if (req.method === 'POST' && req.url === '/generate') {
567
+ let body = '';
568
+ req.on('data', c => body += c);
569
+ req.on('end', () => {
570
+ try {
571
+ const { prompt, max_tokens } = JSON.parse(body);
572
+ const result = generate(prompt, max_tokens || 50);
573
+ res.writeHead(200, { 'Content-Type': 'application/json' });
574
+ res.end(JSON.stringify(result));
575
+ } catch (e) {
576
+ res.writeHead(500, { 'Content-Type': 'application/json' });
577
+ res.end(JSON.stringify({ error: e.message, stack: e.stack }));
578
+ }
579
+ });
580
+ } else if (req.url === '/health') {
581
+ res.writeHead(200, { 'Content-Type': 'application/json' });
582
+ res.end(JSON.stringify({ status: 'ok', model: model ? 'loaded' : 'not loaded', simd: !!simd, loadTime: model?.loadTime }));
583
+ } else { res.writeHead(404); res.end(); }
584
+ });
 
 
 
 
 
 
 
 
 
585
 
586
  // ─── Main ───────────────────────────────────────────────────────────────────
587
+ const ggufPath = process.env.GGUF_PATH || '/tmp/hf_cache/buleyean-smollm2-360m-q4_k_m.gguf';
588
+ const tokenizerPath = process.env.TOKENIZER_PATH || '/tmp/hf_cache/tokenizer.json';
589
 
590
+ async function main() {
591
+ // Load WASM SIMD first
592
+ simd = await loadSIMD();
593
+
594
+ // Download model files
595
+ if (!existsSync(ggufPath)) {
596
+ console.log('[Aether] Downloading Q4_K_M GGUF...');
597
+ execSync(`python3 -c "from huggingface_hub import hf_hub_download; hf_hub_download('forkjoin-ai/buleyean-smollm2-360m', 'buleyean-smollm2-360m-q4_k_m.gguf', cache_dir='/tmp/hf_cache', local_dir='/tmp/hf_cache')"`, { stdio: 'inherit' });
598
+ }
599
+ if (!existsSync(tokenizerPath)) {
600
+ console.log('[Aether] Downloading tokenizer...');
601
+ execSync(`python3 -c "from huggingface_hub import hf_hub_download; hf_hub_download('HuggingFaceTB/SmolLM2-360M-Instruct', 'tokenizer.json', cache_dir='/tmp/hf_cache', local_dir='/tmp/hf_cache')"`, { stdio: 'inherit' });
602
+ }
603
+
604
+ loadModel(ggufPath, tokenizerPath);
605
+
606
+ server.listen(PORT, '127.0.0.1', () => {
607
+ console.log(`[Aether] Server on http://127.0.0.1:${PORT} (SIMD: ${simd ? 'YES' : 'NO'})`);
608
+ });
609
  }
610
 
611
+ main().catch(e => { console.error('[Aether] Fatal:', e); process.exit(1); });
 
app.py CHANGED
@@ -63,7 +63,7 @@ def gen_pytorch(prompt):
63
  with torch.no_grad():
64
  outputs = base_model.generate(
65
  **inputs,
66
- max_new_tokens=100,
67
  temperature=0.7,
68
  top_p=0.9,
69
  do_sample=True,
@@ -80,7 +80,7 @@ def gen_pytorch(prompt):
80
  def gen_aether(prompt):
81
  """Generate with Aether (our engine)"""
82
  try:
83
- data = json.dumps({"prompt": prompt, "max_tokens": 100}).encode()
84
  req = urllib.request.Request(
85
  "http://127.0.0.1:7861/generate",
86
  data=data,
 
63
  with torch.no_grad():
64
  outputs = base_model.generate(
65
  **inputs,
66
+ max_new_tokens=50,
67
  temperature=0.7,
68
  top_p=0.9,
69
  do_sample=True,
 
80
  def gen_aether(prompt):
81
  """Generate with Aether (our engine)"""
82
  try:
83
+ data = json.dumps({"prompt": prompt, "max_tokens": 50}).encode()
84
  req = urllib.request.Request(
85
  "http://127.0.0.1:7861/generate",
86
  data=data,
simd-kernels.wasm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a05084c8998119797c6e80927678ce007e3285b78c6e7e8feee223ca4bb13636
3
+ size 14553