| | """ |
| | ╔══════════════════════════════════════════════════════════════╗ |
| | ║ Granite 4.0 ONNX Inference Server ║ |
| | ║ Model: onnx-community/granite-4.0-h-350m-ONNX ║ |
| | ╚══════════════════════════════════════════════════════════════╝ |
| | """ |
| |
|
| | import asyncio |
| | import time |
| | import uuid |
| | import threading |
| | from collections import deque |
| | from contextlib import asynccontextmanager |
| | from typing import AsyncGenerator, List, Optional |
| |
|
| | import numpy as np |
| | import onnxruntime |
| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import HTMLResponse, StreamingResponse |
| | from fastapi.staticfiles import StaticFiles |
| | from huggingface_hub import snapshot_download |
| | from pydantic import BaseModel |
| | from transformers import AutoConfig, AutoTokenizer |
| |
|
| | |
| | MODEL_ID = "onnx-community/granite-4.0-h-350m-ONNX" |
| | MODEL_FILENAME = "model_q4" |
| |
|
| | decoder_session = None |
| | tokenizer = None |
| | config = None |
| |
|
| | |
| | metrics = { |
| | "total_requests": 0, |
| | "active_requests": 0, |
| | "total_tokens_generated": 0, |
| | "total_prompt_tokens": 0, |
| | "request_latencies": deque(maxlen=100), |
| | "tokens_per_second_history": deque(maxlen=50), |
| | "errors": 0, |
| | "start_time": time.time(), |
| | "last_tps": 0.0, |
| | "model_loaded": False, |
| | "model_loading": True, |
| | } |
| | metrics_lock = threading.Lock() |
| |
|
| |
|
| | |
| | class Message(BaseModel): |
| | role: str |
| | content: str |
| |
|
| |
|
| | class ChatRequest(BaseModel): |
| | messages: List[Message] |
| | max_new_tokens: int = 512 |
| | temperature: float = 1.0 |
| | stream: bool = False |
| |
|
| |
|
| | class ChatResponse(BaseModel): |
| | id: str |
| | content: str |
| | prompt_tokens: int |
| | completion_tokens: int |
| | total_tokens: int |
| | latency_ms: float |
| | tokens_per_second: float |
| |
|
| |
|
| | |
| | def load_model(): |
| | global decoder_session, tokenizer, config |
| | print(f"[INFO] Downloading model {MODEL_ID}...") |
| |
|
| | try: |
| | model_dir = snapshot_download( |
| | MODEL_ID, |
| | ignore_patterns=["*.msgpack", "*.h5", "flax_model*", |
| | "model.onnx", "model_fp16.onnx", "model_q4f16.onnx"], |
| | ) |
| | import os |
| | model_path = os.path.join(model_dir, "onnx", f"{MODEL_FILENAME}.onnx") |
| |
|
| | print(f"[INFO] Loading ONNX session from {model_path}...") |
| | sess_options = onnxruntime.SessionOptions() |
| | sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
| | sess_options.intra_op_num_threads = 4 |
| |
|
| | decoder_session = onnxruntime.InferenceSession( |
| | model_path, |
| | sess_options=sess_options, |
| | providers=["CPUExecutionProvider"], |
| | ) |
| |
|
| | print("[INFO] Loading tokenizer and config...") |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| | config = AutoConfig.from_pretrained(MODEL_ID) |
| |
|
| | with metrics_lock: |
| | metrics["model_loaded"] = True |
| | metrics["model_loading"] = False |
| |
|
| | print("[INFO] ✅ Model loaded successfully!") |
| |
|
| | except Exception as e: |
| | with metrics_lock: |
| | metrics["model_loading"] = False |
| | metrics["errors"] += 1 |
| | print(f"[ERROR] Failed to load model: {e}") |
| | raise |
| |
|
| |
|
| | |
| | def init_cache(batch_size: int, dtype=np.float32): |
| | cache = {} |
| | head_dim = config.hidden_size // config.num_attention_heads |
| | d_conv = config.mamba_d_conv |
| | mamba_expand = config.mamba_expand |
| | mamba_n_groups = config.mamba_n_groups |
| | mamba_d_state = config.mamba_d_state |
| | conv_d_inner = (mamba_expand * config.hidden_size) + (2 * mamba_n_groups * mamba_d_state) |
| |
|
| | for i, layer_type in enumerate(config.layer_types): |
| | if layer_type == "attention": |
| | for kv in ("key", "value"): |
| | cache[f"past_key_values.{i}.{kv}"] = np.zeros( |
| | [batch_size, config.num_key_value_heads, 0, head_dim], dtype=dtype |
| | ) |
| | elif layer_type == "mamba": |
| | cache[f"past_conv.{i}"] = np.zeros( |
| | [batch_size, conv_d_inner, d_conv], dtype=dtype |
| | ) |
| | cache[f"past_ssm.{i}"] = np.zeros( |
| | [batch_size, config.mamba_n_heads, config.mamba_d_head, mamba_d_state], dtype=dtype |
| | ) |
| | return cache |
| |
|
| |
|
| | |
| | def generate_tokens(input_ids: np.ndarray, attention_mask: np.ndarray, |
| | max_new_tokens: int = 512) -> AsyncGenerator: |
| | """Synchronous token generation — yields (token_str, is_done)""" |
| | dtype = np.float32 |
| | cache = init_cache(batch_size=1, dtype=dtype) |
| | output_names = [o.name for o in decoder_session.get_outputs()] |
| | eos_token_id = config.eos_token_id if not isinstance( |
| | config.eos_token_id, list) else config.eos_token_id[0] |
| |
|
| | generated = [] |
| | t_start = time.perf_counter() |
| |
|
| | for step in range(max_new_tokens): |
| | feed_dict = {"input_ids": input_ids, "attention_mask": attention_mask} |
| | outputs = decoder_session.run(None, feed_dict | cache) |
| | named_outputs = dict(zip(output_names, outputs)) |
| |
|
| | next_token = outputs[0][:, -1].argmax(-1, keepdims=True) |
| | attention_mask = np.concatenate( |
| | [attention_mask, np.ones_like(next_token, dtype=np.int64)], axis=-1 |
| | ) |
| | input_ids = next_token |
| |
|
| | for name in cache: |
| | new_name = name.replace("past_key_values", "present").replace("past_", "present_") |
| | cache[name] = named_outputs[new_name] |
| |
|
| | token_id = int(next_token[0, 0]) |
| | generated.append(token_id) |
| |
|
| | token_str = tokenizer.decode([token_id], skip_special_tokens=True) |
| | elapsed = time.perf_counter() - t_start |
| | tps = (step + 1) / elapsed if elapsed > 0 else 0 |
| |
|
| | is_done = token_id == eos_token_id |
| | yield token_str, is_done, tps |
| |
|
| | if is_done: |
| | break |
| |
|
| | return generated |
| |
|
| |
|
| | |
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | loop = asyncio.get_event_loop() |
| | await loop.run_in_executor(None, load_model) |
| | yield |
| |
|
| |
|
| | |
| | app = FastAPI( |
| | title="Granite 4.0 ONNX Server", |
| | description="High-performance inference server for granite-4.0-h-350m-ONNX", |
| | version="1.0.0", |
| | lifespan=lifespan, |
| | ) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | |
| | @app.get("/health") |
| | def health(): |
| | with metrics_lock: |
| | return { |
| | "status": "ready" if metrics["model_loaded"] else "loading", |
| | "model": MODEL_ID, |
| | "uptime_seconds": round(time.time() - metrics["start_time"], 1), |
| | } |
| |
|
| |
|
| | @app.get("/metrics") |
| | def get_metrics(): |
| | with metrics_lock: |
| | uptime = time.time() - metrics["start_time"] |
| | avg_latency = ( |
| | sum(metrics["request_latencies"]) / len(metrics["request_latencies"]) |
| | if metrics["request_latencies"] else 0 |
| | ) |
| | return { |
| | "uptime_seconds": round(uptime, 1), |
| | "total_requests": metrics["total_requests"], |
| | "active_requests": metrics["active_requests"], |
| | "total_tokens_generated": metrics["total_tokens_generated"], |
| | "total_prompt_tokens": metrics["total_prompt_tokens"], |
| | "average_latency_ms": round(avg_latency, 2), |
| | "last_tokens_per_second": round(metrics["last_tps"], 2), |
| | "tps_history": list(metrics["tokens_per_second_history"]), |
| | "errors": metrics["errors"], |
| | "model_loaded": metrics["model_loaded"], |
| | "model_loading": metrics["model_loading"], |
| | "requests_per_minute": round(metrics["total_requests"] / max(uptime / 60, 1), 2), |
| | } |
| |
|
| |
|
| | @app.post("/chat", response_model=ChatResponse) |
| | async def chat(req: ChatRequest): |
| | if not metrics["model_loaded"]: |
| | raise HTTPException(status_code=503, detail="Model still loading, please wait...") |
| |
|
| | with metrics_lock: |
| | metrics["total_requests"] += 1 |
| | metrics["active_requests"] += 1 |
| |
|
| | t0 = time.perf_counter() |
| | request_id = str(uuid.uuid4())[:8] |
| |
|
| | try: |
| | messages = [{"role": m.role, "content": m.content} for m in req.messages] |
| | loop = asyncio.get_event_loop() |
| |
|
| | inputs = await loop.run_in_executor( |
| | None, |
| | lambda: tokenizer.apply_chat_template( |
| | messages, add_generation_prompt=True, |
| | tokenize=True, return_dict=True, return_tensors="np" |
| | ) |
| | ) |
| |
|
| | input_ids = inputs["input_ids"] |
| | attention_mask = inputs["attention_mask"] |
| | prompt_tokens = int(input_ids.shape[1]) |
| |
|
| | full_text = "" |
| | final_tps = 0.0 |
| | completion_tokens = 0 |
| |
|
| | def run_generation(): |
| | nonlocal full_text, final_tps, completion_tokens |
| | for token_str, is_done, tps in generate_tokens( |
| | input_ids, attention_mask, req.max_new_tokens |
| | ): |
| | full_text += token_str |
| | completion_tokens += 1 |
| | final_tps = tps |
| | if is_done: |
| | break |
| |
|
| | await loop.run_in_executor(None, run_generation) |
| |
|
| | latency_ms = (time.perf_counter() - t0) * 1000 |
| |
|
| | with metrics_lock: |
| | metrics["active_requests"] -= 1 |
| | metrics["total_tokens_generated"] += completion_tokens |
| | metrics["total_prompt_tokens"] += prompt_tokens |
| | metrics["request_latencies"].append(latency_ms) |
| | metrics["tokens_per_second_history"].append(round(final_tps, 2)) |
| | metrics["last_tps"] = final_tps |
| |
|
| | return ChatResponse( |
| | id=request_id, |
| | content=full_text, |
| | prompt_tokens=prompt_tokens, |
| | completion_tokens=completion_tokens, |
| | total_tokens=prompt_tokens + completion_tokens, |
| | latency_ms=round(latency_ms, 2), |
| | tokens_per_second=round(final_tps, 2), |
| | ) |
| |
|
| | except Exception as e: |
| | with metrics_lock: |
| | metrics["active_requests"] -= 1 |
| | metrics["errors"] += 1 |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| |
|
| | @app.post("/chat/stream") |
| | async def chat_stream(req: ChatRequest): |
| | if not metrics["model_loaded"]: |
| | raise HTTPException(status_code=503, detail="Model still loading...") |
| |
|
| | with metrics_lock: |
| | metrics["total_requests"] += 1 |
| | metrics["active_requests"] += 1 |
| |
|
| | messages = [{"role": m.role, "content": m.content} for m in req.messages] |
| | inputs = tokenizer.apply_chat_template( |
| | messages, add_generation_prompt=True, |
| | tokenize=True, return_dict=True, return_tensors="np" |
| | ) |
| |
|
| | input_ids = inputs["input_ids"] |
| | attention_mask = inputs["attention_mask"] |
| |
|
| | async def event_stream(): |
| | completion_tokens = 0 |
| | try: |
| | loop = asyncio.get_event_loop() |
| | gen = generate_tokens(input_ids, attention_mask, req.max_new_tokens) |
| |
|
| | def next_token(): |
| | return next(gen, None) |
| |
|
| | while True: |
| | result = await loop.run_in_executor(None, next_token) |
| | if result is None: |
| | break |
| | token_str, is_done, tps = result |
| | completion_tokens += 1 |
| | yield f"data: {token_str}\n\n" |
| | if is_done: |
| | break |
| |
|
| | yield f"data: [DONE]\n\n" |
| | finally: |
| | with metrics_lock: |
| | metrics["active_requests"] -= 1 |
| | metrics["total_tokens_generated"] += completion_tokens |
| |
|
| | return StreamingResponse(event_stream(), media_type="text/event-stream") |
| |
|
| |
|
| | @app.get("/", response_class=HTMLResponse) |
| | async def ui(): |
| | with open("/app/static/index.html") as f: |
| | return f.read() |
| |
|