import os from typing import List, Union from fastapi import FastAPI, HTTPException from pydantic import BaseModel from contextlib import asynccontextmanager from huggingface_hub import HfApi from sentence_transformers import SentenceTransformer from dotenv import load_dotenv # Load environment variables load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") MODEL_REPO = os.getenv("MODEL_REPO") # Example: "sentence-transformers/all-MiniLM-L6-v2" if HF_TOKEN is None: raise ValueError("HF_TOKEN is missing in environment variables") if MODEL_REPO is None: raise ValueError("MODEL_REPO is missing in environment variables") # Request / Response Schemas class EmbeddingInput(BaseModel): text: Union[str, List[str]] class EmbeddingOutput(BaseModel): embeddings: List[List[float]] model_version: str # Global model instance model: SentenceTransformer = None api = HfApi() # Lifespan (FastAPI startup/shutdown) @asynccontextmanager async def lifespan(app: FastAPI): global model print("🔄 Loading embedding model...") model = SentenceTransformer( MODEL_REPO, token=HF_TOKEN, device="cpu" # Force CPU for stability on 16GB RAM host ) # Put model into eval mode (more correct, avoids any unnecessary training state) model.eval() print("✅ Model loaded successfully.") yield print("🛑 Shutting down.") # FastAPI App app = FastAPI( title="Custom Embedding Model API", description="Lightweight SentenceTransformer-based embedding service.", version="1.0.0", lifespan=lifespan ) # Health / Info Route @app.get("/") def root(): # Get the repo's latest commit hash/tag try: latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name except Exception: latest_tag = "unknown" return { "message": "Embedding API is running.", "model_loaded": True if model else False, "model_version": latest_tag, } # Embedding Route @app.post("/embed", response_model=EmbeddingOutput) def generate_embedding(input_data: EmbeddingInput): try: # Normalize input to list texts = input_data.text if isinstance(input_data.text, list) else [input_data.text] # Compute embeddings (in batch for performance) vectors = model.encode( texts, batch_size=32, show_progress_bar=False, normalize_embeddings=True # helps with vector similarity use cases ).tolist() # Fetch model version try: latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name except Exception: latest_tag = "latest" return EmbeddingOutput( embeddings=vectors, model_version=latest_tag ) except Exception as e: raise HTTPException( status_code=500, detail=f"Embedding generation failed: {str(e)}" )