GitHub Action
Sync from GitHub (f3f1952651810a9db23087a4db658bc927992b2d)
142cce7
# ml_service/main.py
import logging
import time
import os
import torch
from typing import List
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field, constr
from sentence_transformers import SentenceTransformer
from fastembed import SparseTextEmbedding
# -----------------------------
# Configuration
# -----------------------------
MAX_TEXT_LENGTH = 5000
MAX_BATCH_SIZE = 32
DENSE_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5"
SPARSE_MODEL_NAME = "prithivida/Splade_PP_en_v1"
# -----------------------------
# Structured Logging Setup
# -----------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
logger = logging.getLogger("athena.vector_engine")
# -----------------------------
# Lifespan Management
# -----------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("🧠 Booting Vector Engine...")
start_time = time.time()
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# Load dense model
dense_model = SentenceTransformer(
DENSE_MODEL_NAME,
trust_remote_code=True,
device=device,
)
# Load sparse model
sparse_model = SparseTextEmbedding(
model_name=SPARSE_MODEL_NAME
)
# Warmup (prevents cold-start latency spike)
logger.info("🔥 Warming up models...")
dense_model.encode("warmup", normalize_embeddings=True)
list(sparse_model.embed(["warmup"]))
# Attach to app state
app.state.dense_model = dense_model
app.state.sparse_model = sparse_model
app.state.device = device
app.state.start_time = time.time()
duration = time.time() - start_time
logger.info(f"✅ Models loaded successfully in {duration:.2f}s")
yield
except Exception as e:
logger.exception("❌ Failed during startup")
raise e
finally:
logger.info("🛑 Shutting down Vector Engine...")
app.state.__dict__.clear()
# -----------------------------
# FastAPI App
# -----------------------------
app = FastAPI(
title="Athena Vector Engine",
description="Production-grade ML microservice for dense + sparse embeddings",
version="2.0.0",
lifespan=lifespan,
)
# -----------------------------
# Schemas
# -----------------------------
class VectorRequest(BaseModel):
texts: List[constr(min_length=1, max_length=MAX_TEXT_LENGTH)] = Field(
..., description="List of input texts to embed"
)
class SparseData(BaseModel):
indices: List[int]
values: List[float]
class VectorResponse(BaseModel):
dense_vectors: List[List[float]]
sparse_vectors: List[SparseData]
# -----------------------------
# Embedding Endpoint
# -----------------------------
@app.post("/vectorize", response_model=VectorResponse)
def generate_vectors(req: VectorRequest, request: Request):
if len(req.texts) > MAX_BATCH_SIZE:
raise HTTPException(
status_code=400,
detail=f"Batch size exceeds maximum limit of {MAX_BATCH_SIZE}",
)
dense_model = request.app.state.dense_model
sparse_model = request.app.state.sparse_model
try:
start_time = time.perf_counter()
# Prefix required for Nomic retrieval queries
prefixed_texts = [f"search_query: {text}" for text in req.texts]
# Dense embeddings (batched)
dense_results = dense_model.encode(
prefixed_texts,
normalize_embeddings=True,
batch_size=len(prefixed_texts),
).tolist()
# Sparse embeddings (batched)
sparse_raw = list(sparse_model.embed(req.texts))
sparse_results = [
{
"indices": vec.indices.tolist(),
"values": vec.values.tolist(),
}
for vec in sparse_raw
]
duration = time.perf_counter() - start_time
logger.info(
f"Vectorized batch_size={len(req.texts)} "
f"latency={duration:.4f}s"
)
return {
"dense_vectors": dense_results,
"sparse_vectors": sparse_results,
}
except Exception as e:
logger.exception("🔥 Vectorization failed")
raise HTTPException(
status_code=500,
detail="Failed to generate embeddings",
)
# -----------------------------
# Health Endpoints
# -----------------------------
@app.api_route("/health/live", methods=["GET", "HEAD"])
async def liveness():
return {"status": "alive"}
@app.api_route("/health/ready", methods=["GET", "HEAD"])
async def readiness(request: Request):
ready = (
hasattr(request.app.state, "dense_model")
and hasattr(request.app.state, "sparse_model")
)
return {"ready": ready}
# -----------------------------
# Metadata Endpoint
# -----------------------------
@app.get("/info")
async def model_info(request: Request):
dense_model = request.app.state.dense_model
device = request.app.state.device
return {
"dense_model": DENSE_MODEL_NAME,
"sparse_model": SPARSE_MODEL_NAME,
"embedding_dimension": dense_model.get_sentence_embedding_dimension(),
"device": device,
"uptime_seconds": int(time.time() - request.app.state.start_time),
"max_batch_size": MAX_BATCH_SIZE,
"max_text_length": MAX_TEXT_LENGTH,
}