""" ╔══════════════════════════════════════════════════════════════╗ ║ Granite 4.0 ONNX Inference Server ║ ║ Model: onnx-community/granite-4.0-h-350m-ONNX ║ ╚══════════════════════════════════════════════════════════════╝ """ import asyncio import time import uuid import threading from collections import deque from contextlib import asynccontextmanager from typing import AsyncGenerator, List, Optional import numpy as np import onnxruntime from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from huggingface_hub import snapshot_download from pydantic import BaseModel from transformers import AutoConfig, AutoTokenizer # ── Global model state ──────────────────────────────────────────────────────── MODEL_ID = "onnx-community/granite-4.0-h-350m-ONNX" MODEL_FILENAME = "model_q4" # use quantized for speed decoder_session = None tokenizer = None config = None # ── Metrics state ───────────────────────────────────────────────────────────── metrics = { "total_requests": 0, "active_requests": 0, "total_tokens_generated": 0, "total_prompt_tokens": 0, "request_latencies": deque(maxlen=100), "tokens_per_second_history": deque(maxlen=50), "errors": 0, "start_time": time.time(), "last_tps": 0.0, "model_loaded": False, "model_loading": True, } metrics_lock = threading.Lock() # ── Pydantic models ─────────────────────────────────────────────────────────── class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: List[Message] max_new_tokens: int = 512 temperature: float = 1.0 stream: bool = False class ChatResponse(BaseModel): id: str content: str prompt_tokens: int completion_tokens: int total_tokens: int latency_ms: float tokens_per_second: float # ── Model loader ────────────────────────────────────────────────────────────── def load_model(): global decoder_session, tokenizer, config print(f"[INFO] Downloading model {MODEL_ID}...") try: model_dir = snapshot_download( MODEL_ID, ignore_patterns=["*.msgpack", "*.h5", "flax_model*", "model.onnx", "model_fp16.onnx", "model_q4f16.onnx"], ) import os model_path = os.path.join(model_dir, "onnx", f"{MODEL_FILENAME}.onnx") print(f"[INFO] Loading ONNX session from {model_path}...") sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.intra_op_num_threads = 4 decoder_session = onnxruntime.InferenceSession( model_path, sess_options=sess_options, providers=["CPUExecutionProvider"], ) print("[INFO] Loading tokenizer and config...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) config = AutoConfig.from_pretrained(MODEL_ID) with metrics_lock: metrics["model_loaded"] = True metrics["model_loading"] = False print("[INFO] ✅ Model loaded successfully!") except Exception as e: with metrics_lock: metrics["model_loading"] = False metrics["errors"] += 1 print(f"[ERROR] Failed to load model: {e}") raise # ── Cache initializer ───────────────────────────────────────────────────────── def init_cache(batch_size: int, dtype=np.float32): cache = {} head_dim = config.hidden_size // config.num_attention_heads d_conv = config.mamba_d_conv mamba_expand = config.mamba_expand mamba_n_groups = config.mamba_n_groups mamba_d_state = config.mamba_d_state conv_d_inner = (mamba_expand * config.hidden_size) + (2 * mamba_n_groups * mamba_d_state) for i, layer_type in enumerate(config.layer_types): if layer_type == "attention": for kv in ("key", "value"): cache[f"past_key_values.{i}.{kv}"] = np.zeros( [batch_size, config.num_key_value_heads, 0, head_dim], dtype=dtype ) elif layer_type == "mamba": cache[f"past_conv.{i}"] = np.zeros( [batch_size, conv_d_inner, d_conv], dtype=dtype ) cache[f"past_ssm.{i}"] = np.zeros( [batch_size, config.mamba_n_heads, config.mamba_d_head, mamba_d_state], dtype=dtype ) return cache # ── Core generation ─────────────────────────────────────────────────────────── def generate_tokens(input_ids: np.ndarray, attention_mask: np.ndarray, max_new_tokens: int = 512) -> AsyncGenerator: """Synchronous token generation — yields (token_str, is_done)""" dtype = np.float32 cache = init_cache(batch_size=1, dtype=dtype) output_names = [o.name for o in decoder_session.get_outputs()] eos_token_id = config.eos_token_id if not isinstance( config.eos_token_id, list) else config.eos_token_id[0] generated = [] t_start = time.perf_counter() for step in range(max_new_tokens): feed_dict = {"input_ids": input_ids, "attention_mask": attention_mask} outputs = decoder_session.run(None, feed_dict | cache) named_outputs = dict(zip(output_names, outputs)) next_token = outputs[0][:, -1].argmax(-1, keepdims=True) attention_mask = np.concatenate( [attention_mask, np.ones_like(next_token, dtype=np.int64)], axis=-1 ) input_ids = next_token for name in cache: new_name = name.replace("past_key_values", "present").replace("past_", "present_") cache[name] = named_outputs[new_name] token_id = int(next_token[0, 0]) generated.append(token_id) token_str = tokenizer.decode([token_id], skip_special_tokens=True) elapsed = time.perf_counter() - t_start tps = (step + 1) / elapsed if elapsed > 0 else 0 is_done = token_id == eos_token_id yield token_str, is_done, tps if is_done: break return generated # ── Lifespan ────────────────────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): loop = asyncio.get_event_loop() await loop.run_in_executor(None, load_model) yield # ── FastAPI app ─────────────────────────────────────────────────────────────── app = FastAPI( title="Granite 4.0 ONNX Server", description="High-performance inference server for granite-4.0-h-350m-ONNX", version="1.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ── API Routes ──────────────────────────────────────────────────────────────── @app.get("/health") def health(): with metrics_lock: return { "status": "ready" if metrics["model_loaded"] else "loading", "model": MODEL_ID, "uptime_seconds": round(time.time() - metrics["start_time"], 1), } @app.get("/metrics") def get_metrics(): with metrics_lock: uptime = time.time() - metrics["start_time"] avg_latency = ( sum(metrics["request_latencies"]) / len(metrics["request_latencies"]) if metrics["request_latencies"] else 0 ) return { "uptime_seconds": round(uptime, 1), "total_requests": metrics["total_requests"], "active_requests": metrics["active_requests"], "total_tokens_generated": metrics["total_tokens_generated"], "total_prompt_tokens": metrics["total_prompt_tokens"], "average_latency_ms": round(avg_latency, 2), "last_tokens_per_second": round(metrics["last_tps"], 2), "tps_history": list(metrics["tokens_per_second_history"]), "errors": metrics["errors"], "model_loaded": metrics["model_loaded"], "model_loading": metrics["model_loading"], "requests_per_minute": round(metrics["total_requests"] / max(uptime / 60, 1), 2), } @app.post("/chat", response_model=ChatResponse) async def chat(req: ChatRequest): if not metrics["model_loaded"]: raise HTTPException(status_code=503, detail="Model still loading, please wait...") with metrics_lock: metrics["total_requests"] += 1 metrics["active_requests"] += 1 t0 = time.perf_counter() request_id = str(uuid.uuid4())[:8] try: messages = [{"role": m.role, "content": m.content} for m in req.messages] loop = asyncio.get_event_loop() inputs = await loop.run_in_executor( None, lambda: tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np" ) ) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] prompt_tokens = int(input_ids.shape[1]) full_text = "" final_tps = 0.0 completion_tokens = 0 def run_generation(): nonlocal full_text, final_tps, completion_tokens for token_str, is_done, tps in generate_tokens( input_ids, attention_mask, req.max_new_tokens ): full_text += token_str completion_tokens += 1 final_tps = tps if is_done: break await loop.run_in_executor(None, run_generation) latency_ms = (time.perf_counter() - t0) * 1000 with metrics_lock: metrics["active_requests"] -= 1 metrics["total_tokens_generated"] += completion_tokens metrics["total_prompt_tokens"] += prompt_tokens metrics["request_latencies"].append(latency_ms) metrics["tokens_per_second_history"].append(round(final_tps, 2)) metrics["last_tps"] = final_tps return ChatResponse( id=request_id, content=full_text, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, latency_ms=round(latency_ms, 2), tokens_per_second=round(final_tps, 2), ) except Exception as e: with metrics_lock: metrics["active_requests"] -= 1 metrics["errors"] += 1 raise HTTPException(status_code=500, detail=str(e)) @app.post("/chat/stream") async def chat_stream(req: ChatRequest): if not metrics["model_loaded"]: raise HTTPException(status_code=503, detail="Model still loading...") with metrics_lock: metrics["total_requests"] += 1 metrics["active_requests"] += 1 messages = [{"role": m.role, "content": m.content} for m in req.messages] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="np" ) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] async def event_stream(): completion_tokens = 0 try: loop = asyncio.get_event_loop() gen = generate_tokens(input_ids, attention_mask, req.max_new_tokens) def next_token(): return next(gen, None) while True: result = await loop.run_in_executor(None, next_token) if result is None: break token_str, is_done, tps = result completion_tokens += 1 yield f"data: {token_str}\n\n" if is_done: break yield f"data: [DONE]\n\n" finally: with metrics_lock: metrics["active_requests"] -= 1 metrics["total_tokens_generated"] += completion_tokens return StreamingResponse(event_stream(), media_type="text/event-stream") @app.get("/", response_class=HTMLResponse) async def ui(): with open("/app/static/index.html") as f: return f.read()