""" OpenELM OpenAI & Anthropic API - Background Loading Version This version loads the model in the background AFTER the app starts, preventing Hugging Face Spaces timeout issues. Key Features: - App starts immediately (no timeout) - Model loads in background thread - Health endpoint works from start - Proper SSE configuration - Returns 503 during loading with Retry-After header """ import asyncio import uuid import sys import threading import time from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from typing import AsyncIterator, List, Optional, Dict, Any import torch from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerFast from transformers import TextIteratorStreamer import os # ==================== Global State ==================== class ModelState: """Track model loading state.""" NOT_LOADED = "not_loaded" LOADING = "loading" READY = "ready" FAILED = "failed" # Global variables model = None tokenizer = None model_state = ModelState.NOT_LOADED model_load_error = None model_load_start_time = None model_load_end_time = None # ==================== Background Model Loading ==================== def load_model_sync(): """ Synchronous model loading function. This runs in a separate thread to not block the event loop. """ global model, tokenizer, model_state, model_load_error, model_load_end_time print("=" * 50) print("BACKGROUND: Starting model load...") print("=" * 50) try: model_id = "apple/OpenELM-450M-Instruct" model_load_start_time = time.time() # Load tokenizer print("BACKGROUND: Loading tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, use_fast=False ) except Exception as e: print(f"BACKGROUND: Tokenizer warning: {e}") tokenizer = PreTrainedTokenizerFast( bos_token="", eos_token="", unk_token="", pad_token="" ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model print("BACKGROUND: Loading model (this may take several minutes)...") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, use_safetensors=True, trust_remote_code=True, low_cpu_mem_usage=True ) model.eval() model_load_end_time = time.time() load_duration = model_load_end_time - model_load_start_time model_state = ModelState.READY print("=" * 50) print(f"BACKGROUND: Model loaded successfully in {load_duration:.1f} seconds!") print(f"BACKGROUND: Model device: {next(model.parameters()).device}") print("=" * 50) except Exception as e: model_load_error = str(e) model_state = ModelState.FAILED print("=" * 50) print(f"BACKGROUND: Model loading FAILED: {e}") print("=" * 50) import traceback traceback.print_exc() def start_background_model_loading(): """Start model loading in a background thread.""" global model_state print("SCHEDULING: Model loading in background thread...") model_state = ModelState.LOADING # Run in separate thread to not block event loop thread = threading.Thread(target=load_model_sync, daemon=True) thread.start() return thread # ==================== FastAPI App ==================== @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan - start server immediately, load model in background.""" global model_state print("=" * 50) print("STARTING: OpenELM API Server") print("=" * 50) print("Server starting immediately...") print("Model will load in background...") print("=" * 50) # Start background model loading (non-blocking) start_background_model_loading() # Yield control to start the server yield # Cleanup on shutdown print("SHUTDOWN: Cleaning up...") if model is not None: del model if tokenizer is not None: del tokenizer torch.cuda.empty_cache() if torch.cuda.is_available() else None print("SHUTDOWN: Complete") # Create FastAPI app app = FastAPI( title="OpenELM OpenAI API", description="OpenAI and Anthropic API compatible wrapper for OpenELM models", version="4.0.0", lifespan=lifespan, docs_url="/docs" if os.environ.get("DEBUG") else None, redoc_url="/redoc" if os.environ.get("DEBUG") else None, ) # Add CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ==================== Pydantic Models ==================== class ChatMessage(BaseModel): role: str content: str name: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str = "openelm-450m-instruct" messages: List[ChatMessage] temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) max_tokens: Optional[int] = Field(default=None, ge=1, le=4096) stream: Optional[bool] = False class ChatCompletionChoice(BaseModel): index: int message: ChatMessage finish_reason: Optional[str] = None class ChatCompletionUsage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionChoice] usage: ChatCompletionUsage class HealthResponse(BaseModel): status: str model_state: str model_loaded: bool load_time_seconds: Optional[float] = None error: Optional[str] = None # ==================== Helper Functions ==================== def check_model_ready(): """Check if model is ready, raise if not.""" global model_state, model_load_error, model_load_start_time, model_load_end_time if model_state == ModelState.NOT_LOADED: raise HTTPException( status_code=503, detail="Model has not started loading yet. Please wait a moment and retry.", headers={"Retry-After": "10"} ) if model_state == ModelState.LOADING: raise HTTPException( status_code=503, detail="Model is still loading. Please wait a few moments and retry.", headers={"Retry-After": "30"} ) if model_state == ModelState.FAILED: raise HTTPException( status_code=503, detail=f"Model loading failed: {model_load_error}", headers={"Retry-After": "0"} ) def generate_with_model(prompt: str, max_tokens: int = 1024, temperature: float = 0.7) -> str: """Generate text using the loaded model.""" global model, tokenizer, model_state # Check state if model_state != ModelState.READY: raise HTTPException( status_code=503, detail="Model is not ready yet. Please retry later.", headers={"Retry-After": "30"} ) # Tokenize inputs = tokenizer(prompt, return_tensors="pt") input_tokens = len(inputs.input_ids[0]) # Move to model device if hasattr(model, 'device'): inputs = {k: v.to(model.device) for k, v in inputs.items()} # Prepare generation parameters gen_params = { "max_new_tokens": max_tokens, "do_sample": temperature > 0, } if temperature > 0 and temperature != 0.7: gen_params["temperature"] = temperature if temperature == 0: gen_params["do_sample"] = False # Generate with torch.no_grad(): outputs = model.generate( **inputs, **gen_params, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract response response_text = extract_assistant_response(generated_text) return response_text, input_tokens def extract_assistant_response(generated_text: str) -> str: """Extract assistant response from generated text.""" if "Assistant:" in generated_text: return generated_text.split("Assistant:")[-1].strip() lines = generated_text.split("\n") response_parts = [] in_assistant = False for line in lines: if line.startswith("Assistant:"): in_assistant = True response_parts.append(line.replace("Assistant:", "").strip()) elif in_assistant and not line.startswith("User:") and not line.startswith("System:"): response_parts.append(line) elif line.startswith("User:") or line.startswith("System:"): in_assistant = False return "\n".join(response_parts).strip() # ==================== API Endpoints ==================== @app.get("/", tags=["Root"]) async def root(): """Root endpoint with API information.""" global model_state, model_load_start_time, model_load_end_time load_time = None if model_load_end_time and model_load_start_time: load_time = model_load_end_time - model_load_start_time return { "name": "OpenELM OpenAI API", "version": "4.0.0", "status": "ready" if model_state == ModelState.READY else "loading", "model_state": model_state, "model_loaded": model_state == ModelState.READY, "load_time_seconds": load_time, "endpoints": { "chat": "POST /v1/chat/completions", "messages": "POST /v1/messages", "health": "GET /health" }, "note": "Model loads in background for fast startup" } @app.get("/health", response_model=HealthResponse, tags=["Health"]) async def health_check(): """ Health check endpoint. IMPORTANT: This endpoint always returns 200 so Hugging Face doesn't timeout during model loading. """ global model_state, model_load_error, model_load_start_time, model_load_end_time load_time = None if model_load_end_time and model_load_start_time: load_time = model_load_end_time - model_load_start_time return HealthResponse( status="healthy" if model_state in [ModelState.READY, ModelState.LOADING] else "unhealthy", model_state=model_state, model_loaded=model_state == ModelState.READY, load_time_seconds=load_time, error=model_load_error ) @app.get("/ready", tags=["Readiness"]) async def readiness_check(): """ Readiness check for load balancers. Returns 200 only when model is ready. """ global model_state if model_state == ModelState.READY: return {"ready": True} raise HTTPException( status_code=503, detail=f"Model not ready (state: {model_state})", headers={"Retry-After": "30"} ) @app.post("/v1/chat/completions", tags=["OpenAI"]) async def create_chat_completion(request: ChatCompletionRequest): """ Create chat completion (OpenAI API format). Returns 503 if model is still loading. """ global model_state # Check if model is ready if model_state != ModelState.READY: if model_state == ModelState.LOADING: raise HTTPException( status_code=503, detail="Model is still loading. Please retry in 30 seconds.", headers={"Retry-After": "30"} ) elif model_state == ModelState.NOT_LOADED: raise HTTPException( status_code=503, detail="Model loading has not started yet. Please wait.", headers={"Retry-After": "10"} ) else: raise HTTPException( status_code=503, detail="Model failed to load. Please restart the Space.", headers={"Retry-After": "0"} ) try: # Build prompt from messages system_msg = None user_msgs = [] for msg in request.messages: if msg.role == "system" and system_msg is None: system_msg = msg.content else: user_msgs.append(msg) # Build prompt prompt_parts = [] if system_msg: prompt_parts.append(f"[System: {system_msg}]") for msg in user_msgs: if msg.role == "user": prompt_parts.append(f"User: {msg.content}") elif msg.role == "assistant": prompt_parts.append(f"Assistant: {msg.content}") prompt_parts.append("Assistant:") prompt = "\n\n".join(prompt_parts) # Generate max_tokens = request.max_tokens or 1024 temperature = request.temperature if request.temperature is not None else 0.7 response_text, input_tokens = generate_with_model(prompt, max_tokens, temperature) # Estimate output tokens output_tokens = max(1, len(response_text.split())) # Build response response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" timestamp = int(uuid.uuid1().time) return ChatCompletionResponse( id=response_id, created=timestamp, model="openelm-450m-instruct", choices=[ ChatCompletionChoice( index=0, message=ChatMessage(role="assistant", content=response_text), finish_reason="stop" ) ], usage=ChatCompletionUsage( prompt_tokens=input_tokens, completion_tokens=output_tokens, total_tokens=input_tokens + output_tokens ) ) except HTTPException: raise except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") @app.post("/v1/messages", tags=["Anthropic"]) async def create_message(request: dict): """ Create message (Anthropic API format). Returns 503 if model is still loading. """ global model_state # Check if model is ready if model_state != ModelState.READY: raise HTTPException( status_code=503, detail="Model is still loading. Please retry in 30 seconds.", headers={"Retry-After": "30"} ) try: # Extract parameters messages = request.get("messages", []) system = request.get("system", None) max_tokens = request.get("max_tokens", 1024) temperature = request.get("temperature", 0.7) # Build prompt prompt_parts = [] if system: prompt_parts.append(f"[System: {system}]") for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if isinstance(content, list): content = "".join(c.get("text", "") for c in content if isinstance(c, dict)) if role == "user": prompt_parts.append(f"User: {content}") elif role == "assistant": prompt_parts.append(f"Assistant: {content}") prompt_parts.append("Assistant:") prompt = "\n\n".join(prompt_parts) # Generate response_text, input_tokens = generate_with_model(prompt, max_tokens, temperature) # Estimate output tokens output_tokens = max(1, len(response_text.split())) # Build response return { "id": f"msg_{uuid.uuid4().hex[:8]}", "type": "message", "role": "assistant", "content": [{"type": "text", "text": response_text}], "model": "openelm-450m-instruct", "stop_reason": "end_turn", "usage": { "input_tokens": input_tokens, "output_tokens": output_tokens, "total_tokens": input_tokens + output_tokens } } except HTTPException: raise except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") # ==================== Streaming Endpoint ==================== @app.post("/v1/chat/completions/stream", tags=["OpenAI"]) async def create_chat_completion_stream(request: ChatCompletionRequest): """ Create streaming chat completion. Returns 503 if model is still loading. """ global model, tokenizer, model_state # Check if model is ready if model_state != ModelState.READY: raise HTTPException( status_code=503, detail="Model is still loading. Please retry in 30 seconds.", headers={"Retry-After": "30"} ) async def generate_stream(): """Generate streaming response.""" try: # Build prompt system_msg = None user_msgs = [] for msg in request.messages: if msg.role == "system" and system_msg is None: system_msg = msg.content else: user_msgs.append(msg) prompt_parts = [] if system_msg: prompt_parts.append(f"[System: {system_msg}]") for msg in user_msgs: if msg.role == "user": prompt_parts.append(f"User: {msg.content}") elif msg.role == "assistant": prompt_parts.append(f"Assistant: {msg.content}") prompt_parts.append("Assistant:") prompt = "\n\n".join(prompt_parts) # Tokenize inputs = tokenizer(prompt, return_tensors="pt") input_tokens = len(inputs.input_ids[0]) if hasattr(model, 'device'): inputs = {k: v.to(model.device) for k, v in inputs.items()} # Prepare generation max_tokens = request.max_tokens or 1024 temperature = request.temperature if request.temperature is not None else 0.7 gen_params = {"max_new_tokens": max_tokens} if temperature == 0: gen_params["do_sample"] = False else: gen_params["temperature"] = temperature gen_params["do_sample"] = True # Set up streaming from transformers import TextIteratorStreamer from threading import Thread streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) gen_params["streamer"] = streamer # Run in thread def generate(): with torch.no_grad(): model.generate(**inputs, **gen_params) thread = Thread(target=generate) thread.start() # Send response start chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" timestamp = int(uuid.uuid1().time) yield f"data: {{\"id\":\"{chunk_id}\",\"object\":\"chat.completion.chunk\",\"created\":{timestamp},\"model\":\"openelm-450m-instruct\",\"choices\":[{{\"index\":0,\"delta\":{{\"role\":\"assistant\"}},\"finish_reason\":null}}]}}\n\n" # Stream tokens full_text = "" for text in streamer: full_text += text chunk_data = { "id": chunk_id, "object": "chat.completion.chunk", "created": timestamp, "model": "openelm-450m-instruct", "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": None}] } yield f"data: {chunk_data}\n\n" # Send stop stop_chunk = { "id": chunk_id, "object": "chat.completion.chunk", "created": timestamp, "model": "openelm-450m-instruct", "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] } yield f"data: {stop_chunk}\n\n" # Send usage output_tokens = len(full_text.split()) + 1 usage_data = { "id": chunk_id, "object": "chat.completion", "created": timestamp, "model": "openelm-450m-instruct", "choices": [{"index": 0, "message": {"role": "assistant", "content": full_text}, "finish_reason": "stop"}], "usage": { "prompt_tokens": input_tokens, "completion_tokens": output_tokens, "total_tokens": input_tokens + output_tokens } } yield f"data: {usage_data}\n\n" yield "data: [DONE]\n\n" thread.join() except Exception as e: yield f"data: {{\"error\": {{\"message\": \"{str(e)}\"}}, \"type\": \"server_error\"}}\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", } ) # ==================== Main Entry Point ==================== if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 8000)) host = os.environ.get("HOST", "0.0.0.0") print("=" * 50) print("OpenELM API Server v4.0") print("=" * 50) print(f"Starting on {host}:{port}") print("Model will load in background") print("=" * 50) uvicorn.run( "app_v4:app", host=host, port=port, reload=False, workers=1, log_level="info" )