File size: 5,440 Bytes
fbbd988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
eduai-embedder — tiny embedding microservice.

One process, one model, three routes. Deployed as a Docker Space on
HuggingFace and called by `eduai_platform` (and any other EduAI service
that needs embeddings) so individual developers don't have to install
torch + sentence-transformers locally.

Endpoints
---------
    GET  /health          → {status, model, dim}
    POST /embed           → {embeddings: [[float]], model, dim}
    POST /embed_one       → {embedding: [float], model, dim}

Authentication
--------------
If the `EMBEDDER_API_KEY` env var is set, all routes except /health
require an `X-API-Key` header that matches it. Leave it unset only for
local dev (the default in `.env.example` makes you set one).

Configuration (env vars)
------------------------
    EMBEDDER_MODEL_NAME    sentence-transformers model id (default: all-MiniLM-L6-v2)
    EMBEDDER_API_KEY       shared secret; if set, required on /embed* routes
    EMBEDDER_MAX_BATCH     reject batches larger than this (default: 128)
    EMBEDDER_MAX_TEXT_LEN  reject texts longer than this many characters (default: 8000)
    EMBEDDER_CORS          comma-separated allow-origins (default: *)
"""

import logging
import os
from typing import List, Optional

from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer

# ----------------------------------------------------------------------------- config

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
log = logging.getLogger("eduai-embedder")

MODEL_NAME = os.getenv("EMBEDDER_MODEL_NAME", "all-MiniLM-L6-v2")
API_KEY = os.getenv("EMBEDDER_API_KEY", "")
MAX_BATCH = int(os.getenv("EMBEDDER_MAX_BATCH", "128"))
MAX_TEXT_LEN = int(os.getenv("EMBEDDER_MAX_TEXT_LEN", "8000"))
CORS_ORIGINS = [o.strip() for o in os.getenv("EMBEDDER_CORS", "*").split(",") if o.strip()]

# ----------------------------------------------------------------------------- model

log.info("Loading sentence-transformers model: %s ...", MODEL_NAME)
_model = SentenceTransformer(MODEL_NAME)
DIM = _model.get_sentence_embedding_dimension()
log.info("Model loaded (dim=%d, normalize_embeddings=True)", DIM)

# ----------------------------------------------------------------------------- app

app = FastAPI(
    title="eduai-embedder",
    description="Tiny embedding microservice for the EduAI platform.",
    version="0.1.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=CORS_ORIGINS,
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)

# ----------------------------------------------------------------------------- schemas


class EmbedBatchIn(BaseModel):
    texts: List[str] = Field(..., min_length=1, description="Texts to embed.")


class EmbedOneIn(BaseModel):
    text: str = Field(..., min_length=1)


class EmbedOut(BaseModel):
    embeddings: List[List[float]]
    model: str
    dim: int


class EmbedOneOut(BaseModel):
    embedding: List[float]
    model: str
    dim: int


class HealthOut(BaseModel):
    status: str
    model: str
    dim: int


# ----------------------------------------------------------------------------- auth


def require_api_key(x_api_key: Optional[str] = Header(default=None, alias="X-API-Key")) -> None:
    """Reject requests if EMBEDDER_API_KEY is set and the header doesn't match."""
    if not API_KEY:
        return  # open mode (intended for local dev only)
    if x_api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid or missing API key.")


# ----------------------------------------------------------------------------- routes


@app.get("/", response_model=HealthOut, tags=["health"])
@app.get("/health", response_model=HealthOut, tags=["health"])
def health() -> HealthOut:
    """Liveness probe. Always public; HF Spaces' built-in checks rely on this."""
    return HealthOut(status="ok", model=MODEL_NAME, dim=DIM)


@app.post(
    "/embed",
    response_model=EmbedOut,
    tags=["embeddings"],
    dependencies=[Depends(require_api_key)],
)
def embed_batch(body: EmbedBatchIn) -> EmbedOut:
    """Embed a batch of texts. Vectors are L2-normalized for cosine similarity."""
    if len(body.texts) > MAX_BATCH:
        raise HTTPException(status_code=400, detail=f"Batch too large (max {MAX_BATCH}).")
    for i, text in enumerate(body.texts):
        if len(text) > MAX_TEXT_LEN:
            raise HTTPException(
                status_code=400,
                detail=f"Text at index {i} too long (max {MAX_TEXT_LEN} characters).",
            )
    vectors = _model.encode(
        body.texts,
        normalize_embeddings=True,
        batch_size=64,
    ).tolist()
    return EmbedOut(embeddings=vectors, model=MODEL_NAME, dim=DIM)


@app.post(
    "/embed_one",
    response_model=EmbedOneOut,
    tags=["embeddings"],
    dependencies=[Depends(require_api_key)],
)
def embed_one(body: EmbedOneIn) -> EmbedOneOut:
    """Embed a single text — convenience for chat query embeddings."""
    if len(body.text) > MAX_TEXT_LEN:
        raise HTTPException(
            status_code=400,
            detail=f"Text too long (max {MAX_TEXT_LEN} characters).",
        )
    vector = _model.encode(body.text, normalize_embeddings=True).tolist()
    return EmbedOneOut(embedding=vector, model=MODEL_NAME, dim=DIM)