Spaces:
Configuration error
Configuration error
| # ml_service/main.py | |
| import logging | |
| import time | |
| import os | |
| import torch | |
| from typing import List | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException, Request | |
| from pydantic import BaseModel, Field, constr | |
| from sentence_transformers import SentenceTransformer | |
| from fastembed import SparseTextEmbedding | |
| # ----------------------------- | |
| # Configuration | |
| # ----------------------------- | |
| MAX_TEXT_LENGTH = 5000 | |
| MAX_BATCH_SIZE = 32 | |
| DENSE_MODEL_NAME = "nomic-ai/nomic-embed-text-v1.5" | |
| SPARSE_MODEL_NAME = "prithivida/Splade_PP_en_v1" | |
| # ----------------------------- | |
| # Structured Logging Setup | |
| # ----------------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| ) | |
| logger = logging.getLogger("athena.vector_engine") | |
| # ----------------------------- | |
| # Lifespan Management | |
| # ----------------------------- | |
| async def lifespan(app: FastAPI): | |
| logger.info("🧠 Booting Vector Engine...") | |
| start_time = time.time() | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| # Load dense model | |
| dense_model = SentenceTransformer( | |
| DENSE_MODEL_NAME, | |
| trust_remote_code=True, | |
| device=device, | |
| ) | |
| # Load sparse model | |
| sparse_model = SparseTextEmbedding( | |
| model_name=SPARSE_MODEL_NAME | |
| ) | |
| # Warmup (prevents cold-start latency spike) | |
| logger.info("🔥 Warming up models...") | |
| dense_model.encode("warmup", normalize_embeddings=True) | |
| list(sparse_model.embed(["warmup"])) | |
| # Attach to app state | |
| app.state.dense_model = dense_model | |
| app.state.sparse_model = sparse_model | |
| app.state.device = device | |
| app.state.start_time = time.time() | |
| duration = time.time() - start_time | |
| logger.info(f"✅ Models loaded successfully in {duration:.2f}s") | |
| yield | |
| except Exception as e: | |
| logger.exception("❌ Failed during startup") | |
| raise e | |
| finally: | |
| logger.info("🛑 Shutting down Vector Engine...") | |
| app.state.__dict__.clear() | |
| # ----------------------------- | |
| # FastAPI App | |
| # ----------------------------- | |
| app = FastAPI( | |
| title="Athena Vector Engine", | |
| description="Production-grade ML microservice for dense + sparse embeddings", | |
| version="2.0.0", | |
| lifespan=lifespan, | |
| ) | |
| # ----------------------------- | |
| # Schemas | |
| # ----------------------------- | |
| class VectorRequest(BaseModel): | |
| texts: List[constr(min_length=1, max_length=MAX_TEXT_LENGTH)] = Field( | |
| ..., description="List of input texts to embed" | |
| ) | |
| class SparseData(BaseModel): | |
| indices: List[int] | |
| values: List[float] | |
| class VectorResponse(BaseModel): | |
| dense_vectors: List[List[float]] | |
| sparse_vectors: List[SparseData] | |
| # ----------------------------- | |
| # Embedding Endpoint | |
| # ----------------------------- | |
| def generate_vectors(req: VectorRequest, request: Request): | |
| if len(req.texts) > MAX_BATCH_SIZE: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Batch size exceeds maximum limit of {MAX_BATCH_SIZE}", | |
| ) | |
| dense_model = request.app.state.dense_model | |
| sparse_model = request.app.state.sparse_model | |
| try: | |
| start_time = time.perf_counter() | |
| # Prefix required for Nomic retrieval queries | |
| prefixed_texts = [f"search_query: {text}" for text in req.texts] | |
| # Dense embeddings (batched) | |
| dense_results = dense_model.encode( | |
| prefixed_texts, | |
| normalize_embeddings=True, | |
| batch_size=len(prefixed_texts), | |
| ).tolist() | |
| # Sparse embeddings (batched) | |
| sparse_raw = list(sparse_model.embed(req.texts)) | |
| sparse_results = [ | |
| { | |
| "indices": vec.indices.tolist(), | |
| "values": vec.values.tolist(), | |
| } | |
| for vec in sparse_raw | |
| ] | |
| duration = time.perf_counter() - start_time | |
| logger.info( | |
| f"Vectorized batch_size={len(req.texts)} " | |
| f"latency={duration:.4f}s" | |
| ) | |
| return { | |
| "dense_vectors": dense_results, | |
| "sparse_vectors": sparse_results, | |
| } | |
| except Exception as e: | |
| logger.exception("🔥 Vectorization failed") | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Failed to generate embeddings", | |
| ) | |
| # ----------------------------- | |
| # Health Endpoints | |
| # ----------------------------- | |
| async def liveness(): | |
| return {"status": "alive"} | |
| async def readiness(request: Request): | |
| ready = ( | |
| hasattr(request.app.state, "dense_model") | |
| and hasattr(request.app.state, "sparse_model") | |
| ) | |
| return {"ready": ready} | |
| # ----------------------------- | |
| # Metadata Endpoint | |
| # ----------------------------- | |
| async def model_info(request: Request): | |
| dense_model = request.app.state.dense_model | |
| device = request.app.state.device | |
| return { | |
| "dense_model": DENSE_MODEL_NAME, | |
| "sparse_model": SPARSE_MODEL_NAME, | |
| "embedding_dimension": dense_model.get_sentence_embedding_dimension(), | |
| "device": device, | |
| "uptime_seconds": int(time.time() - request.app.state.start_time), | |
| "max_batch_size": MAX_BATCH_SIZE, | |
| "max_text_length": MAX_TEXT_LENGTH, | |
| } |