""" OpenELM OpenAI & Anthropic API Compatible Wrapper - v5 Minimal lazy-loading architecture for instant startup. Heavy imports (torch, transformers) are deferred to a background thread. """ import uuid import os import sys import time import asyncio import threading from contextlib import asynccontextmanager from typing import AsyncIterator, List, Optional, Dict, Any from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field # Global state for lazy loading # This allows the server to respond immediately while model loads in background global_state = { "status": "INITIALIZING", # INITIALIZING -> LOADING -> READY -> ERROR "model": None, "tokenizer": None, "error": None } def model_loader_thread(): """Load model in background thread to avoid blocking startup.""" global global_state try: # Import heavy libraries INSIDE the thread import torch import sys from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import hf_hub_download global_state["status"] = "LOADING" model_id = "apple/OpenELM-450M-Instruct" print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) # Set special tokens if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if tokenizer.bos_token is None: tokenizer.bos_token = "" if tokenizer.eos_token is None: tokenizer.eos_token = "" global_state["tokenizer"] = tokenizer print("Tokenizer loaded") print("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, use_safetensors=True, trust_remote_code=True ) model.eval() global_state["model"] = model global_state["status"] = "READY" print(f"Model loaded successfully! Device: {next(model.parameters()).device}") except Exception as e: global_state["error"] = str(e) global_state["status"] = "ERROR" print(f"Error loading model: {e}") @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator: """Application lifespan: Start background loader, then yield.""" global global_state print("=" * 60) print("OpenELM API v5 - Starting with background model loader") print("=" * 60) print("Server will respond immediately. Model loads in background.") print("Endpoints:") print(" POST /v1/chat/completions - OpenAI format") print(" POST /v1/messages - Anthropic format") print(" GET /health - Check model status") print("=" * 60) # Start background thread to load model loader_thread = threading.Thread(target=model_loader_thread, daemon=True) loader_thread.start() yield # Cleanup on shutdown if global_state["model"] is not None: del global_state["model"] if global_state["tokenizer"] is not None: del global_state["tokenizer"] if "torch" in sys.modules: import torch torch.cuda.empty_cache() if torch.cuda.is_available() else None # Create FastAPI app # Note: No heavy imports at module level - only fastapi and pydantic app = FastAPI( title="OpenELM OpenAI API", description="OpenAI and Anthropic API compatible wrapper for OpenELM models", version="5.0.0", lifespan=lifespan ) # Add CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ==================== Pydantic Models ==================== class MessageContent(BaseModel): type: str = "text" text: str class Message(BaseModel): role: str content: str | List[MessageContent] name: Optional[str] = None class Usage(BaseModel): input_tokens: int = 0 output_tokens: int = 0 total_tokens: int = 0 class ContentBlock(BaseModel): type: str = "text" text: str class MessageResponse(BaseModel): id: str type: str = "message" role: str = "assistant" content: List[ContentBlock] model: str stop_reason: Optional[str] = None stop_sequence: Optional[str] = None usage: Usage class MessageCreateParams(BaseModel): model: str = "openelm-450m-instruct" messages: List[Message] system: Optional[str] = None max_tokens: int = Field(default=1024, ge=1, le=4096) temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0) top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) stream: Optional[bool] = False class ChatMessage(BaseModel): role: str content: str name: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str = "openelm-450m-instruct" messages: List[ChatMessage] temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) max_tokens: Optional[int] = Field(default=None, ge=1, le=4096) stream: Optional[bool] = False class ChatCompletionChoice(BaseModel): index: int message: ChatMessage finish_reason: Optional[str] = None class ChatCompletionUsage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionChoice] usage: ChatCompletionUsage class OpenAIModelInfo(BaseModel): id: str object: str = "model" created: int = 0 owned_by: str = "openelm" permission: List[Any] = [] class OpenAIModelListResponse(BaseModel): object: str = "list" data: List[OpenAIModelInfo] # ==================== Helper Functions ==================== def format_prompt_for_openelm(messages: List[Message], system: Optional[str] = None) -> str: """Format messages into a prompt suitable for OpenELM.""" prompt_parts = [] if system: prompt_parts.append(f"[System: {system}]") for msg in messages: role = msg.role.lower() content = msg.content if isinstance(content, list): text_parts = [b.text for b in content if hasattr(b, 'text')] content = ''.join(text_parts) elif not isinstance(content, str): content = str(content) if role == "user": prompt_parts.append(f"User: {content}") elif role == "assistant": prompt_parts.append(f"Assistant: {content}") else: prompt_parts.append(f"{role}: {content}") prompt_parts.append("Assistant:") return "\n\n".join(prompt_parts) def count_tokens(text: str, tokenizer) -> int: """Count tokens using the tokenizer.""" try: return len(tokenizer.encode(text)) except: return max(1, len(text) // 4) def truncate_prompt(prompt: str, max_tokens: int, tokenizer, system: Optional[str] = None) -> str: """Truncate prompt to fit within context window.""" current_tokens = count_tokens(prompt, tokenizer) if current_tokens <= max_tokens: return prompt lines = prompt.split("\n\n") system_line = None if lines and lines[0].startswith("[System:"): system_line = lines[0] lines = lines[1:] truncated_lines = [] for line in reversed(lines): truncated_lines.insert(0, line) test_prompt = "\n\n".join([system_line] + truncated_lines) if system_line else "\n\n".join(truncated_lines) if count_tokens(test_prompt, tokenizer) <= max_tokens: break if system_line: return "\n\n".join([system_line] + truncated_lines) return "\n\n".join(truncated_lines) def extract_assistant_response(generated_text: str) -> str: """Extract assistant response from generated text.""" if "Assistant:" in generated_text: return generated_text.split("Assistant:")[-1].strip() lines = generated_text.split("\n") response_parts = [] in_assistant = False for line in lines: if line.startswith("Assistant:"): in_assistant = True response_parts.append(line.replace("Assistant:", "").strip()) elif in_assistant and not line.startswith("User:") and not line.startswith("System:"): response_parts.append(line) elif line.startswith("User:") or line.startswith("System:"): in_assistant = False return "\n".join(response_parts).strip() # ==================== API Endpoints ==================== @app.get("/", tags=["Root"]) async def root(): """Root endpoint with API information.""" return { "name": "OpenELM OpenAI API v5", "version": "5.0.0", "status": global_state["status"], "model_loaded": global_state["status"] == "READY", "endpoints": { "chat": "POST /v1/chat/completions", "messages": "POST /v1/messages", "health": "GET /health" }, "note": "Model loads in background for instant startup" } @app.get("/health", tags=["Health"]) async def health_check(): """Health check endpoint.""" if global_state["status"] == "READY": return {"status": "healthy", "model_loaded": True} elif global_state["status"] == "ERROR": raise HTTPException( status_code=503, detail=f"Model failed to load: {global_state.get('error', 'Unknown error')}" ) else: raise HTTPException( status_code=503, detail="Model is still loading. Please retry in a few moments." ) @app.get("/v1/models", response_model=OpenAIModelListResponse, tags=["Models"]) async def list_models(): """List available models (OpenAI format).""" return OpenAIModelListResponse( data=[ OpenAIModelInfo( id="openelm-450m-instruct", owned_by="apple", created=int(uuid.uuid1().time) ) ] ) @app.post("/v1/chat/completions", tags=["OpenAI"]) async def create_chat_completion(request: ChatCompletionRequest): """Create chat completion (OpenAI API format).""" if global_state["status"] != "READY": if global_state["status"] == "ERROR": raise HTTPException(status_code=503, detail="Model failed to load") raise HTTPException(status_code=503, detail="Model is still loading. Please retry.") model = global_state["model"] tokenizer = global_state["tokenizer"] try: system_message = None formatted_messages = [] for msg in request.messages: if msg.role == "system" and system_message is None: system_message = msg.content else: formatted_messages.append(Message(role=msg.role, content=msg.content)) prompt = format_prompt_for_openelm(formatted_messages, system_message) max_tokens = request.max_tokens or 1024 prompt = truncate_prompt(prompt, 2048 - max_tokens, tokenizer, system_message) inputs = tokenizer(prompt, return_tensors="pt") input_tokens = len(inputs.input_ids[0]) if hasattr(model, 'device'): inputs = {k: v.to(model.device) for k, v in inputs.items()} gen_params = {"max_new_tokens": max_tokens} if request.temperature is not None: if request.temperature == 0: gen_params["do_sample"] = False else: gen_params["temperature"] = request.temperature gen_params["do_sample"] = True if request.top_p is not None: gen_params["top_p"] = request.top_p import torch with torch.no_grad(): outputs = model.generate( **inputs, **gen_params, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) response_text = extract_assistant_response(generated_text) output_tokens = count_tokens(response_text, tokenizer) response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" timestamp = int(uuid.uuid1().time) return ChatCompletionResponse( id=response_id, created=timestamp, model="openelm-450m-instruct", choices=[ ChatCompletionChoice( index=0, message=ChatMessage(role="assistant", content=response_text), finish_reason="stop" ) ], usage=ChatCompletionUsage( prompt_tokens=input_tokens, completion_tokens=output_tokens, total_tokens=input_tokens + output_tokens ) ) except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") @app.post("/v1/messages", response_model=MessageResponse, tags=["Messages"]) async def create_message(params: MessageCreateParams): """Create message (Anthropic API format).""" if global_state["status"] != "READY": if global_state["status"] == "ERROR": raise HTTPException(status_code=503, detail="Model failed to load") raise HTTPException(status_code=503, detail="Model is still loading. Please retry.") model = global_state["model"] tokenizer = global_state["tokenizer"] try: formatted_messages = [] for msg in params.messages: content = msg.content if isinstance(content, list): content = ''.join(b.text for b in content if hasattr(b, 'text')) formatted_messages.append(Message(role=msg.role, content=content)) prompt = format_prompt_for_openelm(formatted_messages, params.system) prompt = truncate_prompt(prompt, 2048 - params.max_tokens, tokenizer, params.system) inputs = tokenizer(prompt, return_tensors="pt") input_tokens = len(inputs.input_ids[0]) if hasattr(model, 'device'): inputs = {k: v.to(model.device) for k, v in inputs.items()} gen_params = {"max_new_tokens": params.max_tokens} if params.temperature is not None: if params.temperature == 0: gen_params["do_sample"] = False else: gen_params["temperature"] = params.temperature gen_params["do_sample"] = True if params.top_p is not None: gen_params["top_p"] = params.top_p import torch with torch.no_grad(): outputs = model.generate( **inputs, **gen_params, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) response_text = extract_assistant_response(generated_text) output_tokens = count_tokens(response_text, tokenizer) return MessageResponse( id=f"msg_{uuid.uuid4().hex[:8]}", role="assistant", content=[ContentBlock(type="text", text=response_text)], model="openelm-450m-instruct", stop_reason="end_turn", usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens) ) except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") # ==================== Main Entry Point ==================== if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) print(f"\nStarting OpenELM API v5 on port {port}...") print("The server will respond immediately while the model loads in background.\n") uvicorn.run( "app:app", host="0.0.0.0", port=port, reload=False, workers=1 )