File size: 13,715 Bytes
401c156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
"""
╔══════════════════════════════════════════════════════════════╗
║   Granite 4.0 ONNX Inference Server                         ║
║   Model: onnx-community/granite-4.0-h-350m-ONNX             ║
╚══════════════════════════════════════════════════════════════╝
"""

import asyncio
import time
import uuid
import threading
from collections import deque
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Optional

import numpy as np
import onnxruntime
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import snapshot_download
from pydantic import BaseModel
from transformers import AutoConfig, AutoTokenizer

# ── Global model state ────────────────────────────────────────────────────────
MODEL_ID = "onnx-community/granite-4.0-h-350m-ONNX"
MODEL_FILENAME = "model_q4"  # use quantized for speed

decoder_session = None
tokenizer = None
config = None

# ── Metrics state ─────────────────────────────────────────────────────────────
metrics = {
    "total_requests": 0,
    "active_requests": 0,
    "total_tokens_generated": 0,
    "total_prompt_tokens": 0,
    "request_latencies": deque(maxlen=100),
    "tokens_per_second_history": deque(maxlen=50),
    "errors": 0,
    "start_time": time.time(),
    "last_tps": 0.0,
    "model_loaded": False,
    "model_loading": True,
}
metrics_lock = threading.Lock()


# ── Pydantic models ───────────────────────────────────────────────────────────
class Message(BaseModel):
    role: str
    content: str


class ChatRequest(BaseModel):
    messages: List[Message]
    max_new_tokens: int = 512
    temperature: float = 1.0
    stream: bool = False


class ChatResponse(BaseModel):
    id: str
    content: str
    prompt_tokens: int
    completion_tokens: int
    total_tokens: int
    latency_ms: float
    tokens_per_second: float


# ── Model loader ──────────────────────────────────────────────────────────────
def load_model():
    global decoder_session, tokenizer, config
    print(f"[INFO] Downloading model {MODEL_ID}...")

    try:
        model_dir = snapshot_download(
            MODEL_ID,
            ignore_patterns=["*.msgpack", "*.h5", "flax_model*",
                             "model.onnx", "model_fp16.onnx", "model_q4f16.onnx"],
        )
        import os
        model_path = os.path.join(model_dir, "onnx", f"{MODEL_FILENAME}.onnx")

        print(f"[INFO] Loading ONNX session from {model_path}...")
        sess_options = onnxruntime.SessionOptions()
        sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
        sess_options.intra_op_num_threads = 4

        decoder_session = onnxruntime.InferenceSession(
            model_path,
            sess_options=sess_options,
            providers=["CPUExecutionProvider"],
        )

        print("[INFO] Loading tokenizer and config...")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        config = AutoConfig.from_pretrained(MODEL_ID)

        with metrics_lock:
            metrics["model_loaded"] = True
            metrics["model_loading"] = False

        print("[INFO] ✅ Model loaded successfully!")

    except Exception as e:
        with metrics_lock:
            metrics["model_loading"] = False
            metrics["errors"] += 1
        print(f"[ERROR] Failed to load model: {e}")
        raise


# ── Cache initializer ─────────────────────────────────────────────────────────
def init_cache(batch_size: int, dtype=np.float32):
    cache = {}
    head_dim = config.hidden_size // config.num_attention_heads
    d_conv = config.mamba_d_conv
    mamba_expand = config.mamba_expand
    mamba_n_groups = config.mamba_n_groups
    mamba_d_state = config.mamba_d_state
    conv_d_inner = (mamba_expand * config.hidden_size) + (2 * mamba_n_groups * mamba_d_state)

    for i, layer_type in enumerate(config.layer_types):
        if layer_type == "attention":
            for kv in ("key", "value"):
                cache[f"past_key_values.{i}.{kv}"] = np.zeros(
                    [batch_size, config.num_key_value_heads, 0, head_dim], dtype=dtype
                )
        elif layer_type == "mamba":
            cache[f"past_conv.{i}"] = np.zeros(
                [batch_size, conv_d_inner, d_conv], dtype=dtype
            )
            cache[f"past_ssm.{i}"] = np.zeros(
                [batch_size, config.mamba_n_heads, config.mamba_d_head, mamba_d_state], dtype=dtype
            )
    return cache


# ── Core generation ───────────────────────────────────────────────────────────
def generate_tokens(input_ids: np.ndarray, attention_mask: np.ndarray,
                    max_new_tokens: int = 512) -> AsyncGenerator:
    """Synchronous token generation — yields (token_str, is_done)"""
    dtype = np.float32
    cache = init_cache(batch_size=1, dtype=dtype)
    output_names = [o.name for o in decoder_session.get_outputs()]
    eos_token_id = config.eos_token_id if not isinstance(
        config.eos_token_id, list) else config.eos_token_id[0]

    generated = []
    t_start = time.perf_counter()

    for step in range(max_new_tokens):
        feed_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
        outputs = decoder_session.run(None, feed_dict | cache)
        named_outputs = dict(zip(output_names, outputs))

        next_token = outputs[0][:, -1].argmax(-1, keepdims=True)
        attention_mask = np.concatenate(
            [attention_mask, np.ones_like(next_token, dtype=np.int64)], axis=-1
        )
        input_ids = next_token

        for name in cache:
            new_name = name.replace("past_key_values", "present").replace("past_", "present_")
            cache[name] = named_outputs[new_name]

        token_id = int(next_token[0, 0])
        generated.append(token_id)

        token_str = tokenizer.decode([token_id], skip_special_tokens=True)
        elapsed = time.perf_counter() - t_start
        tps = (step + 1) / elapsed if elapsed > 0 else 0

        is_done = token_id == eos_token_id
        yield token_str, is_done, tps

        if is_done:
            break

    return generated


