PyxiLabs commited on
Commit
401c156
·
verified ·
1 Parent(s): 96769eb

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +51 -0
  2. index (2).html +901 -0
  3. requirements (2).txt +15 -0
  4. server.py +367 -0
Dockerfile ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ╔══════════════════════════════════════════════════════════════╗
2
+ # ║ Granite 4.0 ONNX Inference Server ║
3
+ # ║ Model: onnx-community/granite-4.0-h-350m-ONNX ║
4
+ # ║ Runtime: ONNX Runtime CPU · FastAPI · Beautiful UI ║
5
+ # ╚══════════════════════════════════════════════════════════════╝
6
+
7
+ FROM python:3.11-slim
8
+
9
+ # ── System dependencies ───────────────────────────────────────────────────────
10
+ RUN apt-get update && apt-get install -y \
11
+ git \
12
+ curl \
13
+ build-essential \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # ── Create non-root user (HuggingFace Spaces requirement) ─────────────────────
17
+ RUN useradd -m -u 1000 user
18
+ USER user
19
+
20
+ ENV HOME=/home/user \
21
+ PATH=/home/user/.local/bin:$PATH \
22
+ HF_HOME=/home/user/.cache/huggingface \
23
+ TRANSFORMERS_CACHE=/home/user/.cache/huggingface \
24
+ # Prevents OMP issues on CPU
25
+ OMP_NUM_THREADS=4 \
26
+ MKL_NUM_THREADS=4
27
+
28
+ WORKDIR /app
29
+
30
+ # ── Install Python dependencies ───────────────────────────────────────────────
31
+ COPY --chown=user requirements.txt .
32
+ RUN pip install --no-cache-dir --upgrade pip && \
33
+ pip install --no-cache-dir -r requirements.txt
34
+
35
+ # ── Copy application files ────────────────────────────────────────────────────
36
+ COPY --chown=user server.py .
37
+ COPY --chown=user static/ ./static/
38
+
39
+ # ── Expose port (HF Spaces uses 7860) ────────────────────────────────────────
40
+ EXPOSE 7860
41
+
42
+ # ── Health check ─────────────────────────────────────────────────────────────
43
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=120s --retries=3 \
44
+ CMD curl -f http://localhost:7860/health || exit 1
45
+
46
+ # ── Launch server ─────────────────────────────────────────────────────────────
47
+ CMD ["uvicorn", "server:app", \
48
+ "--host", "0.0.0.0", \
49
+ "--port", "7860", \
50
+ "--workers", "1", \
51
+ "--log-level", "info"]
index (2).html ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>Granite 4.0 · ONNX Inference</title>
7
+ <link rel="preconnect" href="https://fonts.googleapis.com" />
8
+ <link href="https://fonts.googleapis.com/css2?family=Space+Mono:ital,wght@0,400;0,700;1,400&family=Syne:wght@400;600;700;800&display=swap" rel="stylesheet" />
9
+ <style>
10
+ :root {
11
+ --bg: #060810;
12
+ --surface: #0d1117;
13
+ --surface2: #131922;
14
+ --border: #1e2d3d;
15
+ --accent: #00d4ff;
16
+ --accent2: #7c3aed;
17
+ --accent3: #10b981;
18
+ --warn: #f59e0b;
19
+ --danger: #ef4444;
20
+ --text: #e2e8f0;
21
+ --muted: #4a5568;
22
+ --dim: #718096;
23
+ }
24
+
25
+ * { margin: 0; padding: 0; box-sizing: border-box; }
26
+
27
+ body {
28
+ background: var(--bg);
29
+ color: var(--text);
30
+ font-family: 'Space Mono', monospace;
31
+ min-height: 100vh;
32
+ overflow-x: hidden;
33
+ }
34
+
35
+ /* ── Animated background ── */
36
+ body::before {
37
+ content: '';
38
+ position: fixed;
39
+ inset: 0;
40
+ background:
41
+ radial-gradient(ellipse 80% 60% at 10% 20%, rgba(0,212,255,0.04) 0%, transparent 60%),
42
+ radial-gradient(ellipse 60% 80% at 90% 80%, rgba(124,58,237,0.04) 0%, transparent 60%);
43
+ pointer-events: none;
44
+ z-index: 0;
45
+ }
46
+
47
+ /* ── Grid lines ── */
48
+ body::after {
49
+ content: '';
50
+ position: fixed;
51
+ inset: 0;
52
+ background-image:
53
+ linear-gradient(rgba(0,212,255,0.03) 1px, transparent 1px),
54
+ linear-gradient(90deg, rgba(0,212,255,0.03) 1px, transparent 1px);
55
+ background-size: 40px 40px;
56
+ pointer-events: none;
57
+ z-index: 0;
58
+ }
59
+
60
+ .app {
61
+ position: relative;
62
+ z-index: 1;
63
+ display: grid;
64
+ grid-template-rows: auto 1fr;
65
+ min-height: 100vh;
66
+ max-width: 1400px;
67
+ margin: 0 auto;
68
+ padding: 0 20px;
69
+ }
70
+
71
+ /* ── Header ── */
72
+ header {
73
+ padding: 24px 0 20px;
74
+ border-bottom: 1px solid var(--border);
75
+ display: flex;
76
+ align-items: center;
77
+ justify-content: space-between;
78
+ gap: 16px;
79
+ flex-wrap: wrap;
80
+ }
81
+
82
+ .logo-block {
83
+ display: flex;
84
+ align-items: center;
85
+ gap: 14px;
86
+ }
87
+
88
+ .logo-icon {
89
+ width: 42px;
90
+ height: 42px;
91
+ border: 1px solid var(--accent);
92
+ border-radius: 8px;
93
+ display: flex;
94
+ align-items: center;
95
+ justify-content: center;
96
+ font-family: 'Syne', sans-serif;
97
+ font-weight: 800;
98
+ font-size: 18px;
99
+ color: var(--accent);
100
+ box-shadow: 0 0 20px rgba(0,212,255,0.2), inset 0 0 20px rgba(0,212,255,0.05);
101
+ animation: pulse-border 3s ease-in-out infinite;
102
+ }
103
+
104
+ @keyframes pulse-border {
105
+ 0%, 100% { box-shadow: 0 0 20px rgba(0,212,255,0.2), inset 0 0 20px rgba(0,212,255,0.05); }
106
+ 50% { box-shadow: 0 0 30px rgba(0,212,255,0.4), inset 0 0 20px rgba(0,212,255,0.1); }
107
+ }
108
+
109
+ .logo-text {
110
+ font-family: 'Syne', sans-serif;
111
+ }
112
+ .logo-text h1 {
113
+ font-size: 20px;
114
+ font-weight: 800;
115
+ letter-spacing: -0.5px;
116
+ color: var(--text);
117
+ }
118
+ .logo-text p {
119
+ font-size: 11px;
120
+ color: var(--dim);
121
+ margin-top: 2px;
122
+ }
123
+
124
+ #status-badge {
125
+ display: flex;
126
+ align-items: center;
127
+ gap: 8px;
128
+ font-size: 12px;
129
+ padding: 6px 14px;
130
+ border-radius: 999px;
131
+ border: 1px solid var(--border);
132
+ background: var(--surface);
133
+ transition: all 0.3s;
134
+ }
135
+ #status-dot {
136
+ width: 8px; height: 8px;
137
+ border-radius: 50%;
138
+ background: var(--warn);
139
+ animation: blink 1s infinite;
140
+ }
141
+ @keyframes blink { 0%,100%{opacity:1} 50%{opacity:0.3} }
142
+ #status-badge.ready { border-color: var(--accent3); }
143
+ #status-badge.ready #status-dot { background: var(--accent3); animation: none; }
144
+ #status-badge.error { border-color: var(--danger); }
145
+ #status-badge.error #status-dot { background: var(--danger); animation: none; }
146
+
147
+ /* ── Main layout ── */
148
+ main {
149
+ display: grid;
150
+ grid-template-columns: 1fr 340px;
151
+ grid-template-rows: 1fr;
152
+ gap: 20px;
153
+ padding: 20px 0 20px;
154
+ height: calc(100vh - 100px);
155
+ }
156
+
157
+ /* ── Chat panel ── */
158
+ .chat-panel {
159
+ display: flex;
160
+ flex-direction: column;
161
+ gap: 16px;
162
+ min-height: 0;
163
+ }
164
+
165
+ .messages-container {
166
+ flex: 1;
167
+ overflow-y: auto;
168
+ display: flex;
169
+ flex-direction: column;
170
+ gap: 12px;
171
+ padding-right: 6px;
172
+ scroll-behavior: smooth;
173
+ }
174
+
175
+ .messages-container::-webkit-scrollbar { width: 4px; }
176
+ .messages-container::-webkit-scrollbar-track { background: transparent; }
177
+ .messages-container::-webkit-scrollbar-thumb { background: var(--border); border-radius: 2px; }
178
+
179
+ .message {
180
+ display: flex;
181
+ gap: 12px;
182
+ animation: fade-in 0.3s ease;
183
+ }
184
+ @keyframes fade-in { from { opacity:0; transform:translateY(8px); } to { opacity:1; transform:none; } }
185
+
186
+ .message.user { flex-direction: row-reverse; }
187
+
188
+ .avatar {
189
+ width: 32px; height: 32px;
190
+ border-radius: 8px;
191
+ display: flex;
192
+ align-items: center;
193
+ justify-content: center;
194
+ font-size: 13px;
195
+ font-weight: 700;
196
+ flex-shrink: 0;
197
+ font-family: 'Syne', sans-serif;
198
+ }
199
+ .message.user .avatar { background: var(--accent2); color: white; }
200
+ .message.assistant .avatar {
201
+ background: linear-gradient(135deg, rgba(0,212,255,0.2), rgba(0,212,255,0.05));
202
+ border: 1px solid rgba(0,212,255,0.3);
203
+ color: var(--accent);
204
+ }
205
+
206
+ .bubble {
207
+ max-width: 75%;
208
+ padding: 12px 16px;
209
+ border-radius: 12px;
210
+ font-size: 13px;
211
+ line-height: 1.7;
212
+ }
213
+ .message.user .bubble {
214
+ background: rgba(124,58,237,0.15);
215
+ border: 1px solid rgba(124,58,237,0.3);
216
+ color: var(--text);
217
+ border-top-right-radius: 2px;
218
+ }
219
+ .message.assistant .bubble {
220
+ background: var(--surface);
221
+ border: 1px solid var(--border);
222
+ color: var(--text);
223
+ border-top-left-radius: 2px;
224
+ }
225
+
226
+ .bubble-meta {
227
+ margin-top: 6px;
228
+ font-size: 10px;
229
+ color: var(--muted);
230
+ display: flex;
231
+ gap: 10px;
232
+ }
233
+ .bubble-meta span { display: flex; align-items: center; gap: 3px; }
234
+
235
+ .typing-indicator {
236
+ display: flex;
237
+ gap: 5px;
238
+ padding: 4px 0;
239
+ align-items: center;
240
+ }
241
+ .typing-indicator span {
242
+ width: 6px; height: 6px;
243
+ background: var(--accent);
244
+ border-radius: 50%;
245
+ animation: bounce 1.2s infinite;
246
+ }
247
+ .typing-indicator span:nth-child(2) { animation-delay: 0.2s; }
248
+ .typing-indicator span:nth-child(3) { animation-delay: 0.4s; }
249
+ @keyframes bounce { 0%,80%,100%{transform:scale(0.8);opacity:0.5} 40%{transform:scale(1.2);opacity:1} }
250
+
251
+ /* ── Input area ── */
252
+ .input-area {
253
+ display: flex;
254
+ gap: 10px;
255
+ align-items: flex-end;
256
+ }
257
+
258
+ textarea {
259
+ flex: 1;
260
+ background: var(--surface);
261
+ border: 1px solid var(--border);
262
+ border-radius: 10px;
263
+ color: var(--text);
264
+ font-family: 'Space Mono', monospace;
265
+ font-size: 13px;
266
+ padding: 12px 14px;
267
+ resize: none;
268
+ min-height: 46px;
269
+ max-height: 120px;
270
+ outline: none;
271
+ transition: border-color 0.2s;
272
+ line-height: 1.5;
273
+ }
274
+ textarea:focus { border-color: var(--accent); }
275
+ textarea::placeholder { color: var(--muted); }
276
+
277
+ .send-btn {
278
+ width: 46px; height: 46px;
279
+ border-radius: 10px;
280
+ border: 1px solid var(--accent);
281
+ background: rgba(0,212,255,0.1);
282
+ color: var(--accent);
283
+ cursor: pointer;
284
+ display: flex;
285
+ align-items: center;
286
+ justify-content: center;
287
+ transition: all 0.2s;
288
+ flex-shrink: 0;
289
+ }
290
+ .send-btn:hover { background: rgba(0,212,255,0.2); box-shadow: 0 0 15px rgba(0,212,255,0.3); }
291
+ .send-btn:disabled { opacity: 0.4; cursor: not-allowed; }
292
+ .send-btn svg { width: 18px; height: 18px; }
293
+
294
+ .input-hint {
295
+ font-size: 10px;
296
+ color: var(--muted);
297
+ margin-top: 4px;
298
+ padding-left: 2px;
299
+ }
300
+
301
+ /* ── Right sidebar ── */
302
+ .sidebar {
303
+ display: flex;
304
+ flex-direction: column;
305
+ gap: 14px;
306
+ overflow-y: auto;
307
+ }
308
+
309
+ .sidebar::-webkit-scrollbar { width: 4px; }
310
+ .sidebar::-webkit-scrollbar-thumb { background: var(--border); border-radius: 2px; }
311
+
312
+ .card {
313
+ background: var(--surface);
314
+ border: 1px solid var(--border);
315
+ border-radius: 12px;
316
+ padding: 16px;
317
+ }
318
+
319
+ .card-title {
320
+ font-family: 'Syne', sans-serif;
321
+ font-size: 11px;
322
+ font-weight: 700;
323
+ letter-spacing: 1.5px;
324
+ text-transform: uppercase;
325
+ color: var(--dim);
326
+ margin-bottom: 14px;
327
+ display: flex;
328
+ align-items: center;
329
+ gap: 6px;
330
+ }
331
+ .card-title::before {
332
+ content: '';
333
+ display: block;
334
+ width: 3px;
335
+ height: 12px;
336
+ background: var(--accent);
337
+ border-radius: 2px;
338
+ }
339
+
340
+ /* ── Metric rows ── */
341
+ .metric-row {
342
+ display: flex;
343
+ justify-content: space-between;
344
+ align-items: center;
345
+ padding: 8px 0;
346
+ border-bottom: 1px solid rgba(30,45,61,0.5);
347
+ font-size: 12px;
348
+ }
349
+ .metric-row:last-child { border-bottom: none; padding-bottom: 0; }
350
+ .metric-label { color: var(--dim); }
351
+ .metric-value {
352
+ font-weight: 700;
353
+ font-size: 14px;
354
+ color: var(--text);
355
+ }
356
+ .metric-value.accent { color: var(--accent); }
357
+ .metric-value.green { color: var(--accent3); }
358
+ .metric-value.warn { color: var(--warn); }
359
+
360
+ /* ── Big TPS display ── */
361
+ .tps-display {
362
+ text-align: center;
363
+ padding: 16px 0 8px;
364
+ }
365
+ .tps-number {
366
+ font-family: 'Syne', sans-serif;
367
+ font-size: 48px;
368
+ font-weight: 800;
369
+ color: var(--accent);
370
+ line-height: 1;
371
+ text-shadow: 0 0 30px rgba(0,212,255,0.5);
372
+ transition: all 0.3s;
373
+ }
374
+ .tps-label {
375
+ font-size: 11px;
376
+ color: var(--dim);
377
+ letter-spacing: 2px;
378
+ text-transform: uppercase;
379
+ margin-top: 4px;
380
+ }
381
+
382
+ /* ── Mini sparkline ── */
383
+ .sparkline-wrap {
384
+ margin-top: 12px;
385
+ height: 40px;
386
+ position: relative;
387
+ }
388
+ canvas#sparkline {
389
+ width: 100%;
390
+ height: 100%;
391
+ }
392
+
393
+ /* ── Model info ── */
394
+ .model-tag {
395
+ display: inline-flex;
396
+ align-items: center;
397
+ gap: 6px;
398
+ background: rgba(0,212,255,0.08);
399
+ border: 1px solid rgba(0,212,255,0.2);
400
+ border-radius: 6px;
401
+ padding: 5px 10px;
402
+ font-size: 11px;
403
+ color: var(--accent);
404
+ word-break: break-all;
405
+ line-height: 1.4;
406
+ margin-top: 2px;
407
+ }
408
+
409
+ /* ── Settings sliders ── */
410
+ .slider-row {
411
+ padding: 8px 0;
412
+ }
413
+ .slider-label {
414
+ display: flex;
415
+ justify-content: space-between;
416
+ font-size: 11px;
417
+ color: var(--dim);
418
+ margin-bottom: 6px;
419
+ }
420
+ .slider-label span:last-child { color: var(--text); font-weight: 700; }
421
+ input[type="range"] {
422
+ width: 100%;
423
+ accent-color: var(--accent);
424
+ cursor: pointer;
425
+ }
426
+
427
+ /* ── Loading overlay ── */
428
+ #loading-overlay {
429
+ position: fixed;
430
+ inset: 0;
431
+ background: rgba(6,8,16,0.9);
432
+ z-index: 100;
433
+ display: flex;
434
+ flex-direction: column;
435
+ align-items: center;
436
+ justify-content: center;
437
+ gap: 20px;
438
+ backdrop-filter: blur(8px);
439
+ }
440
+ #loading-overlay.hidden { display: none; }
441
+
442
+ .loading-logo {
443
+ font-family: 'Syne', sans-serif;
444
+ font-size: 32px;
445
+ font-weight: 800;
446
+ color: var(--accent);
447
+ text-shadow: 0 0 40px rgba(0,212,255,0.5);
448
+ }
449
+ .loading-spinner {
450
+ width: 48px; height: 48px;
451
+ border: 2px solid var(--border);
452
+ border-top-color: var(--accent);
453
+ border-radius: 50%;
454
+ animation: spin 0.8s linear infinite;
455
+ }
456
+ @keyframes spin { to { transform: rotate(360deg); } }
457
+ .loading-text { font-size: 13px; color: var(--dim); }
458
+
459
+ .welcome-msg {
460
+ text-align: center;
461
+ padding: 40px 20px;
462
+ color: var(--muted);
463
+ }
464
+ .welcome-msg h2 {
465
+ font-family: 'Syne', sans-serif;
466
+ font-size: 20px;
467
+ font-weight: 700;
468
+ color: var(--dim);
469
+ margin-bottom: 8px;
470
+ }
471
+ .welcome-msg p { font-size: 12px; line-height: 1.8; }
472
+
473
+ @media (max-width: 900px) {
474
+ main { grid-template-columns: 1fr; grid-template-rows: 1fr auto; }
475
+ .sidebar { display: grid; grid-template-columns: 1fr 1fr; }
476
+ }
477
+ </style>
478
+ </head>
479
+ <body>
480
+
481
+ <div id="loading-overlay">
482
+ <div class="loading-logo">GRANITE</div>
483
+ <div class="loading-spinner"></div>
484
+ <div class="loading-text" id="loading-msg">Loading model — this may take a minute...</div>
485
+ </div>
486
+
487
+ <div class="app">
488
+ <header>
489
+ <div class="logo-block">
490
+ <div class="logo-icon">G4</div>
491
+ <div class="logo-text">
492
+ <h1>Granite 4.0 · ONNX</h1>
493
+ <p>granite-4.0-h-350m · CPU Inference Server</p>
494
+ </div>
495
+ </div>
496
+ <div id="status-badge">
497
+ <div id="status-dot"></div>
498
+ <span id="status-text">Initializing...</span>
499
+ </div>
500
+ </header>
501
+
502
+ <main>
503
+ <!-- ── Chat ── -->
504
+ <div class="chat-panel">
505
+ <div class="messages-container" id="messages">
506
+ <div class="welcome-msg">
507
+ <h2>Ready to chat</h2>
508
+ <p>IBM Granite 4.0 Hybrid · 350M params<br/>Running on ONNX Runtime · CPU</p>
509
+ </div>
510
+ </div>
511
+
512
+ <div>
513
+ <div class="input-area">
514
+ <textarea
515
+ id="user-input"
516
+ placeholder="Send a message... (Shift+Enter for newline)"
517
+ rows="1"
518
+ ></textarea>
519
+ <button class="send-btn" id="send-btn" title="Send">
520
+ <svg fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
521
+ <path stroke-linecap="round" stroke-linejoin="round" d="M5 12h14M12 5l7 7-7 7"/>
522
+ </svg>
523
+ </button>
524
+ </div>
525
+ <div class="input-hint">Enter to send · Shift+Enter for newline · streaming enabled</div>
526
+ </div>
527
+ </div>
528
+
529
+ <!-- ── Sidebar ── -->
530
+ <div class="sidebar">
531
+
532
+ <!-- TPS card -->
533
+ <div class="card">
534
+ <div class="card-title">Live Performance</div>
535
+ <div class="tps-display">
536
+ <div class="tps-number" id="tps-big">—</div>
537
+ <div class="tps-label">tokens / second</div>
538
+ </div>
539
+ <div class="sparkline-wrap">
540
+ <canvas id="sparkline"></canvas>
541
+ </div>
542
+ </div>
543
+
544
+ <!-- Server metrics -->
545
+ <div class="card">
546
+ <div class="card-title">Server Metrics</div>
547
+ <div class="metric-row">
548
+ <span class="metric-label">Uptime</span>
549
+ <span class="metric-value" id="m-uptime">—</span>
550
+ </div>
551
+ <div class="metric-row">
552
+ <span class="metric-label">Total Requests</span>
553
+ <span class="metric-value accent" id="m-requests">0</span>
554
+ </div>
555
+ <div class="metric-row">
556
+ <span class="metric-label">Active</span>
557
+ <span class="metric-value green" id="m-active">0</span>
558
+ </div>
559
+ <div class="metric-row">
560
+ <span class="metric-label">Tokens Generated</span>
561
+ <span class="metric-value" id="m-tokens">0</span>
562
+ </div>
563
+ <div class="metric-row">
564
+ <span class="metric-label">Avg Latency</span>
565
+ <span class="metric-value warn" id="m-latency">—</span>
566
+ </div>
567
+ <div class="metric-row">
568
+ <span class="metric-label">Errors</span>
569
+ <span class="metric-value" id="m-errors">0</span>
570
+ </div>
571
+ </div>
572
+
573
+ <!-- Settings -->
574
+ <div class="card">
575
+ <div class="card-title">Generation Settings</div>
576
+ <div class="slider-row">
577
+ <div class="slider-label">
578
+ <span>Max Tokens</span>
579
+ <span id="val-max-tokens">256</span>
580
+ </div>
581
+ <input type="range" id="max-tokens" min="64" max="1024" step="64" value="256" />
582
+ </div>
583
+ <div class="slider-row">
584
+ <div class="slider-label">
585
+ <span>Temperature</span>
586
+ <span id="val-temp">0.7</span>
587
+ </div>
588
+ <input type="range" id="temperature" min="0.1" max="2.0" step="0.1" value="0.7" />
589
+ </div>
590
+ </div>
591
+
592
+ <!-- Model info -->
593
+ <div class="card">
594
+ <div class="card-title">Model Info</div>
595
+ <div class="metric-row">
596
+ <span class="metric-label">Format</span>
597
+ <span class="metric-value green">ONNX Q4</span>
598
+ </div>
599
+ <div class="metric-row">
600
+ <span class="metric-label">Params</span>
601
+ <span class="metric-value">350M</span>
602
+ </div>
603
+ <div class="metric-row">
604
+ <span class="metric-label">Architecture</span>
605
+ <span class="metric-value">Hybrid MoE</span>
606
+ </div>
607
+ <div class="metric-row">
608
+ <span class="metric-label">Device</span>
609
+ <span class="metric-value accent">CPU</span>
610
+ </div>
611
+ <div style="margin-top:10px">
612
+ <div class="model-tag">onnx-community/granite-4.0-h-350m-ONNX</div>
613
+ </div>
614
+ </div>
615
+
616
+ </div>
617
+ </main>
618
+ </div>
619
+
620
+ <script>
621
+ // ── State ─────────────────────────────────────────────────────────────────
622
+ const conversationHistory = [];
623
+ let isGenerating = false;
624
+ const tpsHistory = [];
625
+
626
+ // ── DOM refs ──────────────────────────────────────────────────────────────
627
+ const messagesEl = document.getElementById('messages');
628
+ const inputEl = document.getElementById('user-input');
629
+ const sendBtn = document.getElementById('send-btn');
630
+ const loadingOverlay = document.getElementById('loading-overlay');
631
+ const loadingMsg = document.getElementById('loading-msg');
632
+ const statusBadge = document.getElementById('status-badge');
633
+ const statusText = document.getElementById('status-text');
634
+ const statusDot = document.getElementById('status-dot');
635
+ const tpsBig = document.getElementById('tps-big');
636
+
637
+ // ── Metrics polling ───────────────────────────────────────────────────────
638
+ async function pollMetrics() {
639
+ try {
640
+ const r = await fetch('/metrics');
641
+ const d = await r.json();
642
+
643
+ // Update status
644
+ const ready = d.model_loaded;
645
+ const loading = d.model_loading;
646
+
647
+ if (ready) {
648
+ statusBadge.className = 'ready';
649
+ statusText.textContent = 'Model Ready';
650
+ loadingOverlay.classList.add('hidden');
651
+ } else if (loading) {
652
+ statusText.textContent = 'Loading model...';
653
+ loadingMsg.textContent = `Downloading & loading ONNX model — uptime ${formatUptime(d.uptime_seconds)}`;
654
+ } else {
655
+ statusBadge.className = 'error';
656
+ statusText.textContent = 'Error';
657
+ }
658
+
659
+ // Update metric cards
660
+ document.getElementById('m-uptime').textContent = formatUptime(d.uptime_seconds);
661
+ document.getElementById('m-requests').textContent = d.total_requests.toLocaleString();
662
+ document.getElementById('m-active').textContent = d.active_requests;
663
+ document.getElementById('m-tokens').textContent = d.total_tokens_generated.toLocaleString();
664
+ document.getElementById('m-latency').textContent = d.average_latency_ms > 0 ? `${d.average_latency_ms.toFixed(0)}ms` : '—';
665
+ document.getElementById('m-errors').textContent = d.errors;
666
+
667
+ // TPS
668
+ const tps = d.last_tokens_per_second;
669
+ tpsBig.textContent = tps > 0 ? tps.toFixed(1) : '—';
670
+
671
+ if (d.tps_history && d.tps_history.length > 0) {
672
+ drawSparkline(d.tps_history);
673
+ }
674
+
675
+ } catch (e) { /* server not ready */ }
676
+ }
677
+
678
+ function formatUptime(s) {
679
+ const h = Math.floor(s / 3600);
680
+ const m = Math.floor((s % 3600) / 60);
681
+ const sec = Math.floor(s % 60);
682
+ if (h > 0) return `${h}h ${m}m`;
683
+ if (m > 0) return `${m}m ${sec}s`;
684
+ return `${sec}s`;
685
+ }
686
+
687
+ // ── Sparkline ─────────────────────────────────────────────────────────────
688
+ function drawSparkline(data) {
689
+ const canvas = document.getElementById('sparkline');
690
+ const ctx = canvas.getContext('2d');
691
+ const dpr = window.devicePixelRatio || 1;
692
+ const w = canvas.offsetWidth;
693
+ const h = canvas.offsetHeight;
694
+ canvas.width = w * dpr;
695
+ canvas.height = h * dpr;
696
+ ctx.scale(dpr, dpr);
697
+ ctx.clearRect(0, 0, w, h);
698
+
699
+ if (data.length < 2) return;
700
+
701
+ const max = Math.max(...data, 1);
702
+ const step = w / (data.length - 1);
703
+
704
+ // Gradient fill
705
+ const grad = ctx.createLinearGradient(0, 0, 0, h);
706
+ grad.addColorStop(0, 'rgba(0,212,255,0.3)');
707
+ grad.addColorStop(1, 'rgba(0,212,255,0)');
708
+
709
+ ctx.beginPath();
710
+ data.forEach((v, i) => {
711
+ const x = i * step;
712
+ const y = h - (v / max) * h * 0.9 - 2;
713
+ i === 0 ? ctx.moveTo(x, y) : ctx.lineTo(x, y);
714
+ });
715
+ ctx.lineTo(w, h);
716
+ ctx.lineTo(0, h);
717
+ ctx.closePath();
718
+ ctx.fillStyle = grad;
719
+ ctx.fill();
720
+
721
+ // Line
722
+ ctx.beginPath();
723
+ data.forEach((v, i) => {
724
+ const x = i * step;
725
+ const y = h - (v / max) * h * 0.9 - 2;
726
+ i === 0 ? ctx.moveTo(x, y) : ctx.lineTo(x, y);
727
+ });
728
+ ctx.strokeStyle = '#00d4ff';
729
+ ctx.lineWidth = 2;
730
+ ctx.stroke();
731
+ }
732
+
733
+ // ── Auto-resize textarea ──────────────────────────────────────────────────
734
+ inputEl.addEventListener('input', () => {
735
+ inputEl.style.height = 'auto';
736
+ inputEl.style.height = Math.min(inputEl.scrollHeight, 120) + 'px';
737
+ });
738
+
739
+ // ── Settings sliders ──────────────────────────────────────────────────────
740
+ document.getElementById('max-tokens').addEventListener('input', e => {
741
+ document.getElementById('val-max-tokens').textContent = e.target.value;
742
+ });
743
+ document.getElementById('temperature').addEventListener('input', e => {
744
+ document.getElementById('val-temp').textContent = parseFloat(e.target.value).toFixed(1);
745
+ });
746
+
747
+ // ── Keyboard handler ──────────────────────────────────────────────────────
748
+ inputEl.addEventListener('keydown', e => {
749
+ if (e.key === 'Enter' && !e.shiftKey) {
750
+ e.preventDefault();
751
+ if (!isGenerating) sendMessage();
752
+ }
753
+ });
754
+ sendBtn.addEventListener('click', () => { if (!isGenerating) sendMessage(); });
755
+
756
+ // ── Chat functions ────────────────────────────────────────────────────────
757
+ function appendMessage(role, content, meta) {
758
+ // Remove welcome msg on first real message
759
+ const welcome = messagesEl.querySelector('.welcome-msg');
760
+ if (welcome) welcome.remove();
761
+
762
+ const div = document.createElement('div');
763
+ div.className = `message ${role}`;
764
+
765
+ const avatar = document.createElement('div');
766
+ avatar.className = 'avatar';
767
+ avatar.textContent = role === 'user' ? 'U' : 'G4';
768
+
769
+ const bubble = document.createElement('div');
770
+ bubble.className = 'bubble';
771
+
772
+ const textNode = document.createElement('div');
773
+ textNode.className = 'bubble-text';
774
+ textNode.textContent = content;
775
+ bubble.appendChild(textNode);
776
+
777
+ if (meta) {
778
+ const metaDiv = document.createElement('div');
779
+ metaDiv.className = 'bubble-meta';
780
+ metaDiv.innerHTML = meta;
781
+ bubble.appendChild(metaDiv);
782
+ }
783
+
784
+ div.appendChild(avatar);
785
+ div.appendChild(bubble);
786
+ messagesEl.appendChild(div);
787
+ messagesEl.scrollTop = messagesEl.scrollHeight;
788
+
789
+ return { div, textNode };
790
+ }
791
+
792
+ function appendTyping() {
793
+ const welcome = messagesEl.querySelector('.welcome-msg');
794
+ if (welcome) welcome.remove();
795
+
796
+ const div = document.createElement('div');
797
+ div.className = 'message assistant';
798
+ div.id = 'typing-msg';
799
+
800
+ const avatar = document.createElement('div');
801
+ avatar.className = 'avatar';
802
+ avatar.textContent = 'G4';
803
+
804
+ const bubble = document.createElement('div');
805
+ bubble.className = 'bubble';
806
+ bubble.innerHTML = `<div class="typing-indicator"><span></span><span></span><span></span></div>`;
807
+
808
+ div.appendChild(avatar);
809
+ div.appendChild(bubble);
810
+ messagesEl.appendChild(div);
811
+ messagesEl.scrollTop = messagesEl.scrollHeight;
812
+ return div;
813
+ }
814
+
815
+ async function sendMessage() {
816
+ const text = inputEl.value.trim();
817
+ if (!text) return;
818
+
819
+ const maxTokens = parseInt(document.getElementById('max-tokens').value);
820
+ const temperature = parseFloat(document.getElementById('temperature').value);
821
+
822
+ inputEl.value = '';
823
+ inputEl.style.height = 'auto';
824
+ isGenerating = true;
825
+ sendBtn.disabled = true;
826
+
827
+ appendMessage('user', text);
828
+ conversationHistory.push({ role: 'user', content: text });
829
+
830
+ const typingEl = appendTyping();
831
+ const t0 = performance.now();
832
+
833
+ try {
834
+ const response = await fetch('/chat/stream', {
835
+ method: 'POST',
836
+ headers: { 'Content-Type': 'application/json' },
837
+ body: JSON.stringify({
838
+ messages: conversationHistory,
839
+ max_new_tokens: maxTokens,
840
+ temperature,
841
+ stream: true
842
+ })
843
+ });
844
+
845
+ typingEl.remove();
846
+ const { div: msgDiv, textNode } = appendMessage('assistant', '');
847
+
848
+ const reader = response.body.getReader();
849
+ const decoder = new TextDecoder();
850
+ let fullText = '';
851
+ let tokenCount = 0;
852
+
853
+ while (true) {
854
+ const { done, value } = await reader.read();
855
+ if (done) break;
856
+ const chunk = decoder.decode(value);
857
+ const lines = chunk.split('\n');
858
+ for (const line of lines) {
859
+ if (line.startsWith('data: ')) {
860
+ const data = line.slice(6);
861
+ if (data === '[DONE]') break;
862
+ fullText += data;
863
+ tokenCount++;
864
+ textNode.textContent = fullText;
865
+ messagesEl.scrollTop = messagesEl.scrollHeight;
866
+ }
867
+ }
868
+ }
869
+
870
+ const elapsed = (performance.now() - t0) / 1000;
871
+ const tps = (tokenCount / elapsed).toFixed(1);
872
+ tpsBig.textContent = tps;
873
+
874
+ conversationHistory.push({ role: 'assistant', content: fullText });
875
+
876
+ // Add meta to bubble
877
+ const metaDiv = document.createElement('div');
878
+ metaDiv.className = 'bubble-meta';
879
+ metaDiv.innerHTML = `
880
+ <span>⚡ ${tps} t/s</span>
881
+ <span>📝 ${tokenCount} tokens</span>
882
+ <span>⏱ ${elapsed.toFixed(1)}s</span>
883
+ `;
884
+ msgDiv.querySelector('.bubble').appendChild(metaDiv);
885
+
886
+ } catch (err) {
887
+ typingEl.remove();
888
+ appendMessage('assistant', `Error: ${err.message}`);
889
+ }
890
+
891
+ isGenerating = false;
892
+ sendBtn.disabled = false;
893
+ inputEl.focus();
894
+ }
895
+
896
+ // ── Boot ──────────────────────────────────────────────────────────────────
897
+ setInterval(pollMetrics, 2000);
898
+ pollMetrics();
899
+ </script>
900
+ </body>
901
+ </html>
requirements (2).txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Core inference ────────────────────────────────────────────
2
+ onnxruntime==1.20.1
3
+ numpy==1.26.4
4
+ transformers==4.47.0
5
+ huggingface_hub==0.26.5
6
+
7
+ # ── Web server ────────────────────────────────────────────────
8
+ fastapi==0.115.5
9
+ uvicorn[standard]==0.32.1
10
+ pydantic==2.10.1
11
+
12
+ # ── Utilities ─────────────────────────────────────────────────
13
+ accelerate==1.2.1
14
+ sentencepiece==0.2.0
15
+ protobuf==5.29.0
server.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════╗
3
+ ║ Granite 4.0 ONNX Inference Server ║
4
+ ║ Model: onnx-community/granite-4.0-h-350m-ONNX ║
5
+ ╚══════════════════════════════════════════════════════════════╝
6
+ """
7
+
8
+ import asyncio
9
+ import time
10
+ import uuid
11
+ import threading
12
+ from collections import deque
13
+ from contextlib import asynccontextmanager
14
+ from typing import AsyncGenerator, List, Optional
15
+
16
+ import numpy as np
17
+ import onnxruntime
18
+ from fastapi import FastAPI, HTTPException
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.responses import HTMLResponse, StreamingResponse
21
+ from fastapi.staticfiles import StaticFiles
22
+ from huggingface_hub import snapshot_download
23
+ from pydantic import BaseModel
24
+ from transformers import AutoConfig, AutoTokenizer
25
+
26
+ # ── Global model state ────────────────────────────────────────────────────────
27
+ MODEL_ID = "onnx-community/granite-4.0-h-350m-ONNX"
28
+ MODEL_FILENAME = "model_q4" # use quantized for speed
29
+
30
+ decoder_session = None
31
+ tokenizer = None
32
+ config = None
33
+
34
+ # ── Metrics state ─────────────────────────────────────────────────────────────
35
+ metrics = {
36
+ "total_requests": 0,
37
+ "active_requests": 0,
38
+ "total_tokens_generated": 0,
39
+ "total_prompt_tokens": 0,
40
+ "request_latencies": deque(maxlen=100),
41
+ "tokens_per_second_history": deque(maxlen=50),
42
+ "errors": 0,
43
+ "start_time": time.time(),
44
+ "last_tps": 0.0,
45
+ "model_loaded": False,
46
+ "model_loading": True,
47
+ }
48
+ metrics_lock = threading.Lock()
49
+
50
+
51
+ # ── Pydantic models ───────────────────────────────────────────────────────────
52
+ class Message(BaseModel):
53
+ role: str
54
+ content: str
55
+
56
+
57
+ class ChatRequest(BaseModel):
58
+ messages: List[Message]
59
+ max_new_tokens: int = 512
60
+ temperature: float = 1.0
61
+ stream: bool = False
62
+
63
+
64
+ class ChatResponse(BaseModel):
65
+ id: str
66
+ content: str
67
+ prompt_tokens: int
68
+ completion_tokens: int
69
+ total_tokens: int
70
+ latency_ms: float
71
+ tokens_per_second: float
72
+
73
+
74
+ # ── Model loader ──────────────────────────────────────────────────────────────
75
+ def load_model():
76
+ global decoder_session, tokenizer, config
77
+ print(f"[INFO] Downloading model {MODEL_ID}...")
78
+
79
+ try:
80
+ model_dir = snapshot_download(
81
+ MODEL_ID,
82
+ ignore_patterns=["*.msgpack", "*.h5", "flax_model*",
83
+ "model.onnx", "model_fp16.onnx", "model_q4f16.onnx"],
84
+ )
85
+ import os
86
+ model_path = os.path.join(model_dir, "onnx", f"{MODEL_FILENAME}.onnx")
87
+
88
+ print(f"[INFO] Loading ONNX session from {model_path}...")
89
+ sess_options = onnxruntime.SessionOptions()
90
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
91
+ sess_options.intra_op_num_threads = 4
92
+
93
+ decoder_session = onnxruntime.InferenceSession(
94
+ model_path,
95
+ sess_options=sess_options,
96
+ providers=["CPUExecutionProvider"],
97
+ )
98
+
99
+ print("[INFO] Loading tokenizer and config...")
100
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
101
+ config = AutoConfig.from_pretrained(MODEL_ID)
102
+
103
+ with metrics_lock:
104
+ metrics["model_loaded"] = True
105
+ metrics["model_loading"] = False
106
+
107
+ print("[INFO] ✅ Model loaded successfully!")
108
+
109
+ except Exception as e:
110
+ with metrics_lock:
111
+ metrics["model_loading"] = False
112
+ metrics["errors"] += 1
113
+ print(f"[ERROR] Failed to load model: {e}")
114
+ raise
115
+
116
+
117
+ # ── Cache initializer ─────────────────────────────────────────────────────────
118
+ def init_cache(batch_size: int, dtype=np.float32):
119
+ cache = {}
120
+ head_dim = config.hidden_size // config.num_attention_heads
121
+ d_conv = config.mamba_d_conv
122
+ mamba_expand = config.mamba_expand
123
+ mamba_n_groups = config.mamba_n_groups
124
+ mamba_d_state = config.mamba_d_state
125
+ conv_d_inner = (mamba_expand * config.hidden_size) + (2 * mamba_n_groups * mamba_d_state)
126
+
127
+ for i, layer_type in enumerate(config.layer_types):
128
+ if layer_type == "attention":
129
+ for kv in ("key", "value"):
130
+ cache[f"past_key_values.{i}.{kv}"] = np.zeros(
131
+ [batch_size, config.num_key_value_heads, 0, head_dim], dtype=dtype
132
+ )
133
+ elif layer_type == "mamba":
134
+ cache[f"past_conv.{i}"] = np.zeros(
135
+ [batch_size, conv_d_inner, d_conv], dtype=dtype
136
+ )
137
+ cache[f"past_ssm.{i}"] = np.zeros(
138
+ [batch_size, config.mamba_n_heads, config.mamba_d_head, mamba_d_state], dtype=dtype
139
+ )
140
+ return cache
141
+
142
+
143
+ # ── Core generation ───────────────────────────────────────────────────────────
144
+ def generate_tokens(input_ids: np.ndarray, attention_mask: np.ndarray,
145
+ max_new_tokens: int = 512) -> AsyncGenerator:
146
+ """Synchronous token generation — yields (token_str, is_done)"""
147
+ dtype = np.float32
148
+ cache = init_cache(batch_size=1, dtype=dtype)
149
+ output_names = [o.name for o in decoder_session.get_outputs()]
150
+ eos_token_id = config.eos_token_id if not isinstance(
151
+ config.eos_token_id, list) else config.eos_token_id[0]
152
+
153
+ generated = []
154
+ t_start = time.perf_counter()
155
+
156
+ for step in range(max_new_tokens):
157
+ feed_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
158
+ outputs = decoder_session.run(None, feed_dict | cache)
159
+ named_outputs = dict(zip(output_names, outputs))
160
+
161
+ next_token = outputs[0][:, -1].argmax(-1, keepdims=True)
162
+ attention_mask = np.concatenate(
163
+ [attention_mask, np.ones_like(next_token, dtype=np.int64)], axis=-1
164
+ )
165
+ input_ids = next_token
166
+
167
+ for name in cache:
168
+ new_name = name.replace("past_key_values", "present").replace("past_", "present_")
169
+ cache[name] = named_outputs[new_name]
170
+
171
+ token_id = int(next_token[0, 0])
172
+ generated.append(token_id)
173
+
174
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True)
175
+ elapsed = time.perf_counter() - t_start
176
+ tps = (step + 1) / elapsed if elapsed > 0 else 0
177
+
178
+ is_done = token_id == eos_token_id
179
+ yield token_str, is_done, tps
180
+
181
+ if is_done:
182
+ break
183
+
184
+ return generated
185
+
186
+
187
+ # ── Lifespan ──────────────────────────────────────────────────────────────────
188
+ @asynccontextmanager
189
+ async def lifespan(app: FastAPI):
190
+ loop = asyncio.get_event_loop()
191
+ await loop.run_in_executor(None, load_model)
192
+ yield
193
+
194
+
195
+ # ── FastAPI app ───────────────────────────────────────────────────────────────
196
+ app = FastAPI(
197
+ title="Granite 4.0 ONNX Server",
198
+ description="High-performance inference server for granite-4.0-h-350m-ONNX",
199
+ version="1.0.0",
200
+ lifespan=lifespan,
201
+ )
202
+
203
+ app.add_middleware(
204
+ CORSMiddleware,
205
+ allow_origins=["*"],
206
+ allow_methods=["*"],
207
+ allow_headers=["*"],
208
+ )
209
+
210
+
211
+ # ── API Routes ────────────────────────────────────────────────────────────────
212
+ @app.get("/health")
213
+ def health():
214
+ with metrics_lock:
215
+ return {
216
+ "status": "ready" if metrics["model_loaded"] else "loading",
217
+ "model": MODEL_ID,
218
+ "uptime_seconds": round(time.time() - metrics["start_time"], 1),
219
+ }
220
+
221
+
222
+ @app.get("/metrics")
223
+ def get_metrics():
224
+ with metrics_lock:
225
+ uptime = time.time() - metrics["start_time"]
226
+ avg_latency = (
227
+ sum(metrics["request_latencies"]) / len(metrics["request_latencies"])
228
+ if metrics["request_latencies"] else 0
229
+ )
230
+ return {
231
+ "uptime_seconds": round(uptime, 1),
232
+ "total_requests": metrics["total_requests"],
233
+ "active_requests": metrics["active_requests"],
234
+ "total_tokens_generated": metrics["total_tokens_generated"],
235
+ "total_prompt_tokens": metrics["total_prompt_tokens"],
236
+ "average_latency_ms": round(avg_latency, 2),
237
+ "last_tokens_per_second": round(metrics["last_tps"], 2),
238
+ "tps_history": list(metrics["tokens_per_second_history"]),
239
+ "errors": metrics["errors"],
240
+ "model_loaded": metrics["model_loaded"],
241
+ "model_loading": metrics["model_loading"],
242
+ "requests_per_minute": round(metrics["total_requests"] / max(uptime / 60, 1), 2),
243
+ }
244
+
245
+
246
+ @app.post("/chat", response_model=ChatResponse)
247
+ async def chat(req: ChatRequest):
248
+ if not metrics["model_loaded"]:
249
+ raise HTTPException(status_code=503, detail="Model still loading, please wait...")
250
+
251
+ with metrics_lock:
252
+ metrics["total_requests"] += 1
253
+ metrics["active_requests"] += 1
254
+
255
+ t0 = time.perf_counter()
256
+ request_id = str(uuid.uuid4())[:8]
257
+
258
+ try:
259
+ messages = [{"role": m.role, "content": m.content} for m in req.messages]
260
+ loop = asyncio.get_event_loop()
261
+
262
+ inputs = await loop.run_in_executor(
263
+ None,
264
+ lambda: tokenizer.apply_chat_template(
265
+ messages, add_generation_prompt=True,
266
+ tokenize=True, return_dict=True, return_tensors="np"
267
+ )
268
+ )
269
+
270
+ input_ids = inputs["input_ids"]
271
+ attention_mask = inputs["attention_mask"]
272
+ prompt_tokens = int(input_ids.shape[1])
273
+
274
+ full_text = ""
275
+ final_tps = 0.0
276
+ completion_tokens = 0
277
+
278
+ def run_generation():
279
+ nonlocal full_text, final_tps, completion_tokens
280
+ for token_str, is_done, tps in generate_tokens(
281
+ input_ids, attention_mask, req.max_new_tokens
282
+ ):
283
+ full_text += token_str
284
+ completion_tokens += 1
285
+ final_tps = tps
286
+ if is_done:
287
+ break
288
+
289
+ await loop.run_in_executor(None, run_generation)
290
+
291
+ latency_ms = (time.perf_counter() - t0) * 1000
292
+
293
+ with metrics_lock:
294
+ metrics["active_requests"] -= 1
295
+ metrics["total_tokens_generated"] += completion_tokens
296
+ metrics["total_prompt_tokens"] += prompt_tokens
297
+ metrics["request_latencies"].append(latency_ms)
298
+ metrics["tokens_per_second_history"].append(round(final_tps, 2))
299
+ metrics["last_tps"] = final_tps
300
+
301
+ return ChatResponse(
302
+ id=request_id,
303
+ content=full_text,
304
+ prompt_tokens=prompt_tokens,
305
+ completion_tokens=completion_tokens,
306
+ total_tokens=prompt_tokens + completion_tokens,
307
+ latency_ms=round(latency_ms, 2),
308
+ tokens_per_second=round(final_tps, 2),
309
+ )
310
+
311
+ except Exception as e:
312
+ with metrics_lock:
313
+ metrics["active_requests"] -= 1
314
+ metrics["errors"] += 1
315
+ raise HTTPException(status_code=500, detail=str(e))
316
+
317
+
318
+ @app.post("/chat/stream")
319
+ async def chat_stream(req: ChatRequest):
320
+ if not metrics["model_loaded"]:
321
+ raise HTTPException(status_code=503, detail="Model still loading...")
322
+
323
+ with metrics_lock:
324
+ metrics["total_requests"] += 1
325
+ metrics["active_requests"] += 1
326
+
327
+ messages = [{"role": m.role, "content": m.content} for m in req.messages]
328
+ inputs = tokenizer.apply_chat_template(
329
+ messages, add_generation_prompt=True,
330
+ tokenize=True, return_dict=True, return_tensors="np"
331
+ )
332
+
333
+ input_ids = inputs["input_ids"]
334
+ attention_mask = inputs["attention_mask"]
335
+
336
+ async def event_stream():
337
+ completion_tokens = 0
338
+ try:
339
+ loop = asyncio.get_event_loop()
340
+ gen = generate_tokens(input_ids, attention_mask, req.max_new_tokens)
341
+
342
+ def next_token():
343
+ return next(gen, None)
344
+
345
+ while True:
346
+ result = await loop.run_in_executor(None, next_token)
347
+ if result is None:
348
+ break
349
+ token_str, is_done, tps = result
350
+ completion_tokens += 1
351
+ yield f"data: {token_str}\n\n"
352
+ if is_done:
353
+ break
354
+
355
+ yield f"data: [DONE]\n\n"
356
+ finally:
357
+ with metrics_lock:
358
+ metrics["active_requests"] -= 1
359
+ metrics["total_tokens_generated"] += completion_tokens
360
+
361
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
362
+
363
+
364
+ @app.get("/", response_class=HTMLResponse)
365
+ async def ui():
366
+ with open("/app/static/index.html") as f:
367
+ return f.read()