# ml_service/main.py import logging import time import os import torch from typing import List from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel, Field, constr from sentence_transformers import SentenceTransformer from fastembed import SparseTextEmbedding # ----------------------------- # Configuration # ----------------------------- MAX_TEXT_LENGTH = 5000 MAX_BATCH_SIZE = 32 DENSE_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5" SPARSE_MODEL_NAME = "prithivida/Splade_PP_en_v1" # ----------------------------- # Structured Logging Setup # ----------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) logger = logging.getLogger("athena.vector_engine") # ----------------------------- # Lifespan Management # ----------------------------- @asynccontextmanager async def lifespan(app: FastAPI): logger.info("🧠 Booting Vector Engine...") start_time = time.time() try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Load dense model dense_model = SentenceTransformer( DENSE_MODEL_NAME, trust_remote_code=True, device=device, ) # Load sparse model sparse_model = SparseTextEmbedding( model_name=SPARSE_MODEL_NAME ) # Warmup (prevents cold-start latency spike) logger.info("🔥 Warming up models...") dense_model.encode("warmup", normalize_embeddings=True) list(sparse_model.embed(["warmup"])) # Attach to app state app.state.dense_model = dense_model app.state.sparse_model = sparse_model app.state.device = device app.state.start_time = time.time() duration = time.time() - start_time logger.info(f"✅ Models loaded successfully in {duration:.2f}s") yield except Exception as e: logger.exception("❌ Failed during startup") raise e finally: logger.info("🛑 Shutting down Vector Engine...") app.state.__dict__.clear() # ----------------------------- # FastAPI App # ----------------------------- app = FastAPI( title="Athena Vector Engine", description="Production-grade ML microservice for dense + sparse embeddings", version="2.0.0", lifespan=lifespan, ) # ----------------------------- # Schemas # ----------------------------- class VectorRequest(BaseModel): texts: List[constr(min_length=1, max_length=MAX_TEXT_LENGTH)] = Field( ..., description="List of input texts to embed" ) class SparseData(BaseModel): indices: List[int] values: List[float] class VectorResponse(BaseModel): dense_vectors: List[List[float]] sparse_vectors: List[SparseData] # ----------------------------- # Embedding Endpoint # ----------------------------- @app.post("/vectorize", response_model=VectorResponse) def generate_vectors(req: VectorRequest, request: Request): if len(req.texts) > MAX_BATCH_SIZE: raise HTTPException( status_code=400, detail=f"Batch size exceeds maximum limit of {MAX_BATCH_SIZE}", ) dense_model = request.app.state.dense_model sparse_model = request.app.state.sparse_model try: start_time = time.perf_counter() # Prefix required for Nomic retrieval queries prefixed_texts = [f"search_query: {text}" for text in req.texts] # Dense embeddings (batched) dense_results = dense_model.encode( prefixed_texts, normalize_embeddings=True, batch_size=len(prefixed_texts), ).tolist() # Sparse embeddings (batched) sparse_raw = list(sparse_model.embed(req.texts)) sparse_results = [ { "indices": vec.indices.tolist(), "values": vec.values.tolist(), } for vec in sparse_raw ] duration = time.perf_counter() - start_time logger.info( f"Vectorized batch_size={len(req.texts)} " f"latency={duration:.4f}s" ) return { "dense_vectors": dense_results, "sparse_vectors": sparse_results, } except Exception as e: logger.exception("🔥 Vectorization failed") raise HTTPException( status_code=500, detail="Failed to generate embeddings", ) # ----------------------------- # Health Endpoints # ----------------------------- @app.api_route("/health/live", methods=["GET", "HEAD"]) async def liveness(): return {"status": "alive"} @app.api_route("/health/ready", methods=["GET", "HEAD"]) async def readiness(request: Request): ready = ( hasattr(request.app.state, "dense_model") and hasattr(request.app.state, "sparse_model") ) return {"ready": ready} # ----------------------------- # Metadata Endpoint # ----------------------------- @app.get("/info") async def model_info(request: Request): dense_model = request.app.state.dense_model device = request.app.state.device return { "dense_model": DENSE_MODEL_NAME, "sparse_model": SPARSE_MODEL_NAME, "embedding_dimension": dense_model.get_sentence_embedding_dimension(), "device": device, "uptime_seconds": int(time.time() - request.app.state.start_time), "max_batch_size": MAX_BATCH_SIZE, "max_text_length": MAX_TEXT_LENGTH, }