Spaces:
Sleeping
Sleeping
File size: 5,440 Bytes
fbbd988 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
eduai-embedder — tiny embedding microservice.
One process, one model, three routes. Deployed as a Docker Space on
HuggingFace and called by `eduai_platform` (and any other EduAI service
that needs embeddings) so individual developers don't have to install
torch + sentence-transformers locally.
Endpoints
---------
GET /health → {status, model, dim}
POST /embed → {embeddings: [[float]], model, dim}
POST /embed_one → {embedding: [float], model, dim}
Authentication
--------------
If the `EMBEDDER_API_KEY` env var is set, all routes except /health
require an `X-API-Key` header that matches it. Leave it unset only for
local dev (the default in `.env.example` makes you set one).
Configuration (env vars)
------------------------
EMBEDDER_MODEL_NAME sentence-transformers model id (default: all-MiniLM-L6-v2)
EMBEDDER_API_KEY shared secret; if set, required on /embed* routes
EMBEDDER_MAX_BATCH reject batches larger than this (default: 128)
EMBEDDER_MAX_TEXT_LEN reject texts longer than this many characters (default: 8000)
EMBEDDER_CORS comma-separated allow-origins (default: *)
"""
import logging
import os
from typing import List, Optional
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
# ----------------------------------------------------------------------------- config
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
log = logging.getLogger("eduai-embedder")
MODEL_NAME = os.getenv("EMBEDDER_MODEL_NAME", "all-MiniLM-L6-v2")
API_KEY = os.getenv("EMBEDDER_API_KEY", "")
MAX_BATCH = int(os.getenv("EMBEDDER_MAX_BATCH", "128"))
MAX_TEXT_LEN = int(os.getenv("EMBEDDER_MAX_TEXT_LEN", "8000"))
CORS_ORIGINS = [o.strip() for o in os.getenv("EMBEDDER_CORS", "*").split(",") if o.strip()]
# ----------------------------------------------------------------------------- model
log.info("Loading sentence-transformers model: %s ...", MODEL_NAME)
_model = SentenceTransformer(MODEL_NAME)
DIM = _model.get_sentence_embedding_dimension()
log.info("Model loaded (dim=%d, normalize_embeddings=True)", DIM)
# ----------------------------------------------------------------------------- app
app = FastAPI(
title="eduai-embedder",
description="Tiny embedding microservice for the EduAI platform.",
version="0.1.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# ----------------------------------------------------------------------------- schemas
class EmbedBatchIn(BaseModel):
texts: List[str] = Field(..., min_length=1, description="Texts to embed.")
class EmbedOneIn(BaseModel):
text: str = Field(..., min_length=1)
class EmbedOut(BaseModel):
embeddings: List[List[float]]
model: str
dim: int
class EmbedOneOut(BaseModel):
embedding: List[float]
model: str
dim: int
class HealthOut(BaseModel):
status: str
model: str
dim: int
# ----------------------------------------------------------------------------- auth
def require_api_key(x_api_key: Optional[str] = Header(default=None, alias="X-API-Key")) -> None:
"""Reject requests if EMBEDDER_API_KEY is set and the header doesn't match."""
if not API_KEY:
return # open mode (intended for local dev only)
if x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing API key.")
# ----------------------------------------------------------------------------- routes
@app.get("/", response_model=HealthOut, tags=["health"])
@app.get("/health", response_model=HealthOut, tags=["health"])
def health() -> HealthOut:
"""Liveness probe. Always public; HF Spaces' built-in checks rely on this."""
return HealthOut(status="ok", model=MODEL_NAME, dim=DIM)
@app.post(
"/embed",
response_model=EmbedOut,
tags=["embeddings"],
dependencies=[Depends(require_api_key)],
)
def embed_batch(body: EmbedBatchIn) -> EmbedOut:
"""Embed a batch of texts. Vectors are L2-normalized for cosine similarity."""
if len(body.texts) > MAX_BATCH:
raise HTTPException(status_code=400, detail=f"Batch too large (max {MAX_BATCH}).")
for i, text in enumerate(body.texts):
if len(text) > MAX_TEXT_LEN:
raise HTTPException(
status_code=400,
detail=f"Text at index {i} too long (max {MAX_TEXT_LEN} characters).",
)
vectors = _model.encode(
body.texts,
normalize_embeddings=True,
batch_size=64,
).tolist()
return EmbedOut(embeddings=vectors, model=MODEL_NAME, dim=DIM)
@app.post(
"/embed_one",
response_model=EmbedOneOut,
tags=["embeddings"],
dependencies=[Depends(require_api_key)],
)
def embed_one(body: EmbedOneIn) -> EmbedOneOut:
"""Embed a single text — convenience for chat query embeddings."""
if len(body.text) > MAX_TEXT_LEN:
raise HTTPException(
status_code=400,
detail=f"Text too long (max {MAX_TEXT_LEN} characters).",
)
vector = _model.encode(body.text, normalize_embeddings=True).tolist()
return EmbedOneOut(embedding=vector, model=MODEL_NAME, dim=DIM)
|