mine / server.py
PyxiLabs's picture
Upload 4 files
401c156 verified
"""
╔══════════════════════════════════════════════════════════════╗
║ Granite 4.0 ONNX Inference Server ║
║ Model: onnx-community/granite-4.0-h-350m-ONNX ║
╚══════════════════════════════════════════════════════════════╝
"""
import asyncio
import time
import uuid
import threading
from collections import deque
from contextlib import asynccontextmanager
from typing import AsyncGenerator, List, Optional
import numpy as np
import onnxruntime
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import snapshot_download
from pydantic import BaseModel
from transformers import AutoConfig, AutoTokenizer
# ── Global model state ────────────────────────────────────────────────────────
MODEL_ID = "onnx-community/granite-4.0-h-350m-ONNX"
MODEL_FILENAME = "model_q4" # use quantized for speed
decoder_session = None
tokenizer = None
config = None
# ── Metrics state ─────────────────────────────────────────────────────────────
metrics = {
"total_requests": 0,
"active_requests": 0,
"total_tokens_generated": 0,
"total_prompt_tokens": 0,
"request_latencies": deque(maxlen=100),
"tokens_per_second_history": deque(maxlen=50),
"errors": 0,
"start_time": time.time(),
"last_tps": 0.0,
"model_loaded": False,
"model_loading": True,
}
metrics_lock = threading.Lock()
# ── Pydantic models ───────────────────────────────────────────────────────────
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
max_new_tokens: int = 512
temperature: float = 1.0
stream: bool = False
class ChatResponse(BaseModel):
id: str
content: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
latency_ms: float
tokens_per_second: float
# ── Model loader ──────────────────────────────────────────────────────────────
def load_model():
global decoder_session, tokenizer, config
print(f"[INFO] Downloading model {MODEL_ID}...")
try:
model_dir = snapshot_download(
MODEL_ID,
ignore_patterns=["*.msgpack", "*.h5", "flax_model*",
"model.onnx", "model_fp16.onnx", "model_q4f16.onnx"],
)
import os
model_path = os.path.join(model_dir, "onnx", f"{MODEL_FILENAME}.onnx")
print(f"[INFO] Loading ONNX session from {model_path}...")
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 4
decoder_session = onnxruntime.InferenceSession(
model_path,
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
print("[INFO] Loading tokenizer and config...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
config = AutoConfig.from_pretrained(MODEL_ID)
with metrics_lock:
metrics["model_loaded"] = True
metrics["model_loading"] = False
print("[INFO] ✅ Model loaded successfully!")
except Exception as e:
with metrics_lock:
metrics["model_loading"] = False
metrics["errors"] += 1
print(f"[ERROR] Failed to load model: {e}")
raise
# ── Cache initializer ─────────────────────────────────────────────────────────
def init_cache(batch_size: int, dtype=np.float32):
cache = {}
head_dim = config.hidden_size // config.num_attention_heads
d_conv = config.mamba_d_conv
mamba_expand = config.mamba_expand
mamba_n_groups = config.mamba_n_groups
mamba_d_state = config.mamba_d_state
conv_d_inner = (mamba_expand * config.hidden_size) + (2 * mamba_n_groups * mamba_d_state)
for i, layer_type in enumerate(config.layer_types):
if layer_type == "attention":
for kv in ("key", "value"):
cache[f"past_key_values.{i}.{kv}"] = np.zeros(
[batch_size, config.num_key_value_heads, 0, head_dim], dtype=dtype
)
elif layer_type == "mamba":
cache[f"past_conv.{i}"] = np.zeros(
[batch_size, conv_d_inner, d_conv], dtype=dtype
)
cache[f"past_ssm.{i}"] = np.zeros(
[batch_size, config.mamba_n_heads, config.mamba_d_head, mamba_d_state], dtype=dtype
)
return cache
# ── Core generation ───────────────────────────────────────────────────────────
def generate_tokens(input_ids: np.ndarray, attention_mask: np.ndarray,
max_new_tokens: int = 512) -> AsyncGenerator:
"""Synchronous token generation — yields (token_str, is_done)"""
dtype = np.float32
cache = init_cache(batch_size=1, dtype=dtype)
output_names = [o.name for o in decoder_session.get_outputs()]
eos_token_id = config.eos_token_id if not isinstance(
config.eos_token_id, list) else config.eos_token_id[0]
generated = []
t_start = time.perf_counter()
for step in range(max_new_tokens):
feed_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
outputs = decoder_session.run(None, feed_dict | cache)
named_outputs = dict(zip(output_names, outputs))
next_token = outputs[0][:, -1].argmax(-1, keepdims=True)
attention_mask = np.concatenate(
[attention_mask, np.ones_like(next_token, dtype=np.int64)], axis=-1
)
input_ids = next_token
for name in cache:
new_name = name.replace("past_key_values", "present").replace("past_", "present_")
cache[name] = named_outputs[new_name]
token_id = int(next_token[0, 0])
generated.append(token_id)
token_str = tokenizer.decode([token_id], skip_special_tokens=True)
elapsed = time.perf_counter() - t_start
tps = (step + 1) / elapsed if elapsed > 0 else 0
is_done = token_id == eos_token_id
yield token_str, is_done, tps
if is_done:
break
return generated
# ── Lifespan ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, load_model)
yield
# ── FastAPI app ───────────────────────────────────────────────────────────────
app = FastAPI(
title="Granite 4.0 ONNX Server",
description="High-performance inference server for granite-4.0-h-350m-ONNX",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ── API Routes ────────────────────────────────────────────────────────────────
@app.get("/health")
def health():
with metrics_lock:
return {
"status": "ready" if metrics["model_loaded"] else "loading",
"model": MODEL_ID,
"uptime_seconds": round(time.time() - metrics["start_time"], 1),
}
@app.get("/metrics")
def get_metrics():
with metrics_lock:
uptime = time.time() - metrics["start_time"]
avg_latency = (
sum(metrics["request_latencies"]) / len(metrics["request_latencies"])
if metrics["request_latencies"] else 0
)
return {
"uptime_seconds": round(uptime, 1),
"total_requests": metrics["total_requests"],
"active_requests": metrics["active_requests"],
"total_tokens_generated": metrics["total_tokens_generated"],
"total_prompt_tokens": metrics["total_prompt_tokens"],
"average_latency_ms": round(avg_latency, 2),
"last_tokens_per_second": round(metrics["last_tps"], 2),
"tps_history": list(metrics["tokens_per_second_history"]),
"errors": metrics["errors"],
"model_loaded": metrics["model_loaded"],
"model_loading": metrics["model_loading"],
"requests_per_minute": round(metrics["total_requests"] / max(uptime / 60, 1), 2),
}
@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest):
if not metrics["model_loaded"]:
raise HTTPException(status_code=503, detail="Model still loading, please wait...")
with metrics_lock:
metrics["total_requests"] += 1
metrics["active_requests"] += 1
t0 = time.perf_counter()
request_id = str(uuid.uuid4())[:8]
try:
messages = [{"role": m.role, "content": m.content} for m in req.messages]
loop = asyncio.get_event_loop()
inputs = await loop.run_in_executor(
None,
lambda: tokenizer.apply_chat_template(
messages, add_generation_prompt=True,
tokenize=True, return_dict=True, return_tensors="np"
)
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
prompt_tokens = int(input_ids.shape[1])
full_text = ""
final_tps = 0.0
completion_tokens = 0
def run_generation():
nonlocal full_text, final_tps, completion_tokens
for token_str, is_done, tps in generate_tokens(
input_ids, attention_mask, req.max_new_tokens
):
full_text += token_str
completion_tokens += 1
final_tps = tps
if is_done:
break
await loop.run_in_executor(None, run_generation)
latency_ms = (time.perf_counter() - t0) * 1000
with metrics_lock:
metrics["active_requests"] -= 1
metrics["total_tokens_generated"] += completion_tokens
metrics["total_prompt_tokens"] += prompt_tokens
metrics["request_latencies"].append(latency_ms)
metrics["tokens_per_second_history"].append(round(final_tps, 2))
metrics["last_tps"] = final_tps
return ChatResponse(
id=request_id,
content=full_text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
latency_ms=round(latency_ms, 2),
tokens_per_second=round(final_tps, 2),
)
except Exception as e:
with metrics_lock:
metrics["active_requests"] -= 1
metrics["errors"] += 1
raise HTTPException(status_code=500, detail=str(e))
@app.post("/chat/stream")
async def chat_stream(req: ChatRequest):
if not metrics["model_loaded"]:
raise HTTPException(status_code=503, detail="Model still loading...")
with metrics_lock:
metrics["total_requests"] += 1
metrics["active_requests"] += 1
messages = [{"role": m.role, "content": m.content} for m in req.messages]
inputs = tokenizer.apply_chat_template(
messages, add_generation_prompt=True,
tokenize=True, return_dict=True, return_tensors="np"
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
async def event_stream():
completion_tokens = 0
try:
loop = asyncio.get_event_loop()
gen = generate_tokens(input_ids, attention_mask, req.max_new_tokens)
def next_token():
return next(gen, None)
while True:
result = await loop.run_in_executor(None, next_token)
if result is None:
break
token_str, is_done, tps = result
completion_tokens += 1
yield f"data: {token_str}\n\n"
if is_done:
break
yield f"data: [DONE]\n\n"
finally:
with metrics_lock:
metrics["active_requests"] -= 1
metrics["total_tokens_generated"] += completion_tokens
return StreamingResponse(event_stream(), media_type="text/event-stream")
@app.get("/", response_class=HTMLResponse)
async def ui():
with open("/app/static/index.html") as f:
return f.read()