mr-kush's picture
Update default latest tag in embedding generation to 'latest' for better clarity
2897aba
import os
from typing import List, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from huggingface_hub import HfApi
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO = os.getenv("MODEL_REPO") # Example: "sentence-transformers/all-MiniLM-L6-v2"
if HF_TOKEN is None:
raise ValueError("HF_TOKEN is missing in environment variables")
if MODEL_REPO is None:
raise ValueError("MODEL_REPO is missing in environment variables")
# Request / Response Schemas
class EmbeddingInput(BaseModel):
text: Union[str, List[str]]
class EmbeddingOutput(BaseModel):
embeddings: List[List[float]]
model_version: str
# Global model instance
model: SentenceTransformer = None
api = HfApi()
# Lifespan (FastAPI startup/shutdown)
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
print("πŸ”„ Loading embedding model...")
model = SentenceTransformer(
MODEL_REPO,
token=HF_TOKEN,
device="cpu" # Force CPU for stability on 16GB RAM host
)
# Put model into eval mode (more correct, avoids any unnecessary training state)
model.eval()
print("βœ… Model loaded successfully.")
yield
print("πŸ›‘ Shutting down.")
# FastAPI App
app = FastAPI(
title="Custom Embedding Model API",
description="Lightweight SentenceTransformer-based embedding service.",
version="1.0.0",
lifespan=lifespan
)
# Health / Info Route
@app.get("/")
def root():
# Get the repo's latest commit hash/tag
try:
latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name
except Exception:
latest_tag = "unknown"
return {
"message": "Embedding API is running.",
"model_loaded": True if model else False,
"model_version": latest_tag,
}
# Embedding Route
@app.post("/embed", response_model=EmbeddingOutput)
def generate_embedding(input_data: EmbeddingInput):
try:
# Normalize input to list
texts = input_data.text if isinstance(input_data.text, list) else [input_data.text]
# Compute embeddings (in batch for performance)
vectors = model.encode(
texts,
batch_size=32,
show_progress_bar=False,
normalize_embeddings=True # helps with vector similarity use cases
).tolist()
# Fetch model version
try:
latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name
except Exception:
latest_tag = "latest"
return EmbeddingOutput(
embeddings=vectors,
model_version=latest_tag
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Embedding generation failed: {str(e)}"
)