Spaces:
Running
Running
| import os | |
| import logging | |
| import asyncio | |
| import multiprocessing | |
| from contextlib import asynccontextmanager | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Union, List, Optional, Any | |
| from fastapi import FastAPI, HTTPException, Security, Depends | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| # Import the new MultiEmbeddingService | |
| from model_service import MultiEmbeddingService | |
| # ============================================================================ | |
| # LOGGING | |
| # ============================================================================ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("EmbedAPI") | |
| # ============================================================================ | |
| # CONFIGURATION | |
| # ============================================================================ | |
| AUTH_TOKEN = os.getenv('AUTH_TOKEN', None) | |
| ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*').split(',') | |
| # Global context container | |
| ml_context = { | |
| "service": None, | |
| "executor": None | |
| } | |
| # ============================================================================ | |
| # LIFESPAN MANAGER | |
| # ============================================================================ | |
| async def lifespan(app: FastAPI): | |
| """Lifecycle manager: Loads models and thread pool.""" | |
| # --- Startup --- | |
| logger.info("Initializing Multi-Dimensional Embedding Service...") | |
| # 1. Thread Pool | |
| cpu_count = multiprocessing.cpu_count() | |
| max_workers = cpu_count * 2 | |
| executor = ThreadPoolExecutor(max_workers=max_workers) | |
| ml_context["executor"] = executor | |
| logger.info(f"Thread pool ready: {max_workers} workers") | |
| # 2. Load Models | |
| try: | |
| service = MultiEmbeddingService() | |
| service.load_all_models() # Loads 384, 768, 1024 models | |
| ml_context["service"] = service | |
| except Exception as e: | |
| logger.critical(f"Critical error loading models: {e}", exc_info=True) | |
| raise e | |
| if AUTH_TOKEN: | |
| logger.info("🔒 Auth enabled.") | |
| yield | |
| # --- Shutdown --- | |
| logger.info("Shutting down...") | |
| if ml_context["executor"]: | |
| ml_context["executor"].shutdown(wait=True) | |
| ml_context.clear() | |
| # ============================================================================ | |
| # APP SETUP | |
| # ============================================================================ | |
| app = FastAPI( | |
| title="Multi-Dim Embedding API", | |
| version="3.0.0", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=ALLOWED_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| security = HTTPBearer(auto_error=False) | |
| async def verify_token(credentials: Optional[HTTPAuthorizationCredentials] = Security(security)): | |
| if not AUTH_TOKEN: | |
| return True | |
| if not credentials or credentials.credentials != AUTH_TOKEN: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| return True | |
| # ============================================================================ | |
| # MODELS | |
| # ============================================================================ | |
| class EmbedRequest(BaseModel): | |
| data: Union[str, List[str]] = Field(..., description="Text string or list of strings") | |
| dimension: int = Field(768, description="Target dimension (384, 768, or 1024)") | |
| model_config = { | |
| "json_schema_extra": { | |
| "example": { | |
| "data": ["Hello world", "Machine learning is great"], | |
| "dimension": 768 | |
| } | |
| } | |
| } | |
| class EmbedResponse(BaseModel): | |
| embeddings: Union[List[float], List[List[float]]] = Field(...) | |
| dimension: int | |
| count: int | |
| class DeEmbedRequest(BaseModel): | |
| vector: List[float] = Field(..., description="The embedding vector to decode") | |
| # ============================================================================ | |
| # ENDPOINTS | |
| # ============================================================================ | |
| async def health_check(): | |
| service = ml_context.get("service") | |
| if not service: | |
| raise HTTPException(status_code=503, detail="Service not ready") | |
| return { | |
| "status": "healthy", | |
| "loaded_dimensions": list(service.models.keys()) | |
| } | |
| async def create_embeddings(request: EmbedRequest): | |
| """ | |
| Generate embeddings for specific dimensions. | |
| Supported dimensions: 384, 768, 1024. | |
| """ | |
| service = ml_context.get("service") | |
| executor = ml_context.get("executor") | |
| if not service or not executor: | |
| raise HTTPException(status_code=503, detail="Service unavailable") | |
| if request.dimension not in service.models: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Dimension {request.dimension} not supported. Use 384, 768, or 1024." | |
| ) | |
| try: | |
| is_single = isinstance(request.data, str) | |
| count = 1 if is_single else len(request.data) | |
| loop = asyncio.get_running_loop() | |
| embeddings = await loop.run_in_executor( | |
| executor, | |
| service.generate_embedding, | |
| request.data, | |
| request.dimension | |
| ) | |
| return EmbedResponse( | |
| embeddings=embeddings, | |
| dimension=request.dimension, | |
| count=count | |
| ) | |
| except Exception as e: | |
| logger.error(f"Inference error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def ping(): | |
| return {"message": "embed-api is alive!"} | |
| async def root(): | |
| return {"version": "3.0.0", "message": "Multi-Dimensional Embedding API Server is running."} |