ibrahimdaud's picture
feat: FastAPI embedding service for eduai_platform
fbbd988
"""
eduai-embedder — tiny embedding microservice.
One process, one model, three routes. Deployed as a Docker Space on
HuggingFace and called by `eduai_platform` (and any other EduAI service
that needs embeddings) so individual developers don't have to install
torch + sentence-transformers locally.
Endpoints
---------
GET /health → {status, model, dim}
POST /embed → {embeddings: [[float]], model, dim}
POST /embed_one → {embedding: [float], model, dim}
Authentication
--------------
If the `EMBEDDER_API_KEY` env var is set, all routes except /health
require an `X-API-Key` header that matches it. Leave it unset only for
local dev (the default in `.env.example` makes you set one).
Configuration (env vars)
------------------------
EMBEDDER_MODEL_NAME sentence-transformers model id (default: all-MiniLM-L6-v2)
EMBEDDER_API_KEY shared secret; if set, required on /embed* routes
EMBEDDER_MAX_BATCH reject batches larger than this (default: 128)
EMBEDDER_MAX_TEXT_LEN reject texts longer than this many characters (default: 8000)
EMBEDDER_CORS comma-separated allow-origins (default: *)
"""
import logging
import os
from typing import List, Optional
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
# ----------------------------------------------------------------------------- config
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
log = logging.getLogger("eduai-embedder")
MODEL_NAME = os.getenv("EMBEDDER_MODEL_NAME", "all-MiniLM-L6-v2")
API_KEY = os.getenv("EMBEDDER_API_KEY", "")
MAX_BATCH = int(os.getenv("EMBEDDER_MAX_BATCH", "128"))
MAX_TEXT_LEN = int(os.getenv("EMBEDDER_MAX_TEXT_LEN", "8000"))
CORS_ORIGINS = [o.strip() for o in os.getenv("EMBEDDER_CORS", "*").split(",") if o.strip()]
# ----------------------------------------------------------------------------- model
log.info("Loading sentence-transformers model: %s ...", MODEL_NAME)
_model = SentenceTransformer(MODEL_NAME)
DIM = _model.get_sentence_embedding_dimension()
log.info("Model loaded (dim=%d, normalize_embeddings=True)", DIM)
# ----------------------------------------------------------------------------- app
app = FastAPI(
title="eduai-embedder",
description="Tiny embedding microservice for the EduAI platform.",
version="0.1.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ----------------------------------------------------------------------------- schemas
class EmbedBatchIn(BaseModel):
texts: List[str] = Field(..., min_length=1, description="Texts to embed.")
class EmbedOneIn(BaseModel):
text: str = Field(..., min_length=1)
class EmbedOut(BaseModel):
embeddings: List[List[float]]
model: str
dim: int
class EmbedOneOut(BaseModel):
embedding: List[float]
model: str
dim: int
class HealthOut(BaseModel):
status: str
model: str
dim: int
# ----------------------------------------------------------------------------- auth
def require_api_key(x_api_key: Optional[str] = Header(default=None, alias="X-API-Key")) -> None:
"""Reject requests if EMBEDDER_API_KEY is set and the header doesn't match."""
if not API_KEY:
return # open mode (intended for local dev only)
if x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing API key.")
# ----------------------------------------------------------------------------- routes
@app.get("/", response_model=HealthOut, tags=["health"])
@app.get("/health", response_model=HealthOut, tags=["health"])
def health() -> HealthOut:
"""Liveness probe. Always public; HF Spaces' built-in checks rely on this."""
return HealthOut(status="ok", model=MODEL_NAME, dim=DIM)
@app.post(
"/embed",
response_model=EmbedOut,
tags=["embeddings"],
dependencies=[Depends(require_api_key)],
)
def embed_batch(body: EmbedBatchIn) -> EmbedOut:
"""Embed a batch of texts. Vectors are L2-normalized for cosine similarity."""
if len(body.texts) > MAX_BATCH:
raise HTTPException(status_code=400, detail=f"Batch too large (max {MAX_BATCH}).")
for i, text in enumerate(body.texts):
if len(text) > MAX_TEXT_LEN:
raise HTTPException(
status_code=400,
detail=f"Text at index {i} too long (max {MAX_TEXT_LEN} characters).",
)
vectors = _model.encode(
body.texts,
normalize_embeddings=True,
batch_size=64,
).tolist()
return EmbedOut(embeddings=vectors, model=MODEL_NAME, dim=DIM)
@app.post(
"/embed_one",
response_model=EmbedOneOut,
tags=["embeddings"],
dependencies=[Depends(require_api_key)],
)
def embed_one(body: EmbedOneIn) -> EmbedOneOut:
"""Embed a single text — convenience for chat query embeddings."""
if len(body.text) > MAX_TEXT_LEN:
raise HTTPException(
status_code=400,
detail=f"Text too long (max {MAX_TEXT_LEN} characters).",
)
vector = _model.encode(body.text, normalize_embeddings=True).tolist()
return EmbedOneOut(embedding=vector, model=MODEL_NAME, dim=DIM)