File size: 2,037 Bytes
76b352a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Embedding server for multilingual-e5-small on HF Spaces."""

from contextlib import asynccontextmanager
import threading

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

MODEL_NAME = "intfloat/multilingual-e5-small"
model = None
model_ready = threading.Event()


def _load_model():
    global model
    logger.info(f"Loading {MODEL_NAME}...")
    model = SentenceTransformer(MODEL_NAME)
    model_ready.set()
    logger.info("Model loaded successfully")


@asynccontextmanager
async def lifespan(app: FastAPI):
    thread = threading.Thread(target=_load_model, daemon=True)
    thread.start()
    yield


app = FastAPI(title="Embedding Server", lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


class EmbedRequest(BaseModel):
    text: str


class EmbedResponse(BaseModel):
    embedding: list[float]


@app.post("/embed", response_model=EmbedResponse)
async def embed(request: EmbedRequest) -> EmbedResponse:
    if not model_ready.is_set():
        raise HTTPException(status_code=503, detail="Model still loading")
    if not request.text:
        return EmbedResponse(embedding=[])
    prefixed = f"query: {request.text}"
    embedding = model.encode([prefixed], normalize_embeddings=True)[0].tolist()
    return EmbedResponse(embedding=embedding)


@app.get("/health")
async def health():
    return {
        "status": "ok" if model_ready.is_set() else "loading",
        "model": MODEL_NAME,
    }


@app.get("/")
async def root():
    return {
        "service": "Embedding Server",
        "model": MODEL_NAME,
        "ready": model_ready.is_set(),
        "endpoints": {
            "POST /embed": "Generate embeddings",
            "GET /health": "Health check",
        },
    }