Spaces:
Runtime error
Runtime error
| """ | |
| OpenELM OpenAI & Anthropic API Compatible Wrapper - v5 | |
| Minimal lazy-loading architecture for instant startup. | |
| Heavy imports (torch, transformers) are deferred to a background thread. | |
| """ | |
| import uuid | |
| import os | |
| import sys | |
| import time | |
| import asyncio | |
| import threading | |
| from contextlib import asynccontextmanager | |
| from typing import AsyncIterator, List, Optional, Dict, Any | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| # Global state for lazy loading | |
| # This allows the server to respond immediately while model loads in background | |
| global_state = { | |
| "status": "INITIALIZING", # INITIALIZING -> LOADING -> READY -> ERROR | |
| "model": None, | |
| "tokenizer": None, | |
| "error": None | |
| } | |
| def model_loader_thread(): | |
| """Load model in background thread to avoid blocking startup.""" | |
| global global_state | |
| try: | |
| # Import heavy libraries INSIDE the thread | |
| import torch | |
| import sys | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import hf_hub_download | |
| global_state["status"] = "LOADING" | |
| model_id = "apple/OpenELM-450M-Instruct" | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True | |
| ) | |
| # Set special tokens | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if tokenizer.bos_token is None: | |
| tokenizer.bos_token = "<s>" | |
| if tokenizer.eos_token is None: | |
| tokenizer.eos_token = "</s>" | |
| global_state["tokenizer"] = tokenizer | |
| print("Tokenizer loaded") | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, | |
| use_safetensors=True, | |
| trust_remote_code=True | |
| ) | |
| model.eval() | |
| global_state["model"] = model | |
| global_state["status"] = "READY" | |
| print(f"Model loaded successfully! Device: {next(model.parameters()).device}") | |
| except Exception as e: | |
| global_state["error"] = str(e) | |
| global_state["status"] = "ERROR" | |
| print(f"Error loading model: {e}") | |
| async def lifespan(app: FastAPI) -> AsyncIterator: | |
| """Application lifespan: Start background loader, then yield.""" | |
| global global_state | |
| print("=" * 60) | |
| print("OpenELM API v5 - Starting with background model loader") | |
| print("=" * 60) | |
| print("Server will respond immediately. Model loads in background.") | |
| print("Endpoints:") | |
| print(" POST /v1/chat/completions - OpenAI format") | |
| print(" POST /v1/messages - Anthropic format") | |
| print(" GET /health - Check model status") | |
| print("=" * 60) | |
| # Start background thread to load model | |
| loader_thread = threading.Thread(target=model_loader_thread, daemon=True) | |
| loader_thread.start() | |
| yield | |
| # Cleanup on shutdown | |
| if global_state["model"] is not None: | |
| del global_state["model"] | |
| if global_state["tokenizer"] is not None: | |
| del global_state["tokenizer"] | |
| if "torch" in sys.modules: | |
| import torch | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| # Create FastAPI app | |
| # Note: No heavy imports at module level - only fastapi and pydantic | |
| app = FastAPI( | |
| title="OpenELM OpenAI API", | |
| description="OpenAI and Anthropic API compatible wrapper for OpenELM models", | |
| version="5.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ==================== Pydantic Models ==================== | |
| class MessageContent(BaseModel): | |
| type: str = "text" | |
| text: str | |
| class Message(BaseModel): | |
| role: str | |
| content: str | List[MessageContent] | |
| name: Optional[str] = None | |
| class Usage(BaseModel): | |
| input_tokens: int = 0 | |
| output_tokens: int = 0 | |
| total_tokens: int = 0 | |
| class ContentBlock(BaseModel): | |
| type: str = "text" | |
| text: str | |
| class MessageResponse(BaseModel): | |
| id: str | |
| type: str = "message" | |
| role: str = "assistant" | |
| content: List[ContentBlock] | |
| model: str | |
| stop_reason: Optional[str] = None | |
| stop_sequence: Optional[str] = None | |
| usage: Usage | |
| class MessageCreateParams(BaseModel): | |
| model: str = "openelm-450m-instruct" | |
| messages: List[Message] | |
| system: Optional[str] = None | |
| max_tokens: int = Field(default=1024, ge=1, le=4096) | |
| temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0) | |
| top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) | |
| stream: Optional[bool] = False | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: str | |
| name: Optional[str] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = "openelm-450m-instruct" | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) | |
| top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) | |
| max_tokens: Optional[int] = Field(default=None, ge=1, le=4096) | |
| stream: Optional[bool] = False | |
| class ChatCompletionChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: Optional[str] = None | |
| class ChatCompletionUsage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| usage: ChatCompletionUsage | |
| class OpenAIModelInfo(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int = 0 | |
| owned_by: str = "openelm" | |
| permission: List[Any] = [] | |
| class OpenAIModelListResponse(BaseModel): | |
| object: str = "list" | |
| data: List[OpenAIModelInfo] | |
| # ==================== Helper Functions ==================== | |
| def format_prompt_for_openelm(messages: List[Message], system: Optional[str] = None) -> str: | |
| """Format messages into a prompt suitable for OpenELM.""" | |
| prompt_parts = [] | |
| if system: | |
| prompt_parts.append(f"[System: {system}]") | |
| for msg in messages: | |
| role = msg.role.lower() | |
| content = msg.content | |
| if isinstance(content, list): | |
| text_parts = [b.text for b in content if hasattr(b, 'text')] | |
| content = ''.join(text_parts) | |
| elif not isinstance(content, str): | |
| content = str(content) | |
| if role == "user": | |
| prompt_parts.append(f"User: {content}") | |
| elif role == "assistant": | |
| prompt_parts.append(f"Assistant: {content}") | |
| else: | |
| prompt_parts.append(f"{role}: {content}") | |
| prompt_parts.append("Assistant:") | |
| return "\n\n".join(prompt_parts) | |
| def count_tokens(text: str, tokenizer) -> int: | |
| """Count tokens using the tokenizer.""" | |
| try: | |
| return len(tokenizer.encode(text)) | |
| except: | |
| return max(1, len(text) // 4) | |
| def truncate_prompt(prompt: str, max_tokens: int, tokenizer, system: Optional[str] = None) -> str: | |
| """Truncate prompt to fit within context window.""" | |
| current_tokens = count_tokens(prompt, tokenizer) | |
| if current_tokens <= max_tokens: | |
| return prompt | |
| lines = prompt.split("\n\n") | |
| system_line = None | |
| if lines and lines[0].startswith("[System:"): | |
| system_line = lines[0] | |
| lines = lines[1:] | |
| truncated_lines = [] | |
| for line in reversed(lines): | |
| truncated_lines.insert(0, line) | |
| test_prompt = "\n\n".join([system_line] + truncated_lines) if system_line else "\n\n".join(truncated_lines) | |
| if count_tokens(test_prompt, tokenizer) <= max_tokens: | |
| break | |
| if system_line: | |
| return "\n\n".join([system_line] + truncated_lines) | |
| return "\n\n".join(truncated_lines) | |
| def extract_assistant_response(generated_text: str) -> str: | |
| """Extract assistant response from generated text.""" | |
| if "Assistant:" in generated_text: | |
| return generated_text.split("Assistant:")[-1].strip() | |
| lines = generated_text.split("\n") | |
| response_parts = [] | |
| in_assistant = False | |
| for line in lines: | |
| if line.startswith("Assistant:"): | |
| in_assistant = True | |
| response_parts.append(line.replace("Assistant:", "").strip()) | |
| elif in_assistant and not line.startswith("User:") and not line.startswith("System:"): | |
| response_parts.append(line) | |
| elif line.startswith("User:") or line.startswith("System:"): | |
| in_assistant = False | |
| return "\n".join(response_parts).strip() | |
| # ==================== API Endpoints ==================== | |
| async def root(): | |
| """Root endpoint with API information.""" | |
| return { | |
| "name": "OpenELM OpenAI API v5", | |
| "version": "5.0.0", | |
| "status": global_state["status"], | |
| "model_loaded": global_state["status"] == "READY", | |
| "endpoints": { | |
| "chat": "POST /v1/chat/completions", | |
| "messages": "POST /v1/messages", | |
| "health": "GET /health" | |
| }, | |
| "note": "Model loads in background for instant startup" | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| if global_state["status"] == "READY": | |
| return {"status": "healthy", "model_loaded": True} | |
| elif global_state["status"] == "ERROR": | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Model failed to load: {global_state.get('error', 'Unknown error')}" | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model is still loading. Please retry in a few moments." | |
| ) | |
| async def list_models(): | |
| """List available models (OpenAI format).""" | |
| return OpenAIModelListResponse( | |
| data=[ | |
| OpenAIModelInfo( | |
| id="openelm-450m-instruct", | |
| owned_by="apple", | |
| created=int(uuid.uuid1().time) | |
| ) | |
| ] | |
| ) | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| """Create chat completion (OpenAI API format).""" | |
| if global_state["status"] != "READY": | |
| if global_state["status"] == "ERROR": | |
| raise HTTPException(status_code=503, detail="Model failed to load") | |
| raise HTTPException(status_code=503, detail="Model is still loading. Please retry.") | |
| model = global_state["model"] | |
| tokenizer = global_state["tokenizer"] | |
| try: | |
| system_message = None | |
| formatted_messages = [] | |
| for msg in request.messages: | |
| if msg.role == "system" and system_message is None: | |
| system_message = msg.content | |
| else: | |
| formatted_messages.append(Message(role=msg.role, content=msg.content)) | |
| prompt = format_prompt_for_openelm(formatted_messages, system_message) | |
| max_tokens = request.max_tokens or 1024 | |
| prompt = truncate_prompt(prompt, 2048 - max_tokens, tokenizer, system_message) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_tokens = len(inputs.input_ids[0]) | |
| if hasattr(model, 'device'): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| gen_params = {"max_new_tokens": max_tokens} | |
| if request.temperature is not None: | |
| if request.temperature == 0: | |
| gen_params["do_sample"] = False | |
| else: | |
| gen_params["temperature"] = request.temperature | |
| gen_params["do_sample"] = True | |
| if request.top_p is not None: | |
| gen_params["top_p"] = request.top_p | |
| import torch | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| **gen_params, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response_text = extract_assistant_response(generated_text) | |
| output_tokens = count_tokens(response_text, tokenizer) | |
| response_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" | |
| timestamp = int(uuid.uuid1().time) | |
| return ChatCompletionResponse( | |
| id=response_id, | |
| created=timestamp, | |
| model="openelm-450m-instruct", | |
| choices=[ | |
| ChatCompletionChoice( | |
| index=0, | |
| message=ChatMessage(role="assistant", content=response_text), | |
| finish_reason="stop" | |
| ) | |
| ], | |
| usage=ChatCompletionUsage( | |
| prompt_tokens=input_tokens, | |
| completion_tokens=output_tokens, | |
| total_tokens=input_tokens + output_tokens | |
| ) | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| async def create_message(params: MessageCreateParams): | |
| """Create message (Anthropic API format).""" | |
| if global_state["status"] != "READY": | |
| if global_state["status"] == "ERROR": | |
| raise HTTPException(status_code=503, detail="Model failed to load") | |
| raise HTTPException(status_code=503, detail="Model is still loading. Please retry.") | |
| model = global_state["model"] | |
| tokenizer = global_state["tokenizer"] | |
| try: | |
| formatted_messages = [] | |
| for msg in params.messages: | |
| content = msg.content | |
| if isinstance(content, list): | |
| content = ''.join(b.text for b in content if hasattr(b, 'text')) | |
| formatted_messages.append(Message(role=msg.role, content=content)) | |
| prompt = format_prompt_for_openelm(formatted_messages, params.system) | |
| prompt = truncate_prompt(prompt, 2048 - params.max_tokens, tokenizer, params.system) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_tokens = len(inputs.input_ids[0]) | |
| if hasattr(model, 'device'): | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| gen_params = {"max_new_tokens": params.max_tokens} | |
| if params.temperature is not None: | |
| if params.temperature == 0: | |
| gen_params["do_sample"] = False | |
| else: | |
| gen_params["temperature"] = params.temperature | |
| gen_params["do_sample"] = True | |
| if params.top_p is not None: | |
| gen_params["top_p"] = params.top_p | |
| import torch | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| **gen_params, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response_text = extract_assistant_response(generated_text) | |
| output_tokens = count_tokens(response_text, tokenizer) | |
| return MessageResponse( | |
| id=f"msg_{uuid.uuid4().hex[:8]}", | |
| role="assistant", | |
| content=[ContentBlock(type="text", text=response_text)], | |
| model="openelm-450m-instruct", | |
| stop_reason="end_turn", | |
| usage=Usage(input_tokens=input_tokens, output_tokens=output_tokens) | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") | |
| # ==================== Main Entry Point ==================== | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| print(f"\nStarting OpenELM API v5 on port {port}...") | |
| print("The server will respond immediately while the model loads in background.\n") | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=False, | |
| workers=1 | |
| ) | |