Spaces:
Runtime error
Runtime error
MiniMax Agent
Add v4 with background model loading - prevents timeout by loading model after server starts
44ffe48 | """ | |
| OpenELM OpenAI & Anthropic API - Background Loading Version | |
| This version loads the model in the background AFTER the app starts, | |
| preventing Hugging Face Spaces timeout issues. | |
| Key Features: | |
| - App starts immediately (no timeout) | |
| - Model loads in background thread | |
| - Health endpoint works from start | |
| - Proper SSE configuration | |
| - Returns 503 during loading with Retry-After header | |
| """ | |
| import asyncio | |
| import uuid | |
| import sys | |
| import threading | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncIterator, List, Optional, Dict, Any | |
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerFast | |
| from transformers import TextIteratorStreamer | |
| import os | |
| # ==================== Global State ==================== | |
| class ModelState: | |
| """Track model loading state.""" | |
| NOT_LOADED = "not_loaded" | |
| LOADING = "loading" | |
| READY = "ready" | |
| FAILED = "failed" | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| model_state = ModelState.NOT_LOADED | |
| model_load_error = None | |
| model_load_start_time = None | |
| model_load_end_time = None | |
| # ==================== Background Model Loading ==================== | |
| def load_model_sync(): | |
| """ | |
| Synchronous model loading function. | |
| This runs in a separate thread to not block the event loop. | |
| """ | |
| global model, tokenizer, model_state, model_load_error, model_load_end_time | |
| print("=" * 50) | |
| print("BACKGROUND: Starting model load...") | |
| print("=" * 50) | |
| try: | |
| model_id = "apple/OpenELM-450M-Instruct" | |
| model_load_start_time = time.time() | |
| # Load tokenizer | |
| print("BACKGROUND: Loading tokenizer...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| use_fast=False | |
| ) | |
| except Exception as e: | |
| print(f"BACKGROUND: Tokenizer warning: {e}") | |
| tokenizer = PreTrainedTokenizerFast( | |
| bos_token="<s>", | |
| eos_token="</s>", | |
| unk_token="<unk>", | |
| pad_token="<pad>" | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load model | |
| print("BACKGROUND: Loading model (this may take several minutes)...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| use_safetensors=True, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| model.eval() | |
| model_load_end_time = time.time() | |
| load_duration = model_load_end_time - model_load_start_time | |
| model_state = ModelState.READY | |
| print("=" * 50) | |
| print(f"BACKGROUND: Model loaded successfully in {load_duration:.1f} seconds!") | |
| print(f"BACKGROUND: Model device: {next(model.parameters()).device}") | |
| print("=" * 50) | |
| except Exception as e: | |
| model_load_error = str(e) | |
| model_state = ModelState.FAILED | |
| print("=" * 50) | |
| print(f"BACKGROUND: Model loading FAILED: {e}") | |
| print("=" * 50) | |
| import traceback | |
| traceback.print_exc() | |
| def start_background_model_loading(): | |
| """Start model loading in a background thread.""" | |
| global model_state | |
| print("SCHEDULING: Model loading in background thread...") | |
| model_state = ModelState.LOADING | |
| # Run in separate thread to not block event loop | |
| thread = threading.Thread(target=load_model_sync, daemon=True) | |
| thread.start() | |
| return thread | |
| # ==================== FastAPI App ==================== | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan - start server immediately, load model in background.""" | |
| global model_state | |
| print("=" * 50) | |
| print("STARTING: OpenELM API Server") | |
| print("=" * 50) | |
| print("Server starting immediately...") | |
| print("Model will load in background...") | |
| print("=" * 50) | |
| # Start background model loading (non-blocking) | |
| start_background_model_loading() | |
| # Yield control to start the server | |
| yield | |
| # Cleanup on shutdown | |
| print("SHUTDOWN: Cleaning up...") | |
| if model is not None: | |
| del model | |
| if tokenizer is not None: | |
| del tokenizer | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| print("SHUTDOWN: Complete") | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="OpenELM OpenAI API", | |
| description="OpenAI and Anthropic API compatible wrapper for OpenELM models", | |
| version="4.0.0", | |
| lifespan=lifespan, | |
| docs_url="/docs" if os.environ.get("DEBUG") else None, | |
| redoc_url="/redoc" if os.environ.get("DEBUG") else None, | |
| ) | |
| # Add CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ==================== Pydantic Models ==================== | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| name: Optional[str] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = "openelm-450m-instruct" | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) | |
| top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) | |
| max_tokens: Optional[int] = Field(default=None, ge=1, le=4096) | |
| stream: Optional[bool] = False | |
| class ChatCompletionChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: Optional[str] = None | |
| class ChatCompletionUsage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| usage: ChatCompletionUsage | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_state: str | |
| model_loaded: bool | |
| load_time_seconds: Optional[float] = None | |
| error: Optional[str] = None | |
| # ==================== Helper Functions ==================== | |
| def check_model_ready(): | |
| """Check if model is ready, raise if not.""" | |
| global model_state, model_load_error, model_load_start_time, model_load_end_time | |
| if model_state == ModelState.NOT_LOADED: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model has not started loading yet. Please wait a moment and retry.", | |
| headers={"Retry-After": "10"} | |
| ) | |
| if model_state == ModelState.LOADING: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model is still loading. Please wait a few moments and retry.", | |
| headers={"Retry-After": "30"} | |
| ) | |
| if model_state == ModelState.FAILED: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Model loading failed: {model_load_error}", | |
| headers={"Retry-After": "0"} | |
| ) | |
| def generate_with_model(prompt: str, max_tokens: int = 1024, temperature: float = 0.7) -> str: | |
| """Generate text using the loaded model.""" | |
| global model, tokenizer, model_state | |
| # Check state | |
| if model_state != ModelState.READY: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model is not ready yet. Please retry later.", | |
| headers={"Retry-After": "30"} | |
| ) | |
| # Tokenize | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_tokens = len(inputs.input_ids[0]) | |
| # Move to model device | |
| if hasattr(model, 'device'): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # Prepare generation parameters | |
| gen_params = { | |
| "max_new_tokens": max_tokens, | |
| "do_sample": temperature > 0, | |
| } | |
| if temperature > 0 and temperature != 0.7: | |
| gen_params["temperature"] = temperature | |
| if temperature == 0: | |
| gen_params["do_sample"] = False | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| **gen_params, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| # Decode | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract response | |
| response_text = extract_assistant_response(generated_text) | |
| return response_text, input_tokens | |
| def extract_assistant_response(generated_text: str) -> str: | |
| """Extract assistant response from generated text.""" | |
| if "Assistant:" in generated_text: | |
| return generated_text.split("Assistant:")[-1].strip() | |
| lines = generated_text.split("\n") | |
| response_parts = [] | |
| in_assistant = False | |
| for line in lines: | |
| if line.startswith("Assistant:"): | |
| in_assistant = True | |
| response_parts.append(line.replace("Assistant:", "").strip()) | |
| elif in_assistant and not line.startswith("User:") and not line.startswith("System:"): | |
| response_parts.append(line) | |
| elif line.startswith("User:") or line.startswith("System:"): | |
| in_assistant = False | |
| return "\n".join(response_parts).strip() | |
| # ==================== API Endpoints ==================== | |
| async def root(): | |
| """Root endpoint with API information.""" | |
| global model_state, model_load_start_time, model_load_end_time | |
| load_time = None | |
| if model_load_end_time and model_load_start_time: | |
| load_time = model_load_end_time - model_load_start_time | |
| return { | |
| "name": "OpenELM OpenAI API", | |
| "version": "4.0.0", | |
| "status": "ready" if model_state == ModelState.READY else "loading", | |
| "model_state": model_state, | |
| "model_loaded": model_state == ModelState.READY, | |
| "load_time_seconds": load_time, | |
| "endpoints": { | |
| "chat": "POST /v1/chat/completions", | |
| "messages": "POST /v1/messages", | |
| "health": "GET /health" | |
| }, | |
| "note": "Model loads in background for fast startup" | |
| } | |
| async def health_check(): | |
| """ | |
| Health check endpoint. | |
| IMPORTANT: This endpoint always returns 200 so Hugging Face | |
| doesn't timeout during model loading. | |
| """ | |
| global model_state, model_load_error, model_load_start_time, model_load_end_time | |
| load_time = None | |
| if model_load_end_time and model_load_start_time: | |
| load_time = model_load_end_time - model_load_start_time | |
| return HealthResponse( | |
| status="healthy" if model_state in [ModelState.READY, ModelState.LOADING] else "unhealthy", | |
| model_state=model_state, | |
| model_loaded=model_state == ModelState.READY, | |
| load_time_seconds=load_time, | |
| error=model_load_error | |
| ) | |
| async def readiness_check(): | |
| """ | |
| Readiness check for load balancers. | |
| Returns 200 only when model is ready. | |
| """ | |
| global model_state | |
| if model_state == ModelState.READY: | |
| return {"ready": True} | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Model not ready (state: {model_state})", | |
| headers={"Retry-After": "30"} | |
| ) | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| """ | |
| Create chat completion (OpenAI API format). | |
| Returns 503 if model is still loading. | |
| """ | |
| global model_state | |
| # Check if model is ready | |
| if model_state != ModelState.READY: | |
| if model_state == ModelState.LOADING: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model is still loading. Please retry in 30 seconds.", | |
| headers={"Retry-After": "30"} | |
| ) | |
| elif model_state == ModelState.NOT_LOADED: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model loading has not started yet. Please wait.", | |
| headers={"Retry-After": "10"} | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model failed to load. Please restart the Space.", | |
| headers={"Retry-After": "0"} | |
| ) | |
| try: | |
| # Build prompt from messages | |
| system_msg = None | |
| user_msgs = [] | |
| for msg in request.messages: | |
| if msg.role == "system" and system_msg is None: | |
| system_msg = msg.content | |
| else: | |
| user_msgs.append(msg) | |
| # Build prompt | |
| prompt_parts = [] | |
| if system_msg: | |
| prompt_parts.append(f"[System: {system_msg}]") | |
| for msg in user_msgs: | |
| if msg.role == "user": | |
| prompt_parts.append(f"User: {msg.content}") | |
| elif msg.role == "assistant": | |
| prompt_parts.append(f"Assistant: {msg.content}") | |
| prompt_parts.append("Assistant:") | |
| prompt = "\n\n".join(prompt_parts) | |
| # Generate | |
| max_tokens = request.max_tokens or 1024 | |
| temperature = request.temperature if request.temperature is not None else 0.7 | |
| response_text, input_tokens = generate_with_model(prompt, max_tokens, temperature) | |
| # Estimate output tokens | |
| output_tokens = max(1, len(response_text.split())) | |
| # Build response | |
| response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" | |
| timestamp = int(uuid.uuid1().time) | |
| return ChatCompletionResponse( | |
| id=response_id, | |
| created=timestamp, | |
| model="openelm-450m-instruct", | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| message=ChatMessage(role="assistant", content=response_text), | |
| finish_reason="stop" | |
| ) | |
| ], | |
| usage=ChatCompletionUsage( | |
| prompt_tokens=input_tokens, | |
| completion_tokens=output_tokens, | |
| total_tokens=input_tokens + output_tokens | |
| ) | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| async def create_message(request: dict): | |
| """ | |
| Create message (Anthropic API format). | |
| Returns 503 if model is still loading. | |
| """ | |
| global model_state | |
| # Check if model is ready | |
| if model_state != ModelState.READY: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model is still loading. Please retry in 30 seconds.", | |
| headers={"Retry-After": "30"} | |
| ) | |
| try: | |
| # Extract parameters | |
| messages = request.get("messages", []) | |
| system = request.get("system", None) | |
| max_tokens = request.get("max_tokens", 1024) | |
| temperature = request.get("temperature", 0.7) | |
| # Build prompt | |
| prompt_parts = [] | |
| if system: | |
| prompt_parts.append(f"[System: {system}]") | |
| for msg in messages: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| if isinstance(content, list): | |
| content = "".join(c.get("text", "") for c in content if isinstance(c, dict)) | |
| if role == "user": | |
| prompt_parts.append(f"User: {content}") | |
| elif role == "assistant": | |
| prompt_parts.append(f"Assistant: {content}") | |
| prompt_parts.append("Assistant:") | |
| prompt = "\n\n".join(prompt_parts) | |
| # Generate | |
| response_text, input_tokens = generate_with_model(prompt, max_tokens, temperature) | |
| # Estimate output tokens | |
| output_tokens = max(1, len(response_text.split())) | |
| # Build response | |
| return { | |
| "id": f"msg_{uuid.uuid4().hex[:8]}", | |
| "type": "message", | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": response_text}], | |
| "model": "openelm-450m-instruct", | |
| "stop_reason": "end_turn", | |
| "usage": { | |
| "input_tokens": input_tokens, | |
| "output_tokens": output_tokens, | |
| "total_tokens": input_tokens + output_tokens | |
| } | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| # ==================== Streaming Endpoint ==================== | |
| async def create_chat_completion_stream(request: ChatCompletionRequest): | |
| """ | |
| Create streaming chat completion. | |
| Returns 503 if model is still loading. | |
| """ | |
| global model, tokenizer, model_state | |
| # Check if model is ready | |
| if model_state != ModelState.READY: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model is still loading. Please retry in 30 seconds.", | |
| headers={"Retry-After": "30"} | |
| ) | |
| async def generate_stream(): | |
| """Generate streaming response.""" | |
| try: | |
| # Build prompt | |
| system_msg = None | |
| user_msgs = [] | |
| for msg in request.messages: | |
| if msg.role == "system" and system_msg is None: | |
| system_msg = msg.content | |
| else: | |
| user_msgs.append(msg) | |
| prompt_parts = [] | |
| if system_msg: | |
| prompt_parts.append(f"[System: {system_msg}]") | |
| for msg in user_msgs: | |
| if msg.role == "user": | |
| prompt_parts.append(f"User: {msg.content}") | |
| elif msg.role == "assistant": | |
| prompt_parts.append(f"Assistant: {msg.content}") | |
| prompt_parts.append("Assistant:") | |
| prompt = "\n\n".join(prompt_parts) | |
| # Tokenize | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_tokens = len(inputs.input_ids[0]) | |
| if hasattr(model, 'device'): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # Prepare generation | |
| max_tokens = request.max_tokens or 1024 | |
| temperature = request.temperature if request.temperature is not None else 0.7 | |
| gen_params = {"max_new_tokens": max_tokens} | |
| if temperature == 0: | |
| gen_params["do_sample"] = False | |
| else: | |
| gen_params["temperature"] = temperature | |
| gen_params["do_sample"] = True | |
| # Set up streaming | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| gen_params["streamer"] = streamer | |
| # Run in thread | |
| def generate(): | |
| with torch.no_grad(): | |
| model.generate(**inputs, **gen_params) | |
| thread = Thread(target=generate) | |
| thread.start() | |
| # Send response start | |
| chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" | |
| timestamp = int(uuid.uuid1().time) | |
| yield f"data: {{\"id\":\"{chunk_id}\",\"object\":\"chat.completion.chunk\",\"created\":{timestamp},\"model\":\"openelm-450m-instruct\",\"choices\":[{{\"index\":0,\"delta\":{{\"role\":\"assistant\"}},\"finish_reason\":null}}]}}\n\n" | |
| # Stream tokens | |
| full_text = "" | |
| for text in streamer: | |
| full_text += text | |
| chunk_data = { | |
| "id": chunk_id, | |
| "object": "chat.completion.chunk", | |
| "created": timestamp, | |
| "model": "openelm-450m-instruct", | |
| "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": None}] | |
| } | |
| yield f"data: {chunk_data}\n\n" | |
| # Send stop | |
| stop_chunk = { | |
| "id": chunk_id, | |
| "object": "chat.completion.chunk", | |
| "created": timestamp, | |
| "model": "openelm-450m-instruct", | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] | |
| } | |
| yield f"data: {stop_chunk}\n\n" | |
| # Send usage | |
| output_tokens = len(full_text.split()) + 1 | |
| usage_data = { | |
| "id": chunk_id, | |
| "object": "chat.completion", | |
| "created": timestamp, | |
| "model": "openelm-450m-instruct", | |
| "choices": [{"index": 0, "message": {"role": "assistant", "content": full_text}, "finish_reason": "stop"}], | |
| "usage": { | |
| "prompt_tokens": input_tokens, | |
| "completion_tokens": output_tokens, | |
| "total_tokens": input_tokens + output_tokens | |
| } | |
| } | |
| yield f"data: {usage_data}\n\n" | |
| yield "data: [DONE]\n\n" | |
| thread.join() | |
| except Exception as e: | |
| yield f"data: {{\"error\": {{\"message\": \"{str(e)}\"}}, \"type\": \"server_error\"}}\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| # ==================== Main Entry Point ==================== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| host = os.environ.get("HOST", "0.0.0.0") | |
| print("=" * 50) | |
| print("OpenELM API Server v4.0") | |
| print("=" * 50) | |
| print(f"Starting on {host}:{port}") | |
| print("Model will load in background") | |
| print("=" * 50) | |
| uvicorn.run( | |
| "app_v4:app", | |
| host=host, | |
| port=port, | |
| reload=False, | |
| workers=1, | |
| log_level="info" | |
| ) | |