"""Embedding server for multilingual-e5-small on HF Spaces.""" from contextlib import asynccontextmanager import threading from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sentence_transformers import SentenceTransformer import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) MODEL_NAME = "intfloat/multilingual-e5-small" model = None model_ready = threading.Event() def _load_model(): global model logger.info(f"Loading {MODEL_NAME}...") model = SentenceTransformer(MODEL_NAME) model_ready.set() logger.info("Model loaded successfully") @asynccontextmanager async def lifespan(app: FastAPI): thread = threading.Thread(target=_load_model, daemon=True) thread.start() yield app = FastAPI(title="Embedding Server", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class EmbedRequest(BaseModel): text: str class EmbedResponse(BaseModel): embedding: list[float] @app.post("/embed", response_model=EmbedResponse) async def embed(request: EmbedRequest) -> EmbedResponse: if not model_ready.is_set(): raise HTTPException(status_code=503, detail="Model still loading") if not request.text: return EmbedResponse(embedding=[]) prefixed = f"query: {request.text}" embedding = model.encode([prefixed], normalize_embeddings=True)[0].tolist() return EmbedResponse(embedding=embedding) @app.get("/health") async def health(): return { "status": "ok" if model_ready.is_set() else "loading", "model": MODEL_NAME, } @app.get("/") async def root(): return { "service": "Embedding Server", "model": MODEL_NAME, "ready": model_ready.is_set(), "endpoints": { "POST /embed": "Generate embeddings", "GET /health": "Health check", }, }