""" eduai-embedder — tiny embedding microservice. One process, one model, three routes. Deployed as a Docker Space on HuggingFace and called by `eduai_platform` (and any other EduAI service that needs embeddings) so individual developers don't have to install torch + sentence-transformers locally. Endpoints --------- GET /health → {status, model, dim} POST /embed → {embeddings: [[float]], model, dim} POST /embed_one → {embedding: [float], model, dim} Authentication -------------- If the `EMBEDDER_API_KEY` env var is set, all routes except /health require an `X-API-Key` header that matches it. Leave it unset only for local dev (the default in `.env.example` makes you set one). Configuration (env vars) ------------------------ EMBEDDER_MODEL_NAME sentence-transformers model id (default: all-MiniLM-L6-v2) EMBEDDER_API_KEY shared secret; if set, required on /embed* routes EMBEDDER_MAX_BATCH reject batches larger than this (default: 128) EMBEDDER_MAX_TEXT_LEN reject texts longer than this many characters (default: 8000) EMBEDDER_CORS comma-separated allow-origins (default: *) """ import logging import os from typing import List, Optional from fastapi import Depends, FastAPI, Header, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from sentence_transformers import SentenceTransformer # ----------------------------------------------------------------------------- config logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s", ) log = logging.getLogger("eduai-embedder") MODEL_NAME = os.getenv("EMBEDDER_MODEL_NAME", "all-MiniLM-L6-v2") API_KEY = os.getenv("EMBEDDER_API_KEY", "") MAX_BATCH = int(os.getenv("EMBEDDER_MAX_BATCH", "128")) MAX_TEXT_LEN = int(os.getenv("EMBEDDER_MAX_TEXT_LEN", "8000")) CORS_ORIGINS = [o.strip() for o in os.getenv("EMBEDDER_CORS", "*").split(",") if o.strip()] # ----------------------------------------------------------------------------- model log.info("Loading sentence-transformers model: %s ...", MODEL_NAME) _model = SentenceTransformer(MODEL_NAME) DIM = _model.get_sentence_embedding_dimension() log.info("Model loaded (dim=%d, normalize_embeddings=True)", DIM) # ----------------------------------------------------------------------------- app app = FastAPI( title="eduai-embedder", description="Tiny embedding microservice for the EduAI platform.", version="0.1.0", ) app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS, allow_methods=["GET", "POST"], allow_headers=["*"], ) # ----------------------------------------------------------------------------- schemas class EmbedBatchIn(BaseModel): texts: List[str] = Field(..., min_length=1, description="Texts to embed.") class EmbedOneIn(BaseModel): text: str = Field(..., min_length=1) class EmbedOut(BaseModel): embeddings: List[List[float]] model: str dim: int class EmbedOneOut(BaseModel): embedding: List[float] model: str dim: int class HealthOut(BaseModel): status: str model: str dim: int # ----------------------------------------------------------------------------- auth def require_api_key(x_api_key: Optional[str] = Header(default=None, alias="X-API-Key")) -> None: """Reject requests if EMBEDDER_API_KEY is set and the header doesn't match.""" if not API_KEY: return # open mode (intended for local dev only) if x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing API key.") # ----------------------------------------------------------------------------- routes @app.get("/", response_model=HealthOut, tags=["health"]) @app.get("/health", response_model=HealthOut, tags=["health"]) def health() -> HealthOut: """Liveness probe. Always public; HF Spaces' built-in checks rely on this.""" return HealthOut(status="ok", model=MODEL_NAME, dim=DIM) @app.post( "/embed", response_model=EmbedOut, tags=["embeddings"], dependencies=[Depends(require_api_key)], ) def embed_batch(body: EmbedBatchIn) -> EmbedOut: """Embed a batch of texts. Vectors are L2-normalized for cosine similarity.""" if len(body.texts) > MAX_BATCH: raise HTTPException(status_code=400, detail=f"Batch too large (max {MAX_BATCH}).") for i, text in enumerate(body.texts): if len(text) > MAX_TEXT_LEN: raise HTTPException( status_code=400, detail=f"Text at index {i} too long (max {MAX_TEXT_LEN} characters).", ) vectors = _model.encode( body.texts, normalize_embeddings=True, batch_size=64, ).tolist() return EmbedOut(embeddings=vectors, model=MODEL_NAME, dim=DIM) @app.post( "/embed_one", response_model=EmbedOneOut, tags=["embeddings"], dependencies=[Depends(require_api_key)], ) def embed_one(body: EmbedOneIn) -> EmbedOneOut: """Embed a single text — convenience for chat query embeddings.""" if len(body.text) > MAX_TEXT_LEN: raise HTTPException( status_code=400, detail=f"Text too long (max {MAX_TEXT_LEN} characters).", ) vector = _model.encode(body.text, normalize_embeddings=True).tolist() return EmbedOneOut(embedding=vector, model=MODEL_NAME, dim=DIM)