# ── Lifespan ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
    loop = asyncio.get_event_loop()
    await loop.run_in_executor(None, load_model)
    yield


# ── FastAPI app ───────────────────────────────────────────────────────────────
app = FastAPI(
    title="Granite 4.0 ONNX Server",
    description="High-performance inference server for granite-4.0-h-350m-ONNX",
    version="1.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


# ── API Routes ────────────────────────────────────────────────────────────────
@app.get("/health")
def health():
    with metrics_lock:
        return {
            "status": "ready" if metrics["model_loaded"] else "loading",
            "model": MODEL_ID,
            "uptime_seconds": round(time.time() - metrics["start_time"], 1),
        }


@app.get("/metrics")
def get_metrics():
    with metrics_lock:
        uptime = time.time() - metrics["start_time"]
        avg_latency = (
            sum(metrics["request_latencies"]) / len(metrics["request_latencies"])
            if metrics["request_latencies"] else 0
        )
        return {
            "uptime_seconds": round(uptime, 1),
            "total_requests": metrics["total_requests"],
            "active_requests": metrics["active_requests"],
            "total_tokens_generated": metrics["total_tokens_generated"],
            "total_prompt_tokens": metrics["total_prompt_tokens"],
            "average_latency_ms": round(avg_latency, 2),
            "last_tokens_per_second": round(metrics["last_tps"], 2),
            "tps_history": list(metrics["tokens_per_second_history"]),
            "errors": metrics["errors"],
            "model_loaded": metrics["model_loaded"],
            "model_loading": metrics["model_loading"],
            "requests_per_minute": round(metrics["total_requests"] / max(uptime / 60, 1), 2),
        }


@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest):
    if not metrics["model_loaded"]:
        raise HTTPException(status_code=503, detail="Model still loading, please wait...")

    with metrics_lock:
        metrics["total_requests"] += 1
        metrics["active_requests"] += 1

    t0 = time.perf_counter()
    request_id = str(uuid.uuid4())[:8]

    try:
        messages = [{"role": m.role, "content": m.content} for m in req.messages]
        loop = asyncio.get_event_loop()

        inputs = await loop.run_in_executor(
            None,
            lambda: tokenizer.apply_chat_template(
                messages, add_generation_prompt=True,
                tokenize=True, return_dict=True, return_tensors="np"
            )
        )

        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        prompt_tokens = int(input_ids.shape[1])

        full_text = ""
        final_tps = 0.0
        completion_tokens = 0

        def run_generation():
            nonlocal full_text, final_tps, completion_tokens
            for token_str, is_done, tps in generate_tokens(
                input_ids, attention_mask, req.max_new_tokens
            ):
                full_text += token_str
                completion_tokens += 1
                final_tps = tps
                if is_done:
                    break

        await loop.run_in_executor(None, run_generation)

        latency_ms = (time.perf_counter() - t0) * 1000

        with metrics_lock:
            metrics["active_requests"] -= 1
            metrics["total_tokens_generated"] += completion_tokens
            metrics["total_prompt_tokens"] += prompt_tokens
            metrics["request_latencies"].append(latency_ms)
            metrics["tokens_per_second_history"].append(round(final_tps, 2))
            metrics["last_tps"] = final_tps

        return ChatResponse(
            id=request_id,
            content=full_text,
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
            latency_ms=round(latency_ms, 2),
            tokens_per_second=round(final_tps, 2),
        )

    except Exception as e:
        with metrics_lock:
            metrics["active_requests"] -= 1
            metrics["errors"] += 1
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/chat/stream")
async def chat_stream(req: ChatRequest):
    if not metrics["model_loaded"]:
        raise HTTPException(status_code=503, detail="Model still loading...")

    with metrics_lock:
        metrics["total_requests"] += 1
        metrics["active_requests"] += 1

    messages = [{"role": m.role, "content": m.content} for m in req.messages]
    inputs = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True,
        tokenize=True, return_dict=True, return_tensors="np"
    )

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    async def event_stream():
        completion_tokens = 0
        try:
            loop = asyncio.get_event_loop()
            gen = generate_tokens(input_ids, attention_mask, req.max_new_tokens)

            def next_token():
                return next(gen, None)

            while True:
                result = await loop.run_in_executor(None, next_token)
                if result is None:
                    break
                token_str, is_done, tps = result
                completion_tokens += 1
                yield f"data: {token_str}\n\n"
                if is_done:
                    break

            yield f"data: [DONE]\n\n"
        finally:
            with metrics_lock:
                metrics["active_requests"] -= 1
                metrics["total_tokens_generated"] += completion_tokens

    return StreamingResponse(event_stream(), media_type="text/event-stream")


@app.get("/", response_class=HTMLResponse)
async def ui():
    with open("/app/static/index.html") as f:
        return f.read()