NOT-OMEGA commited on
Commit
472497b
·
verified ·
1 Parent(s): 21f4792

Upload 4 files

Browse files
Files changed (4) hide show
  1. index.html +693 -0
  2. inference.cpp +409 -0
  3. main.py +152 -0
  4. tokenizer.bin +3 -0
index.html ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>SLM · Story Engine</title>
7
+ <link href="https://fonts.googleapis.com/css2?family=Playfair+Display:ital,wght@0,400;0,700;1,400&family=IBM+Plex+Mono:wght@300;400&display=swap" rel="stylesheet">
8
+ <style>
9
+ :root {
10
+ --ink: #1a1209;
11
+ --paper: #f5f0e8;
12
+ --aged: #e8e0cc;
13
+ --sepia: #8b6914;
14
+ --rust: #c0392b;
15
+ --green: #27ae60;
16
+ --shadow: rgba(26,18,9,0.15);
17
+ }
18
+
19
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
20
+
21
+ body {
22
+ background: var(--paper);
23
+ color: var(--ink);
24
+ font-family: 'Playfair Display', Georgia, serif;
25
+ min-height: 100vh;
26
+ display: flex;
27
+ flex-direction: column;
28
+ align-items: center;
29
+ padding: 40px 20px 80px;
30
+ background-image:
31
+ repeating-linear-gradient(
32
+ 0deg,
33
+ transparent,
34
+ transparent 27px,
35
+ rgba(139,105,20,0.08) 28px
36
+ );
37
+ background-size: 100% 28px;
38
+ }
39
+
40
+ /* ---- Status Badge ---- */
41
+ .status-badge {
42
+ position: fixed;
43
+ top: 20px;
44
+ right: 20px;
45
+ font-family: 'IBM Plex Mono', monospace;
46
+ font-size: 0.65rem;
47
+ padding: 6px 12px;
48
+ border-radius: 20px;
49
+ display: flex;
50
+ align-items: center;
51
+ gap: 6px;
52
+ z-index: 100;
53
+ transition: all 0.3s;
54
+ }
55
+ .status-badge.connected {
56
+ background: rgba(39, 174, 96, 0.15);
57
+ color: var(--green);
58
+ border: 1px solid var(--green);
59
+ }
60
+ .status-badge.disconnected {
61
+ background: rgba(192, 57, 43, 0.15);
62
+ color: var(--rust);
63
+ border: 1px solid var(--rust);
64
+ }
65
+ .status-dot {
66
+ width: 8px;
67
+ height: 8px;
68
+ border-radius: 50%;
69
+ animation: pulse 2s ease-in-out infinite;
70
+ }
71
+ .status-badge.connected .status-dot { background: var(--green); }
72
+ .status-badge.disconnected .status-dot { background: var(--rust); }
73
+ @keyframes pulse {
74
+ 0%, 100% { opacity: 1; }
75
+ 50% { opacity: 0.4; }
76
+ }
77
+
78
+ /* ---- Header ---- */
79
+ header {
80
+ text-align: center;
81
+ margin-bottom: 48px;
82
+ position: relative;
83
+ }
84
+ header::after {
85
+ content: '';
86
+ display: block;
87
+ width: 120px;
88
+ height: 2px;
89
+ margin: 16px auto 0;
90
+ background: linear-gradient(90deg, transparent, var(--sepia), transparent);
91
+ }
92
+ .masthead {
93
+ font-size: clamp(2.2rem, 6vw, 3.6rem);
94
+ font-weight: 700;
95
+ letter-spacing: -1px;
96
+ line-height: 1;
97
+ color: var(--ink);
98
+ }
99
+ .masthead em { color: var(--sepia); font-style: italic; }
100
+ .subtitle {
101
+ font-family: 'IBM Plex Mono', monospace;
102
+ font-size: 0.72rem;
103
+ font-weight: 300;
104
+ letter-spacing: 4px;
105
+ text-transform: uppercase;
106
+ color: var(--sepia);
107
+ margin-top: 10px;
108
+ }
109
+
110
+ /* ---- Card ---- */
111
+ .card {
112
+ width: 100%;
113
+ max-width: 760px;
114
+ background: #faf7f0;
115
+ border: 1px solid var(--aged);
116
+ border-radius: 2px;
117
+ box-shadow: 4px 4px 0 var(--shadow), 8px 8px 0 rgba(26,18,9,0.06);
118
+ padding: 36px 40px;
119
+ position: relative;
120
+ }
121
+ .card::before {
122
+ content: '';
123
+ position: absolute;
124
+ top: 0; left: 36px; right: 36px;
125
+ height: 3px;
126
+ background: linear-gradient(90deg, transparent, var(--sepia) 30%, var(--sepia) 70%, transparent);
127
+ opacity: 0.5;
128
+ }
129
+
130
+ /* ---- Performance Stats (NEW) ---- */
131
+ .perf-stats {
132
+ display: grid;
133
+ grid-template-columns: repeat(auto-fit, minmax(140px, 1fr));
134
+ gap: 12px;
135
+ margin-bottom: 24px;
136
+ padding: 16px;
137
+ background: rgba(139,105,20,0.04);
138
+ border-radius: 2px;
139
+ border: 1px solid var(--aged);
140
+ }
141
+ .stat-item {
142
+ text-align: center;
143
+ }
144
+ .stat-value {
145
+ font-family: 'IBM Plex Mono', monospace;
146
+ font-size: 1.4rem;
147
+ font-weight: 400;
148
+ color: var(--sepia);
149
+ line-height: 1;
150
+ margin-bottom: 4px;
151
+ }
152
+ .stat-label {
153
+ font-family: 'IBM Plex Mono', monospace;
154
+ font-size: 0.6rem;
155
+ letter-spacing: 1.5px;
156
+ text-transform: uppercase;
157
+ color: rgba(26,18,9,0.5);
158
+ }
159
+
160
+ /* ---- Controls ---- */
161
+ .controls-row {
162
+ display: flex;
163
+ gap: 24px;
164
+ margin-bottom: 20px;
165
+ flex-wrap: wrap;
166
+ }
167
+ .control-group {
168
+ display: flex;
169
+ flex-direction: column;
170
+ gap: 6px;
171
+ flex: 1;
172
+ min-width: 120px;
173
+ }
174
+ label {
175
+ font-family: 'IBM Plex Mono', monospace;
176
+ font-size: 0.68rem;
177
+ letter-spacing: 2px;
178
+ text-transform: uppercase;
179
+ color: var(--sepia);
180
+ font-weight: 400;
181
+ }
182
+ input[type="range"] {
183
+ -webkit-appearance: none;
184
+ width: 100%;
185
+ height: 2px;
186
+ background: var(--aged);
187
+ outline: none;
188
+ cursor: pointer;
189
+ }
190
+ input[type="range"]::-webkit-slider-thumb {
191
+ -webkit-appearance: none;
192
+ width: 14px; height: 14px;
193
+ border-radius: 50%;
194
+ background: var(--sepia);
195
+ border: 2px solid var(--paper);
196
+ box-shadow: 0 0 0 1px var(--sepia);
197
+ transition: transform 0.15s;
198
+ }
199
+ input[type="range"]:hover::-webkit-slider-thumb { transform: scale(1.3); }
200
+ input[type="range"]::-moz-range-thumb {
201
+ width: 14px; height: 14px;
202
+ border-radius: 50%;
203
+ background: var(--sepia);
204
+ border: 2px solid var(--paper);
205
+ box-shadow: 0 0 0 1px var(--sepia);
206
+ cursor: pointer;
207
+ }
208
+ .range-val {
209
+ font-family: 'IBM Plex Mono', monospace;
210
+ font-size: 0.75rem;
211
+ color: var(--ink);
212
+ font-weight: 400;
213
+ opacity: 0.7;
214
+ }
215
+
216
+ /* ---- Prompt area ---- */
217
+ .prompt-wrap {
218
+ position: relative;
219
+ margin-bottom: 20px;
220
+ }
221
+ .prompt-label {
222
+ font-family: 'IBM Plex Mono', monospace;
223
+ font-size: 0.68rem;
224
+ letter-spacing: 2px;
225
+ text-transform: uppercase;
226
+ color: var(--sepia);
227
+ margin-bottom: 8px;
228
+ display: block;
229
+ }
230
+ textarea {
231
+ width: 100%;
232
+ min-height: 90px;
233
+ resize: vertical;
234
+ background: transparent;
235
+ border: none;
236
+ border-bottom: 1px solid var(--aged);
237
+ font-family: 'Playfair Display', serif;
238
+ font-size: 1.05rem;
239
+ color: var(--ink);
240
+ line-height: 1.7;
241
+ padding: 8px 0;
242
+ outline: none;
243
+ transition: border-color 0.2s;
244
+ }
245
+ textarea::placeholder { color: rgba(26,18,9,0.3); font-style: italic; }
246
+ textarea:focus { border-bottom-color: var(--sepia); }
247
+
248
+ /* ---- Button ---- */
249
+ .btn-row { display: flex; gap: 12px; align-items: center; flex-wrap: wrap; }
250
+ button {
251
+ font-family: 'IBM Plex Mono', monospace;
252
+ font-size: 0.75rem;
253
+ letter-spacing: 3px;
254
+ text-transform: uppercase;
255
+ padding: 12px 32px;
256
+ border: 1.5px solid var(--ink);
257
+ background: var(--ink);
258
+ color: var(--paper);
259
+ cursor: pointer;
260
+ transition: all 0.18s;
261
+ border-radius: 1px;
262
+ }
263
+ button:hover:not(:disabled) {
264
+ background: var(--sepia);
265
+ border-color: var(--sepia);
266
+ }
267
+ button:disabled { opacity: 0.4; cursor: not-allowed; }
268
+ .btn-clear {
269
+ background: transparent;
270
+ color: var(--ink);
271
+ padding: 12px 20px;
272
+ font-size: 0.68rem;
273
+ }
274
+ .btn-clear:hover:not(:disabled) {
275
+ background: transparent;
276
+ color: var(--rust);
277
+ border-color: var(--rust);
278
+ }
279
+
280
+ /* ---- Output ---- */
281
+ .output-section { margin-top: 32px; }
282
+ .output-header {
283
+ display: flex;
284
+ justify-content: space-between;
285
+ align-items: baseline;
286
+ margin-bottom: 12px;
287
+ border-bottom: 1px solid var(--aged);
288
+ padding-bottom: 8px;
289
+ }
290
+ .output-title {
291
+ font-family: 'IBM Plex Mono', monospace;
292
+ font-size: 0.68rem;
293
+ letter-spacing: 2px;
294
+ text-transform: uppercase;
295
+ color: var(--sepia);
296
+ }
297
+ .meta-chips {
298
+ display: flex;
299
+ gap: 12px;
300
+ font-family: 'IBM Plex Mono', monospace;
301
+ font-size: 0.65rem;
302
+ color: rgba(26,18,9,0.45);
303
+ flex-wrap: wrap;
304
+ }
305
+ #output {
306
+ font-size: 1.05rem;
307
+ line-height: 1.85;
308
+ min-height: 80px;
309
+ color: var(--ink);
310
+ white-space: pre-wrap;
311
+ word-break: break-word;
312
+ }
313
+ #output .prompt-part { color: rgba(26,18,9,0.5); }
314
+ #output .gen-part { color: var(--ink); }
315
+
316
+ /* Typewriter cursor */
317
+ .cursor {
318
+ display: inline-block;
319
+ width: 2px;
320
+ height: 1.1em;
321
+ background: var(--sepia);
322
+ vertical-align: text-bottom;
323
+ margin-left: 2px;
324
+ animation: blink 0.9s step-end infinite;
325
+ }
326
+ @keyframes blink { 50% { opacity: 0; } }
327
+
328
+ /* ---- Spinner ---- */
329
+ .spinner {
330
+ display: none;
331
+ width: 16px; height: 16px;
332
+ border: 2px solid var(--aged);
333
+ border-top-color: var(--sepia);
334
+ border-radius: 50%;
335
+ animation: spin 0.7s linear infinite;
336
+ margin-left: 8px;
337
+ }
338
+ @keyframes spin { to { transform: rotate(360deg); } }
339
+
340
+ /* ---- Error ---- */
341
+ .error-msg {
342
+ display: none;
343
+ font-family: 'IBM Plex Mono', monospace;
344
+ font-size: 0.8rem;
345
+ color: var(--rust);
346
+ margin-top: 12px;
347
+ padding: 10px 14px;
348
+ border-left: 3px solid var(--rust);
349
+ background: rgba(192,57,43,0.05);
350
+ }
351
+
352
+ /* ---- Example prompts ---- */
353
+ .examples {
354
+ margin-top: 28px;
355
+ padding-top: 20px;
356
+ border-top: 1px dashed var(--aged);
357
+ }
358
+ .ex-label {
359
+ font-family: 'IBM Plex Mono', monospace;
360
+ font-size: 0.65rem;
361
+ letter-spacing: 2px;
362
+ text-transform: uppercase;
363
+ color: rgba(139,105,20,0.6);
364
+ margin-bottom: 10px;
365
+ }
366
+ .ex-pills {
367
+ display: flex;
368
+ flex-wrap: wrap;
369
+ gap: 8px;
370
+ }
371
+ .ex-pill {
372
+ font-family: 'Playfair Display', serif;
373
+ font-size: 0.82rem;
374
+ font-style: italic;
375
+ padding: 5px 14px;
376
+ border: 1px solid var(--aged);
377
+ border-radius: 2px;
378
+ cursor: pointer;
379
+ color: rgba(26,18,9,0.6);
380
+ transition: all 0.15s;
381
+ background: transparent;
382
+ letter-spacing: 0;
383
+ text-transform: none;
384
+ }
385
+ .ex-pill:hover {
386
+ border-color: var(--sepia);
387
+ color: var(--sepia);
388
+ background: rgba(139,105,20,0.04);
389
+ }
390
+
391
+ /* ---- Footer ---- */
392
+ footer {
393
+ margin-top: 48px;
394
+ font-family: 'IBM Plex Mono', monospace;
395
+ font-size: 0.63rem;
396
+ letter-spacing: 1.5px;
397
+ text-transform: uppercase;
398
+ color: rgba(26,18,9,0.3);
399
+ text-align: center;
400
+ }
401
+ footer span { color: var(--sepia); }
402
+
403
+ /* ---- Mobile responsiveness ---- */
404
+ @media (max-width: 640px) {
405
+ .controls-row { flex-direction: column; }
406
+ .perf-stats { grid-template-columns: 1fr 1fr; }
407
+ .status-badge { top: 10px; right: 10px; font-size: 0.6rem; }
408
+ }
409
+ </style>
410
+ </head>
411
+ <body>
412
+
413
+ <!-- Status Badge -->
414
+ <div class="status-badge disconnected" id="status-badge">
415
+ <div class="status-dot"></div>
416
+ <span id="status-text">Disconnected</span>
417
+ </div>
418
+
419
+ <header>
420
+ <h1 class="masthead">The Story <em>Engine</em></h1>
421
+ <p class="subtitle">Custom SLM &nbsp;·&nbsp; C++ CPU Inference &nbsp;·&nbsp; GPT-2 Architecture</p>
422
+ </header>
423
+
424
+ <div class="card">
425
+
426
+ <!-- Performance Stats -->
427
+ <div class="perf-stats" id="perf-stats" style="display:none">
428
+ <div class="stat-item">
429
+ <div class="stat-value" id="stat-throughput">—</div>
430
+ <div class="stat-label">Tokens/Sec</div>
431
+ </div>
432
+ <div class="stat-item">
433
+ <div class="stat-value" id="stat-latency">—</div>
434
+ <div class="stat-label">ms/Token</div>
435
+ </div>
436
+ <div class="stat-item">
437
+ <div class="stat-value" id="stat-total">0</div>
438
+ <div class="stat-label">Total Tokens</div>
439
+ </div>
440
+ </div>
441
+
442
+ <div class="controls-row">
443
+ <div class="control-group">
444
+ <label>Max Tokens <span class="range-val" id="max-tokens-val">100</span></label>
445
+ <input type="range" id="max-tokens" min="20" max="400" value="100" step="10">
446
+ </div>
447
+ <div class="control-group">
448
+ <label>Temperature <span class="range-val" id="temp-val">0.8</span></label>
449
+ <input type="range" id="temperature" min="0.1" max="1.5" value="0.8" step="0.05">
450
+ </div>
451
+ <div class="control-group">
452
+ <label>Top-K <span class="range-val" id="topk-val">40</span></label>
453
+ <input type="range" id="topk" min="1" max="100" value="40" step="1">
454
+ </div>
455
+ </div>
456
+
457
+ <div class="prompt-wrap">
458
+ <span class="prompt-label">Your Prompt</span>
459
+ <textarea id="prompt" rows="3"
460
+ placeholder="Once upon a time, in a small village near the forest…"></textarea>
461
+ </div>
462
+
463
+ <div class="btn-row">
464
+ <button id="generate-btn" onclick="generate()">Generate</button>
465
+ <button class="btn-clear" onclick="clearOutput()">Clear</button>
466
+ <div class="spinner" id="spinner"></div>
467
+ </div>
468
+
469
+ <div class="error-msg" id="error-msg"></div>
470
+
471
+ <div class="output-section" id="output-section" style="display:none">
472
+ <div class="output-header">
473
+ <span class="output-title">Generated Story</span>
474
+ <div class="meta-chips">
475
+ <span id="meta-tokens"></span>
476
+ <span id="meta-latency"></span>
477
+ <span id="meta-speed"></span>
478
+ </div>
479
+ </div>
480
+ <div id="output"></div>
481
+ </div>
482
+
483
+ <div class="examples">
484
+ <p class="ex-label">Try these prompts</p>
485
+ <div class="ex-pills">
486
+ <button class="ex-pill" onclick="setPrompt(this)">Once upon a time, there was a little</button>
487
+ <button class="ex-pill" onclick="setPrompt(this)">The big dog was very angry because</button>
488
+ <button class="ex-pill" onclick="setPrompt(this)">Sara and Tom went to the park to</button>
489
+ <button class="ex-pill" onclick="setPrompt(this)">One day, a tiny dragon found a</button>
490
+ <button class="ex-pill" onclick="setPrompt(this)">The old wizard smiled and said,</button>
491
+ </div>
492
+ </div>
493
+
494
+ </div>
495
+
496
+ <footer>
497
+ Built with &nbsp;<span>C++ Inference Engine</span>&nbsp; + &nbsp;<span>FastAPI</span>&nbsp; + &nbsp;<span>tiktoken</span>
498
+ </footer>
499
+
500
+ <script>
501
+ const API_BASE = "";;
502
+
503
+ // ---- Performance tracking ----
504
+ let totalTokensGenerated = 0;
505
+ let avgThroughput = 0;
506
+ let avgLatencyPerToken = 0;
507
+ let numGenerations = 0;
508
+
509
+ // ---- Check server status on load ----
510
+ async function checkHealth() {
511
+ try {
512
+ const res = await fetch(`${API_BASE}/health`);
513
+ if (res.ok) {
514
+ const data = await res.json();
515
+ updateStatus(true, data);
516
+ } else {
517
+ updateStatus(false);
518
+ }
519
+ } catch {
520
+ updateStatus(false);
521
+ }
522
+ }
523
+
524
+ function updateStatus(connected, data = null) {
525
+ const badge = document.getElementById('status-badge');
526
+ const text = document.getElementById('status-text');
527
+
528
+ if (connected) {
529
+ badge.className = 'status-badge connected';
530
+ text.textContent = 'Connected';
531
+
532
+ // Show model info if available
533
+ if (data && data.model_config) {
534
+ const cfg = data.model_config;
535
+ console.log(`Model: ${cfg.n_layer}L/${cfg.n_head}H/${cfg.n_embd}D, Vocab: ${cfg.vocab_size}`);
536
+ }
537
+ } else {
538
+ badge.className = 'status-badge disconnected';
539
+ text.textContent = 'Disconnected';
540
+ }
541
+ }
542
+
543
+ // Check health on load and every 30s
544
+ checkHealth();
545
+ setInterval(checkHealth, 30000);
546
+
547
+ // ---- Sync sliders ----
548
+ document.getElementById('max-tokens').addEventListener('input', e => {
549
+ document.getElementById('max-tokens-val').textContent = e.target.value;
550
+ });
551
+ document.getElementById('temperature').addEventListener('input', e => {
552
+ document.getElementById('temp-val').textContent = parseFloat(e.target.value).toFixed(2);
553
+ });
554
+ document.getElementById('topk').addEventListener('input', e => {
555
+ document.getElementById('topk-val').textContent = e.target.value;
556
+ });
557
+
558
+ // ---- Generate ----
559
+ async function generate() {
560
+ const prompt = document.getElementById('prompt').value.trim();
561
+ if (!prompt) { showError("Please enter a prompt first."); return; }
562
+
563
+ const maxTokens = parseInt(document.getElementById('max-tokens').value);
564
+ const temperature = parseFloat(document.getElementById('temperature').value);
565
+ const topK = parseInt(document.getElementById('topk').value);
566
+
567
+ setLoading(true);
568
+ hideError();
569
+
570
+ try {
571
+ const res = await fetch(`${API_BASE}/generate`, {
572
+ method: 'POST',
573
+ headers: { 'Content-Type': 'application/json' },
574
+ body: JSON.stringify({
575
+ prompt,
576
+ max_tokens: maxTokens,
577
+ temperature,
578
+ top_k: topK,
579
+ }),
580
+ });
581
+
582
+ if (!res.ok) {
583
+ const err = await res.json();
584
+ throw new Error(err.detail || `Server error: ${res.status}`);
585
+ }
586
+
587
+ const data = await res.json();
588
+ renderOutput(data);
589
+ updatePerfStats(data);
590
+
591
+ } catch (e) {
592
+ showError(e.message.includes('fetch')
593
+ ? 'Cannot connect to server. Is uvicorn running on port 8000?'
594
+ : e.message
595
+ );
596
+ } finally {
597
+ setLoading(false);
598
+ }
599
+ }
600
+
601
+ // ---- Update performance stats ----
602
+ function updatePerfStats(data) {
603
+ totalTokensGenerated += data.tokens_out;
604
+ numGenerations++;
605
+
606
+ const throughput = (data.tokens_out / (data.latency_ms / 1000)).toFixed(1);
607
+ const latencyPerToken = (data.latency_ms / data.tokens_out).toFixed(2);
608
+
609
+ // Running average
610
+ avgThroughput = ((avgThroughput * (numGenerations - 1)) + parseFloat(throughput)) / numGenerations;
611
+ avgLatencyPerToken = ((avgLatencyPerToken * (numGenerations - 1)) + parseFloat(latencyPerToken)) / numGenerations;
612
+
613
+ document.getElementById('stat-throughput').textContent = avgThroughput.toFixed(1);
614
+ document.getElementById('stat-latency').textContent = avgLatencyPerToken.toFixed(2);
615
+ document.getElementById('stat-total').textContent = totalTokensGenerated;
616
+ document.getElementById('perf-stats').style.display = 'grid';
617
+ }
618
+
619
+ // ---- Typewriter render ----
620
+ function renderOutput(data) {
621
+ const section = document.getElementById('output-section');
622
+ const out = document.getElementById('output');
623
+
624
+ section.style.display = 'block';
625
+
626
+ const tokensPerSec = (data.tokens_out / (data.latency_ms / 1000)).toFixed(1);
627
+
628
+ document.getElementById('meta-tokens').textContent =
629
+ `${data.tokens_in} in · ${data.tokens_out} out`;
630
+ document.getElementById('meta-latency').textContent =
631
+ `${data.latency_ms.toFixed(0)} ms`;
632
+ document.getElementById('meta-speed').textContent =
633
+ `${tokensPerSec} tok/s`;
634
+
635
+ const genText = data.generated_text;
636
+ out.innerHTML =
637
+ `<span class="prompt-part">${escHtml(data.prompt)}</span>` +
638
+ `<span class="gen-part" id="typewriter"></span>` +
639
+ `<span class="cursor" id="cursor"></span>`;
640
+
641
+ let i = 0;
642
+ const typed = document.getElementById('typewriter');
643
+ const speed = Math.max(10, Math.min(40, 3000 / genText.length));
644
+
645
+ function tick() {
646
+ if (i < genText.length) {
647
+ typed.textContent += genText[i++];
648
+ setTimeout(tick, speed);
649
+ } else {
650
+ const cursor = document.getElementById('cursor');
651
+ if (cursor) cursor.remove();
652
+ }
653
+ }
654
+ tick();
655
+ }
656
+
657
+ function clearOutput() {
658
+ document.getElementById('output-section').style.display = 'none';
659
+ document.getElementById('output').innerHTML = '';
660
+ hideError();
661
+ }
662
+
663
+ function setPrompt(el) {
664
+ document.getElementById('prompt').value = el.textContent;
665
+ document.getElementById('prompt').focus();
666
+ }
667
+
668
+ function setLoading(on) {
669
+ document.getElementById('generate-btn').disabled = on;
670
+ document.getElementById('spinner').style.display = on ? 'inline-block' : 'none';
671
+ }
672
+
673
+ function showError(msg) {
674
+ const el = document.getElementById('error-msg');
675
+ el.textContent = msg;
676
+ el.style.display = 'block';
677
+ }
678
+ function hideError() {
679
+ document.getElementById('error-msg').style.display = 'none';
680
+ }
681
+
682
+ function escHtml(s) {
683
+ return s.replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;');
684
+ }
685
+
686
+ // Keyboard shortcut: Ctrl/Cmd + Enter to generate
687
+ document.getElementById('prompt').addEventListener('keydown', e => {
688
+ if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') generate();
689
+ });
690
+ </script>
691
+
692
+ </body>
693
+ </html>
inference.cpp ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * OPTIMIZED SLM 50M INFERENCE ENGINE
3
+ * Target: i3 11th Gen | Windows 11 | 8GB RAM
4
+ * OpenMP Parallel + AVX2 Auto Vectorized
5
+ */
6
+
7
+ #include <stdio.h>
8
+ #include <stdlib.h>
9
+ #include <math.h>
10
+ #include <string.h>
11
+ #include <time.h>
12
+ #include <vector>
13
+ #include <algorithm>
14
+ #include <immintrin.h> // REQUIRED FOR AVX2 SIMD
15
+
16
+ #ifdef _OPENMP
17
+ #include <omp.h>
18
+ #endif
19
+
20
+ // ---------------------------------------------------------------------------
21
+ // Config & Structures
22
+ // ---------------------------------------------------------------------------
23
+
24
+ typedef struct {
25
+ int n_layer;
26
+ int n_head;
27
+ int n_embd;
28
+ int block_size;
29
+ int vocab_size;
30
+ } Config;
31
+
32
+ typedef struct {
33
+ float* wte; float* wpe;
34
+ float** ln1_w; float** ln1_b;
35
+ float** c_attn_w; float** c_attn_b;
36
+ float** c_proj_w; float** c_proj_b;
37
+ float** ln2_w; float** ln2_b;
38
+ float** fc_w; float** fc_b;
39
+ float** mlp_proj_w; float** mlp_proj_b;
40
+ float* ln_f_w; float* ln_f_b;
41
+ float* lm_head_w;
42
+ } Weights;
43
+
44
+ typedef struct { float* k_cache; float* v_cache; } KVCache;
45
+
46
+ static Config cfg;
47
+ static Weights W;
48
+ static float* model_data_buffer = NULL;
49
+
50
+ // ---------------------------------------------------------------------------
51
+ // Math Kernels
52
+ // ---------------------------------------------------------------------------
53
+
54
+ static void layer_norm(float* out, const float* x, const float* w, const float* b, int size) {
55
+ float mean = 0.0f, var = 0.0f;
56
+
57
+ for (int i = 0; i < size; i++) mean += x[i];
58
+ mean /= size;
59
+
60
+ for (int i = 0; i < size; i++) {
61
+ float d = x[i] - mean;
62
+ var += d * d;
63
+ }
64
+ var /= size;
65
+
66
+ float scale = 1.0f / sqrtf(var + 1e-5f);
67
+
68
+ for (int i = 0; i < size; i++)
69
+ out[i] = (x[i] - mean) * scale * w[i] + b[i];
70
+ }
71
+
72
+ // OpenMP + AVX2 + FMA parallelized matmul
73
+ static void matmul_vec(float* out, const float* mat, const float* x, int M, int K) {
74
+
75
+ #pragma omp parallel for
76
+ for (int i = 0; i < M; i++) {
77
+ const float* row = mat + (long long)i * K;
78
+
79
+ // Initialize a 256-bit vector with all zeros
80
+ __m256 sum_vec = _mm256_setzero_ps();
81
+
82
+ int j = 0;
83
+ // Process 8 floats at a time
84
+ for (; j <= K - 8; j += 8) {
85
+ // Load 8 floats from the matrix row and the input vector
86
+ __m256 m_val = _mm256_loadu_ps(&row[j]);
87
+ __m256 x_val = _mm256_loadu_ps(&x[j]);
88
+
89
+ // FMA (Fused Multiply-Add): sum_vec += m_val * x_val
90
+ sum_vec = _mm256_fmadd_ps(m_val, x_val, sum_vec);
91
+ }
92
+
93
+ // Extract the 8 floats back out and sum them horizontally
94
+ float sum_arr[8];
95
+ _mm256_storeu_ps(sum_arr, sum_vec);
96
+ float sum = sum_arr[0] + sum_arr[1] + sum_arr[2] + sum_arr[3] +
97
+ sum_arr[4] + sum_arr[5] + sum_arr[6] + sum_arr[7];
98
+
99
+ // Handle any leftover elements if K is not a multiple of 8
100
+ for (; j < K; j++) {
101
+ sum += row[j] * x[j];
102
+ }
103
+
104
+ out[i] = sum;
105
+ }
106
+ }
107
+
108
+ static void add_bias(float* x, const float* b, int N) {
109
+ #pragma omp parallel for
110
+ for (int i = 0; i < N; i++)
111
+ x[i] += b[i];
112
+ }
113
+
114
+ static void residual_add(float* x, const float* y, int N) {
115
+ #pragma omp parallel for
116
+ for (int i = 0; i < N; i++)
117
+ x[i] += y[i];
118
+ }
119
+
120
+ static void gelu_inplace(float* x, int N) {
121
+ const float c = 0.7978845608f;
122
+
123
+ #pragma omp parallel for
124
+ for (int i = 0; i < N; i++) {
125
+ float v = x[i];
126
+ float t = tanhf(c * (v + 0.044715f * v * v * v));
127
+ x[i] = 0.5f * v * (1.0f + t);
128
+ }
129
+ }
130
+
131
+ static void softmax_inplace(float* x, int N) {
132
+
133
+ float max_val = x[0];
134
+ for (int i = 1; i < N; i++)
135
+ if (x[i] > max_val) max_val = x[i];
136
+
137
+ float sum = 0.0f;
138
+ for (int i = 0; i < N; i++) {
139
+ x[i] = expf(x[i] - max_val);
140
+ sum += x[i];
141
+ }
142
+
143
+ for (int i = 0; i < N; i++)
144
+ x[i] /= sum;
145
+ }
146
+
147
+ // ---------------------------------------------------------------------------
148
+ // Transformer Forward
149
+ // ---------------------------------------------------------------------------
150
+
151
+ static void forward(
152
+ int token_id,
153
+ int pos,
154
+ KVCache* kv,
155
+ float* x,
156
+ float* buf,
157
+ float* qkv_buf,
158
+ float* attn_buf,
159
+ float* ff_buf,
160
+ float* logits
161
+ ) {
162
+ const int C = cfg.n_embd;
163
+ const int H = cfg.n_head;
164
+ const int hs = C / H;
165
+
166
+ float* content_row = W.wte + (long long)token_id * C;
167
+ float* pos_row = W.wpe + (long long)pos * C;
168
+
169
+ #pragma omp parallel for
170
+ for (int i = 0; i < C; i++)
171
+ x[i] = content_row[i] + pos_row[i];
172
+
173
+ for (int l = 0; l < cfg.n_layer; l++) {
174
+
175
+ layer_norm(buf, x, W.ln1_w[l], W.ln1_b[l], C);
176
+
177
+ matmul_vec(qkv_buf, W.c_attn_w[l], buf, 3 * C, C);
178
+ add_bias(qkv_buf, W.c_attn_b[l], 3 * C);
179
+
180
+ float* q = qkv_buf;
181
+ float* k = qkv_buf + C;
182
+ float* v = qkv_buf + 2 * C;
183
+
184
+ float* k_cache = kv->k_cache + (long long)l * cfg.block_size * C;
185
+ float* v_cache = kv->v_cache + (long long)l * cfg.block_size * C;
186
+
187
+ memcpy(k_cache + (long long)pos * C, k, C * sizeof(float));
188
+ memcpy(v_cache + (long long)pos * C, v, C * sizeof(float));
189
+
190
+ #pragma omp parallel for
191
+ for (int h = 0; h < H; h++) {
192
+
193
+ float* q_h = q + h * hs;
194
+ float scale = 1.0f / sqrtf((float)hs);
195
+
196
+ // Give each thread its own slice of the attention buffer
197
+ float* local_attn = attn_buf + h * cfg.block_size;
198
+
199
+ for (int t = 0; t <= pos; t++) {
200
+ float* k_h = k_cache + (long long)t * C + h * hs;
201
+ float dot = 0.0f;
202
+ for (int d = 0; d < hs; d++)
203
+ dot += q_h[d] * k_h[d];
204
+
205
+ local_attn[t] = dot * scale;
206
+ }
207
+
208
+ softmax_inplace(local_attn, pos + 1);
209
+
210
+ float* out_h = buf + h * hs;
211
+ memset(out_h, 0, hs * sizeof(float));
212
+
213
+ for (int t = 0; t <= pos; t++) {
214
+ float* v_h = v_cache + (long long)t * C + h * hs;
215
+ float a = local_attn[t];
216
+ for (int d = 0; d < hs; d++)
217
+ out_h[d] += a * v_h[d];
218
+ }
219
+ }
220
+
221
+ float* attn_out = qkv_buf;
222
+ matmul_vec(attn_out, W.c_proj_w[l], buf, C, C);
223
+ add_bias(attn_out, W.c_proj_b[l], C);
224
+ residual_add(x, attn_out, C);
225
+
226
+ layer_norm(buf, x, W.ln2_w[l], W.ln2_b[l], C);
227
+
228
+ matmul_vec(ff_buf, W.fc_w[l], buf, 4 * C, C);
229
+ add_bias(ff_buf, W.fc_b[l], 4 * C);
230
+ gelu_inplace(ff_buf, 4 * C);
231
+
232
+ matmul_vec(buf, W.mlp_proj_w[l], ff_buf, C, 4 * C);
233
+ add_bias(buf, W.mlp_proj_b[l], C);
234
+ residual_add(x, buf, C);
235
+ }
236
+
237
+ layer_norm(buf, x, W.ln_f_w, W.ln_f_b, C);
238
+ matmul_vec(logits, W.lm_head_w, buf, cfg.vocab_size, C);
239
+ }
240
+
241
+ // ---------------------------------------------------------------------------
242
+ // Weight Mapping
243
+ // ---------------------------------------------------------------------------
244
+
245
+ static void map_weights(float* data) {
246
+
247
+ float* ptr = data;
248
+ const int C = cfg.n_embd;
249
+ const int L = cfg.n_layer;
250
+
251
+ W.wte = ptr; ptr += (long long)cfg.vocab_size * C;
252
+ W.wpe = ptr; ptr += (long long)cfg.block_size * C;
253
+
254
+ W.ln1_w = (float**)malloc(L * sizeof(float*));
255
+ W.ln1_b = (float**)malloc(L * sizeof(float*));
256
+ W.c_attn_w = (float**)malloc(L * sizeof(float*));
257
+ W.c_attn_b = (float**)malloc(L * sizeof(float*));
258
+ W.c_proj_w = (float**)malloc(L * sizeof(float*));
259
+ W.c_proj_b = (float**)malloc(L * sizeof(float*));
260
+ W.ln2_w = (float**)malloc(L * sizeof(float*));
261
+ W.ln2_b = (float**)malloc(L * sizeof(float*));
262
+ W.fc_w = (float**)malloc(L * sizeof(float*));
263
+ W.fc_b = (float**)malloc(L * sizeof(float*));
264
+ W.mlp_proj_w = (float**)malloc(L * sizeof(float*));
265
+ W.mlp_proj_b = (float**)malloc(L * sizeof(float*));
266
+
267
+ for (int l = 0; l < L; l++) {
268
+ W.ln1_w[l] = ptr; ptr += C;
269
+ W.ln1_b[l] = ptr; ptr += C;
270
+
271
+ W.c_attn_w[l] = ptr; ptr += 3LL * C * C;
272
+ W.c_attn_b[l] = ptr; ptr += 3LL * C;
273
+
274
+ W.c_proj_w[l] = ptr; ptr += 1LL * C * C;
275
+ W.c_proj_b[l] = ptr; ptr += C;
276
+
277
+ W.ln2_w[l] = ptr; ptr += C;
278
+ W.ln2_b[l] = ptr; ptr += C;
279
+
280
+ W.fc_w[l] = ptr; ptr += 4LL * C * C;
281
+ W.fc_b[l] = ptr; ptr += 4LL * C;
282
+
283
+ W.mlp_proj_w[l] = ptr; ptr += 1LL * C * 4 * C;
284
+ W.mlp_proj_b[l] = ptr; ptr += C;
285
+ }
286
+
287
+ W.ln_f_w = ptr; ptr += C;
288
+ W.ln_f_b = ptr; ptr += C;
289
+
290
+ W.lm_head_w = ptr;
291
+ }
292
+
293
+ // ---------------------------------------------------------------------------
294
+ // MAIN
295
+ // ---------------------------------------------------------------------------
296
+
297
+ int main(int argc, char* argv[]) {
298
+
299
+ if (argc < 3) {
300
+ printf("ERROR_ARGS");
301
+ return 1;
302
+ }
303
+
304
+ FILE* f = fopen("model.bin", "rb");
305
+ if (!f) {
306
+ printf("ERROR_MODEL_NOT_FOUND");
307
+ return 1;
308
+ }
309
+
310
+ fread(&cfg, sizeof(int), 5, f);
311
+ fseek(f, 0, SEEK_END);
312
+ long file_size = ftell(f);
313
+ fseek(f, 5 * sizeof(int), SEEK_SET);
314
+
315
+ model_data_buffer = (float*)malloc(file_size - 5 * sizeof(int));
316
+ fread(model_data_buffer, 1, file_size - 5 * sizeof(int), f);
317
+ fclose(f);
318
+
319
+ map_weights(model_data_buffer);
320
+
321
+ std::vector<int> input_ids;
322
+ char* token = strtok(argv[1], ",");
323
+ while (token) {
324
+ input_ids.push_back(atoi(token));
325
+ token = strtok(NULL, ",");
326
+ }
327
+
328
+ if (input_ids.size() >= (size_t)cfg.block_size)
329
+ input_ids.resize(cfg.block_size - 1);
330
+
331
+ int max_new_tokens = atoi(argv[2]);
332
+
333
+ float temperature = (argc > 3) ? atof(argv[3]) : 0.8f;
334
+ int top_k = (argc > 4) ? atoi(argv[4]) : 40;
335
+ if (temperature < 0.01f) temperature = 0.01f;
336
+ if (top_k < 1) top_k = 1;
337
+ if (top_k > cfg.vocab_size) top_k = cfg.vocab_size;
338
+
339
+ srand((unsigned int)time(NULL));
340
+
341
+ const int C = cfg.n_embd;
342
+
343
+ KVCache kv;
344
+ kv.k_cache = (float*)calloc((long long)cfg.n_layer * cfg.block_size * C, sizeof(float));
345
+ kv.v_cache = (float*)calloc((long long)cfg.n_layer * cfg.block_size * C, sizeof(float));
346
+
347
+ float* x = (float*)malloc(C * sizeof(float));
348
+ float* buf = (float*)malloc(C * sizeof(float));
349
+ float* qkv_buf = (float*)malloc(3 * C * sizeof(float));
350
+
351
+ // Allocate enough space for ALL heads to process simultaneously
352
+ float* attn_buf = (float*)malloc(cfg.n_head * cfg.block_size * sizeof(float));
353
+
354
+ float* ff_buf = (float*)malloc(4 * C * sizeof(float));
355
+ float* logits = (float*)malloc(cfg.vocab_size * sizeof(float));
356
+
357
+ for (int i = 0; i < (int)input_ids.size(); i++)
358
+ forward(input_ids[i], i, &kv, x, buf, qkv_buf, attn_buf, ff_buf, logits);
359
+
360
+ int pos = input_ids.size();
361
+
362
+ for (int i = 0; i < max_new_tokens; i++) {
363
+
364
+ if (pos >= cfg.block_size)
365
+ break;
366
+
367
+ for (int v = 0; v < cfg.vocab_size; v++)
368
+ logits[v] /= temperature;
369
+
370
+ std::vector<std::pair<float, int>> pairs(cfg.vocab_size);
371
+ for (int v = 0; v < cfg.vocab_size; v++)
372
+ pairs[v] = {logits[v], v};
373
+
374
+ std::partial_sort(pairs.begin(), pairs.begin() + top_k, pairs.end(),
375
+ [](const std::pair<float,int>& a, const std::pair<float,int>& b) {
376
+ return a.first > b.first;
377
+ });
378
+
379
+ float sum = 0.0f;
380
+ for (int j = 0; j < top_k; j++) {
381
+ pairs[j].first = expf(pairs[j].first);
382
+ sum += pairs[j].first;
383
+ }
384
+ for (int j = 0; j < top_k; j++)
385
+ pairs[j].first /= sum;
386
+
387
+ float r = (float)rand() / ((float)RAND_MAX + 1.0f);
388
+ float cum = 0.0f;
389
+ int best = pairs[0].second;
390
+ for (int j = 0; j < top_k; j++) {
391
+ cum += pairs[j].first;
392
+ if (r < cum) {
393
+ best = pairs[j].second;
394
+ break;
395
+ }
396
+ }
397
+
398
+ printf("%d ", best);
399
+
400
+ if (best == 50256)
401
+ break;
402
+
403
+ forward(best, pos, &kv, x, buf, qkv_buf, attn_buf, ff_buf, logits);
404
+ pos++;
405
+ }
406
+
407
+ free(model_data_buffer);
408
+ return 0;
409
+ }
main.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py - SLM Inference Server
2
+ from fastapi import FastAPI, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from pydantic import BaseModel
5
+ import subprocess
6
+ import tiktoken
7
+ import os
8
+ import time
9
+
10
+ app = FastAPI()
11
+
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
17
+ )
18
+
19
+ class GenerateRequest(BaseModel):
20
+ prompt: str
21
+ max_tokens: int = 100
22
+ temperature: float = 0.8
23
+ top_k: int = 40
24
+
25
+ # Tokenizer setup
26
+ try:
27
+ enc = tiktoken.get_encoding("gpt2")
28
+ print("✅ Tokenizer loaded successfully.")
29
+ except Exception as e:
30
+ print(f"❌ Warning: tiktoken not found. Error: {e}")
31
+ enc = None
32
+
33
+
34
+ @app.get("/health")
35
+ async def health_check():
36
+ current_dir = os.path.dirname(os.path.abspath(__file__))
37
+ exe_path = os.path.join(current_dir, "inference.exe")
38
+ model_path = os.path.join(current_dir, "model.bin")
39
+
40
+ return {
41
+ "status": "ok",
42
+ "inference_exe_found": os.path.exists(exe_path),
43
+ "model_bin_found": os.path.exists(model_path),
44
+ "working_directory": current_dir
45
+ }
46
+
47
+
48
+ @app.post("/generate")
49
+ async def generate_text(req: GenerateRequest):
50
+
51
+ # 0. Tokenizer check
52
+ if enc is None:
53
+ raise HTTPException(
54
+ status_code=500,
55
+ detail="Tokenizer not loaded. Run: pip install tiktoken"
56
+ )
57
+
58
+ # 1. Encode prompt
59
+ input_tokens = enc.encode(req.prompt)
60
+ token_str = ",".join(map(str, input_tokens))
61
+
62
+ # 2. Path setup
63
+ current_dir = os.path.dirname(os.path.abspath(__file__))
64
+ exe_path = os.path.join(current_dir, "inference.exe")
65
+ model_path = os.path.join(current_dir, "model.bin")
66
+
67
+ print(f"DEBUG: exe -> {exe_path} exists={os.path.exists(exe_path)}")
68
+ print(f"DEBUG: model -> {model_path} exists={os.path.exists(model_path)}")
69
+
70
+ # 3. File existence checks
71
+ if not os.path.exists(exe_path):
72
+ raise HTTPException(
73
+ status_code=500,
74
+ detail=f"inference.exe nahi mili: {exe_path} — Pehle C++ compile karo!"
75
+ )
76
+
77
+ if not os.path.exists(model_path):
78
+ raise HTTPException(
79
+ status_code=500,
80
+ detail=f"model.bin nahi mili: {model_path} — Model file same folder mein rakhni hai!"
81
+ )
82
+
83
+ # 4. Run C++ engine
84
+ # FIX: temperature aur top_k ab subprocess ko pass ho rahe hain
85
+ try:
86
+ start_time = time.perf_counter()
87
+
88
+ process = subprocess.run(
89
+ [
90
+ exe_path,
91
+ token_str,
92
+ str(req.max_tokens),
93
+ str(req.temperature), # <-- FIX: was missing before
94
+ str(req.top_k), # <-- FIX: was missing before
95
+ ],
96
+ capture_output=True,
97
+ text=True,
98
+ cwd=current_dir
99
+ )
100
+
101
+ elapsed_ms = (time.perf_counter() - start_time) * 1000
102
+
103
+ except Exception as e:
104
+ raise HTTPException(status_code=500, detail=f"Execution failed: {str(e)}")
105
+
106
+ # 5. Error check
107
+ if process.returncode != 0 and not process.stdout.strip():
108
+ stdout_msg = process.stdout.strip() if process.stdout else ""
109
+ stderr_msg = process.stderr.strip() if process.stderr else ""
110
+
111
+ if "ERROR_MODEL_NOT_FOUND" in stdout_msg:
112
+ raise HTTPException(status_code=500, detail="model.bin nahi mili! Same folder mein rakho.")
113
+ elif "ERROR_ARGS" in stdout_msg:
114
+ raise HTTPException(status_code=500, detail="C++ engine ko arguments galat mile.")
115
+ else:
116
+ raise HTTPException(
117
+ status_code=500,
118
+ detail=f"C++ Error | stdout: '{stdout_msg}' | stderr: '{stderr_msg}'"
119
+ )
120
+
121
+ # 6. Decode output token IDs
122
+ try:
123
+ output_str = process.stdout.strip()
124
+
125
+ if not output_str:
126
+ generated_ids = []
127
+ else:
128
+ generated_ids = []
129
+ for x in output_str.split():
130
+ try:
131
+ generated_ids.append(int(x))
132
+ except ValueError:
133
+ print(f"DEBUG: skipping non-integer token: '{x}'")
134
+
135
+ generated_text = enc.decode(generated_ids) if generated_ids else ""
136
+
137
+ tokens_out = len(generated_ids)
138
+ tokens_per_sec = round(tokens_out / (elapsed_ms / 1000), 2) if elapsed_ms > 0 else 0
139
+
140
+ print(f"✅ Generated {tokens_out} tokens in {elapsed_ms:.2f}ms ({tokens_per_sec} tok/s)")
141
+
142
+ return {
143
+ "prompt": req.prompt,
144
+ "generated_text": generated_text,
145
+ "tokens_in": len(input_tokens),
146
+ "tokens_out": tokens_out,
147
+ "latency_ms": round(elapsed_ms, 2),
148
+ "tokens_per_sec": tokens_per_sec
149
+ }
150
+
151
+ except Exception as e:
152
+ raise HTTPException(status_code=500, detail=f"Decoding error: {str(e)}")
tokenizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80bb8ed25d76fd80db81de4faafb69cdeb7547c2aad716400347f10a6ab265c2
3
+ size 521859