agentic-api / app_v4.py
MiniMax Agent
Add v4 with background model loading - prevents timeout by loading model after server starts
44ffe48
"""
OpenELM OpenAI & Anthropic API - Background Loading Version
This version loads the model in the background AFTER the app starts,
preventing Hugging Face Spaces timeout issues.
Key Features:
- App starts immediately (no timeout)
- Model loads in background thread
- Health endpoint works from start
- Proper SSE configuration
- Returns 503 during loading with Retry-After header
"""
import asyncio
import uuid
import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from typing import AsyncIterator, List, Optional, Dict, Any
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerFast
from transformers import TextIteratorStreamer
import os
# ==================== Global State ====================
class ModelState:
"""Track model loading state."""
NOT_LOADED = "not_loaded"
LOADING = "loading"
READY = "ready"
FAILED = "failed"
# Global variables
model = None
tokenizer = None
model_state = ModelState.NOT_LOADED
model_load_error = None
model_load_start_time = None
model_load_end_time = None
# ==================== Background Model Loading ====================
def load_model_sync():
"""
Synchronous model loading function.
This runs in a separate thread to not block the event loop.
"""
global model, tokenizer, model_state, model_load_error, model_load_end_time
print("=" * 50)
print("BACKGROUND: Starting model load...")
print("=" * 50)
try:
model_id = "apple/OpenELM-450M-Instruct"
model_load_start_time = time.time()
# Load tokenizer
print("BACKGROUND: Loading tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
use_fast=False
)
except Exception as e:
print(f"BACKGROUND: Tokenizer warning: {e}")
tokenizer = PreTrainedTokenizerFast(
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>"
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
print("BACKGROUND: Loading model (this may take several minutes)...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
use_safetensors=True,
trust_remote_code=True,
low_cpu_mem_usage=True
)
model.eval()
model_load_end_time = time.time()
load_duration = model_load_end_time - model_load_start_time
model_state = ModelState.READY
print("=" * 50)
print(f"BACKGROUND: Model loaded successfully in {load_duration:.1f} seconds!")
print(f"BACKGROUND: Model device: {next(model.parameters()).device}")
print("=" * 50)
except Exception as e:
model_load_error = str(e)
model_state = ModelState.FAILED
print("=" * 50)
print(f"BACKGROUND: Model loading FAILED: {e}")
print("=" * 50)
import traceback
traceback.print_exc()
def start_background_model_loading():
"""Start model loading in a background thread."""
global model_state
print("SCHEDULING: Model loading in background thread...")
model_state = ModelState.LOADING
# Run in separate thread to not block event loop
thread = threading.Thread(target=load_model_sync, daemon=True)
thread.start()
return thread
# ==================== FastAPI App ====================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan - start server immediately, load model in background."""
global model_state
print("=" * 50)
print("STARTING: OpenELM API Server")
print("=" * 50)
print("Server starting immediately...")
print("Model will load in background...")
print("=" * 50)
# Start background model loading (non-blocking)
start_background_model_loading()
# Yield control to start the server
yield
# Cleanup on shutdown
print("SHUTDOWN: Cleaning up...")
if model is not None:
del model
if tokenizer is not None:
del tokenizer
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("SHUTDOWN: Complete")
# Create FastAPI app
app = FastAPI(
title="OpenELM OpenAI API",
description="OpenAI and Anthropic API compatible wrapper for OpenELM models",
version="4.0.0",
lifespan=lifespan,
docs_url="/docs" if os.environ.get("DEBUG") else None,
redoc_url="/redoc" if os.environ.get("DEBUG") else None,
)
# Add CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==================== Pydantic Models ====================
class ChatMessage(BaseModel):
role: str
content: str
name: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str = "openelm-450m-instruct"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
max_tokens: Optional[int] = Field(default=None, ge=1, le=4096)
stream: Optional[bool] = False
class ChatCompletionChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[str] = None
class ChatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionChoice]
usage: ChatCompletionUsage
class HealthResponse(BaseModel):
status: str
model_state: str
model_loaded: bool
load_time_seconds: Optional[float] = None
error: Optional[str] = None
# ==================== Helper Functions ====================
def check_model_ready():
"""Check if model is ready, raise if not."""
global model_state, model_load_error, model_load_start_time, model_load_end_time
if model_state == ModelState.NOT_LOADED:
raise HTTPException(
status_code=503,
detail="Model has not started loading yet. Please wait a moment and retry.",
headers={"Retry-After": "10"}
)
if model_state == ModelState.LOADING:
raise HTTPException(
status_code=503,
detail="Model is still loading. Please wait a few moments and retry.",
headers={"Retry-After": "30"}
)
if model_state == ModelState.FAILED:
raise HTTPException(
status_code=503,
detail=f"Model loading failed: {model_load_error}",
headers={"Retry-After": "0"}
)
def generate_with_model(prompt: str, max_tokens: int = 1024, temperature: float = 0.7) -> str:
"""Generate text using the loaded model."""
global model, tokenizer, model_state
# Check state
if model_state != ModelState.READY:
raise HTTPException(
status_code=503,
detail="Model is not ready yet. Please retry later.",
headers={"Retry-After": "30"}
)
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
input_tokens = len(inputs.input_ids[0])
# Move to model device
if hasattr(model, 'device'):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Prepare generation parameters
gen_params = {
"max_new_tokens": max_tokens,
"do_sample": temperature > 0,
}
if temperature > 0 and temperature != 0.7:
gen_params["temperature"] = temperature
if temperature == 0:
gen_params["do_sample"] = False
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
**gen_params,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract response
response_text = extract_assistant_response(generated_text)
return response_text, input_tokens
def extract_assistant_response(generated_text: str) -> str:
"""Extract assistant response from generated text."""
if "Assistant:" in generated_text:
return generated_text.split("Assistant:")[-1].strip()
lines = generated_text.split("\n")
response_parts = []
in_assistant = False
for line in lines:
if line.startswith("Assistant:"):
in_assistant = True
response_parts.append(line.replace("Assistant:", "").strip())
elif in_assistant and not line.startswith("User:") and not line.startswith("System:"):
response_parts.append(line)
elif line.startswith("User:") or line.startswith("System:"):
in_assistant = False
return "\n".join(response_parts).strip()
# ==================== API Endpoints ====================
@app.get("/", tags=["Root"])
async def root():
"""Root endpoint with API information."""
global model_state, model_load_start_time, model_load_end_time
load_time = None
if model_load_end_time and model_load_start_time:
load_time = model_load_end_time - model_load_start_time
return {
"name": "OpenELM OpenAI API",
"version": "4.0.0",
"status": "ready" if model_state == ModelState.READY else "loading",
"model_state": model_state,
"model_loaded": model_state == ModelState.READY,
"load_time_seconds": load_time,
"endpoints": {
"chat": "POST /v1/chat/completions",
"messages": "POST /v1/messages",
"health": "GET /health"
},
"note": "Model loads in background for fast startup"
}
@app.get("/health", response_model=HealthResponse, tags=["Health"])
async def health_check():
"""
Health check endpoint.
IMPORTANT: This endpoint always returns 200 so Hugging Face
doesn't timeout during model loading.
"""
global model_state, model_load_error, model_load_start_time, model_load_end_time
load_time = None
if model_load_end_time and model_load_start_time:
load_time = model_load_end_time - model_load_start_time
return HealthResponse(
status="healthy" if model_state in [ModelState.READY, ModelState.LOADING] else "unhealthy",
model_state=model_state,
model_loaded=model_state == ModelState.READY,
load_time_seconds=load_time,
error=model_load_error
)
@app.get("/ready", tags=["Readiness"])
async def readiness_check():
"""
Readiness check for load balancers.
Returns 200 only when model is ready.
"""
global model_state
if model_state == ModelState.READY:
return {"ready": True}
raise HTTPException(
status_code=503,
detail=f"Model not ready (state: {model_state})",
headers={"Retry-After": "30"}
)
@app.post("/v1/chat/completions", tags=["OpenAI"])
async def create_chat_completion(request: ChatCompletionRequest):
"""
Create chat completion (OpenAI API format).
Returns 503 if model is still loading.
"""
global model_state
# Check if model is ready
if model_state != ModelState.READY:
if model_state == ModelState.LOADING:
raise HTTPException(
status_code=503,
detail="Model is still loading. Please retry in 30 seconds.",
headers={"Retry-After": "30"}
)
elif model_state == ModelState.NOT_LOADED:
raise HTTPException(
status_code=503,
detail="Model loading has not started yet. Please wait.",
headers={"Retry-After": "10"}
)
else:
raise HTTPException(
status_code=503,
detail="Model failed to load. Please restart the Space.",
headers={"Retry-After": "0"}
)
try:
# Build prompt from messages
system_msg = None
user_msgs = []
for msg in request.messages:
if msg.role == "system" and system_msg is None:
system_msg = msg.content
else:
user_msgs.append(msg)
# Build prompt
prompt_parts = []
if system_msg:
prompt_parts.append(f"[System: {system_msg}]")
for msg in user_msgs:
if msg.role == "user":
prompt_parts.append(f"User: {msg.content}")
elif msg.role == "assistant":
prompt_parts.append(f"Assistant: {msg.content}")
prompt_parts.append("Assistant:")
prompt = "\n\n".join(prompt_parts)
# Generate
max_tokens = request.max_tokens or 1024
temperature = request.temperature if request.temperature is not None else 0.7
response_text, input_tokens = generate_with_model(prompt, max_tokens, temperature)
# Estimate output tokens
output_tokens = max(1, len(response_text.split()))
# Build response
response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
timestamp = int(uuid.uuid1().time)
return ChatCompletionResponse(
id=response_id,
created=timestamp,
model="openelm-450m-instruct",
choices=[
ChatCompletionChoice(
index=0,
message=ChatMessage(role="assistant", content=response_text),
finish_reason="stop"
)
],
usage=ChatCompletionUsage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens
)
)
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.post("/v1/messages", tags=["Anthropic"])
async def create_message(request: dict):
"""
Create message (Anthropic API format).
Returns 503 if model is still loading.
"""
global model_state
# Check if model is ready
if model_state != ModelState.READY:
raise HTTPException(
status_code=503,
detail="Model is still loading. Please retry in 30 seconds.",
headers={"Retry-After": "30"}
)
try:
# Extract parameters
messages = request.get("messages", [])
system = request.get("system", None)
max_tokens = request.get("max_tokens", 1024)
temperature = request.get("temperature", 0.7)
# Build prompt
prompt_parts = []
if system:
prompt_parts.append(f"[System: {system}]")
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if isinstance(content, list):
content = "".join(c.get("text", "") for c in content if isinstance(c, dict))
if role == "user":
prompt_parts.append(f"User: {content}")
elif role == "assistant":
prompt_parts.append(f"Assistant: {content}")
prompt_parts.append("Assistant:")
prompt = "\n\n".join(prompt_parts)
# Generate
response_text, input_tokens = generate_with_model(prompt, max_tokens, temperature)
# Estimate output tokens
output_tokens = max(1, len(response_text.split()))
# Build response
return {
"id": f"msg_{uuid.uuid4().hex[:8]}",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": response_text}],
"model": "openelm-450m-instruct",
"stop_reason": "end_turn",
"usage": {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens
}
}
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
# ==================== Streaming Endpoint ====================
@app.post("/v1/chat/completions/stream", tags=["OpenAI"])
async def create_chat_completion_stream(request: ChatCompletionRequest):
"""
Create streaming chat completion.
Returns 503 if model is still loading.
"""
global model, tokenizer, model_state
# Check if model is ready
if model_state != ModelState.READY:
raise HTTPException(
status_code=503,
detail="Model is still loading. Please retry in 30 seconds.",
headers={"Retry-After": "30"}
)
async def generate_stream():
"""Generate streaming response."""
try:
# Build prompt
system_msg = None
user_msgs = []
for msg in request.messages:
if msg.role == "system" and system_msg is None:
system_msg = msg.content
else:
user_msgs.append(msg)
prompt_parts = []
if system_msg:
prompt_parts.append(f"[System: {system_msg}]")
for msg in user_msgs:
if msg.role == "user":
prompt_parts.append(f"User: {msg.content}")
elif msg.role == "assistant":
prompt_parts.append(f"Assistant: {msg.content}")
prompt_parts.append("Assistant:")
prompt = "\n\n".join(prompt_parts)
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
input_tokens = len(inputs.input_ids[0])
if hasattr(model, 'device'):
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Prepare generation
max_tokens = request.max_tokens or 1024
temperature = request.temperature if request.temperature is not None else 0.7
gen_params = {"max_new_tokens": max_tokens}
if temperature == 0:
gen_params["do_sample"] = False
else:
gen_params["temperature"] = temperature
gen_params["do_sample"] = True
# Set up streaming
from transformers import TextIteratorStreamer
from threading import Thread
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
gen_params["streamer"] = streamer
# Run in thread
def generate():
with torch.no_grad():
model.generate(**inputs, **gen_params)
thread = Thread(target=generate)
thread.start()
# Send response start
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
timestamp = int(uuid.uuid1().time)
yield f"data: {{\"id\":\"{chunk_id}\",\"object\":\"chat.completion.chunk\",\"created\":{timestamp},\"model\":\"openelm-450m-instruct\",\"choices\":[{{\"index\":0,\"delta\":{{\"role\":\"assistant\"}},\"finish_reason\":null}}]}}\n\n"
# Stream tokens
full_text = ""
for text in streamer:
full_text += text
chunk_data = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": "openelm-450m-instruct",
"choices": [{"index": 0, "delta": {"content": text}, "finish_reason": None}]
}
yield f"data: {chunk_data}\n\n"
# Send stop
stop_chunk = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": "openelm-450m-instruct",
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
}
yield f"data: {stop_chunk}\n\n"
# Send usage
output_tokens = len(full_text.split()) + 1
usage_data = {
"id": chunk_id,
"object": "chat.completion",
"created": timestamp,
"model": "openelm-450m-instruct",
"choices": [{"index": 0, "message": {"role": "assistant", "content": full_text}, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens
}
}
yield f"data: {usage_data}\n\n"
yield "data: [DONE]\n\n"
thread.join()
except Exception as e:
yield f"data: {{\"error\": {{\"message\": \"{str(e)}\"}}, \"type\": \"server_error\"}}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
)
# ==================== Main Entry Point ====================
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
host = os.environ.get("HOST", "0.0.0.0")
print("=" * 50)
print("OpenELM API Server v4.0")
print("=" * 50)
print(f"Starting on {host}:{port}")
print("Model will load in background")
print("=" * 50)
uvicorn.run(
"app_v4:app",
host=host,
port=port,
reload=False,
workers=1,
log_level="info"
)