File size: 2,451 Bytes
61bd74f
 
 
 
 
e82be7a
61bd74f
 
e82be7a
61bd74f
 
e82be7a
 
 
 
 
a99846a
 
 
 
 
61bd74f
 
e82be7a
 
 
a99846a
61bd74f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a99846a
 
 
 
 
61bd74f
 
a99846a
61bd74f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# infra/hf_spaces/embedder/app.py
# Serves BAAI/bge-small-en-v1.5 embeddings over HTTP.
# Model is loaded from /app/model_cache (baked into the Docker image at build time).

from contextlib import asynccontextmanager
from typing import Annotated, Any

from fastapi import FastAPI
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer

# 64 texts * 2000 chars = 128KB max payload — keeps the free-tier Space under
# its 16GB RAM limit even with the largest expected retrieval batch (top-20).
_MAX_TEXTS = 64
_MAX_TEXT_LEN = 2000

# BGE model card specifies this prefix for query embeddings in asymmetric retrieval.
# Document embeddings must NOT use this prefix — only query-time calls set is_query=True.
# Paper shows 2-4% NDCG improvement over no-prefix symmetric mode.
_BGE_QUERY_PREFIX = "Represent this sentence for searching relevant passages: "


class EmbedRequest(BaseModel):
    texts: list[Annotated[str, Field(max_length=_MAX_TEXT_LEN)]] = Field(
        ..., max_length=_MAX_TEXTS
    )
    is_query: bool = False  # True → prepend BGE asymmetric query instruction


class EmbedResponse(BaseModel):
    embeddings: list[list[float]]


@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load from baked-in cache path — no network call at startup.
    # BGE normalises embeddings by default; no manual L2 step needed.
    app.state.model = SentenceTransformer(
        "BAAI/bge-small-en-v1.5",
        cache_folder="/app/model_cache",
    )
    app.state.model.eval()
    yield
    app.state.model = None


app = FastAPI(
    title="PersonaBot Embedder",
    lifespan=lifespan,
    docs_url=None,
    redoc_url=None,
    openapi_url=None,
)


@app.get("/health")
async def health() -> dict[str, str]:
    if app.state.model is None:
        return {"status": "loading"}
    return {"status": "ok"}


@app.post("/embed", response_model=EmbedResponse)
async def embed(request: EmbedRequest) -> EmbedResponse:
    if not request.texts:
        return EmbedResponse(embeddings=[])
    texts = (
        [_BGE_QUERY_PREFIX + t for t in request.texts]
        if request.is_query
        else request.texts
    )
    # encode with batch_size=32, returns numpy array shape (N, 384)
    vectors: Any = app.state.model.encode(
        texts,
        batch_size=32,
        normalize_embeddings=True,
        show_progress_bar=False,
    )
    return EmbedResponse(embeddings=vectors.tolist())