File size: 2,665 Bytes
983d8eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | """FastAPI embeddings sidecar.
Exposes two endpoints backed by `fastembed`:
POST /embed/dense -> single vectors (sentence-transformers/all-MiniLM-L6-v2)
POST /embed/colbert -> per-token matrices (colbert-ir/colbertv2.0)
Models are loaded lazily on first request and reused for the lifetime of the
process. The Next.js app calls this service via plain HTTP.
"""
from __future__ import annotations
import os
from contextlib import asynccontextmanager
from pathlib import Path
from typing import List
from dotenv import load_dotenv
# Load shared env from the parent project so HF_TOKEN, RAG_*_MODEL, etc. flow
# through without requiring the user to export them by hand.
_PARENT_ENV = Path(__file__).resolve().parent.parent / ".env.local"
if _PARENT_ENV.is_file():
load_dotenv(_PARENT_ENV)
from fastapi import FastAPI
from pydantic import BaseModel
from fastembed import TextEmbedding, LateInteractionTextEmbedding
DENSE_MODEL = os.environ.get("RAG_DENSE_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
LATE_MODEL = os.environ.get("RAG_LATE_MODEL", "colbert-ir/colbertv2.0")
_models: dict[str, object] = {}
def _dense() -> TextEmbedding:
if "dense" not in _models:
_models["dense"] = TextEmbedding(model_name=DENSE_MODEL)
return _models["dense"] # type: ignore[return-value]
def _colbert() -> LateInteractionTextEmbedding:
if "colbert" not in _models:
_models["colbert"] = LateInteractionTextEmbedding(model_name=LATE_MODEL)
return _models["colbert"] # type: ignore[return-value]
@asynccontextmanager
async def lifespan(_app: FastAPI):
_dense()
_colbert()
yield
app = FastAPI(lifespan=lifespan)
class EmbedRequest(BaseModel):
texts: List[str]
class DenseResponse(BaseModel):
vectors: List[List[float]]
model: str
class ColbertResponse(BaseModel):
vectors: List[List[List[float]]]
model: str
@app.get("/health")
def health():
return {"ok": True, "dense_model": DENSE_MODEL, "late_model": LATE_MODEL}
@app.post("/embed/dense", response_model=DenseResponse)
def embed_dense(req: EmbedRequest):
vectors = [v.tolist() for v in _dense().embed(req.texts)]
return {"vectors": vectors, "model": DENSE_MODEL}
@app.post("/embed/colbert", response_model=ColbertResponse)
def embed_colbert(req: EmbedRequest):
vectors = [v.tolist() for v in _colbert().embed(req.texts)]
return {"vectors": vectors, "model": LATE_MODEL}
@app.post("/embed/colbert/query", response_model=ColbertResponse)
def embed_colbert_query(req: EmbedRequest):
vectors = [v.tolist() for v in _colbert().query_embed(req.texts)]
return {"vectors": vectors, "model": LATE_MODEL}
|