Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Union | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| from huggingface_hub import HfApi | |
| from sentence_transformers import SentenceTransformer | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| MODEL_REPO = os.getenv("MODEL_REPO") # Example: "sentence-transformers/all-MiniLM-L6-v2" | |
| if HF_TOKEN is None: | |
| raise ValueError("HF_TOKEN is missing in environment variables") | |
| if MODEL_REPO is None: | |
| raise ValueError("MODEL_REPO is missing in environment variables") | |
| # Request / Response Schemas | |
| class EmbeddingInput(BaseModel): | |
| text: Union[str, List[str]] | |
| class EmbeddingOutput(BaseModel): | |
| embeddings: List[List[float]] | |
| model_version: str | |
| # Global model instance | |
| model: SentenceTransformer = None | |
| api = HfApi() | |
| # Lifespan (FastAPI startup/shutdown) | |
| async def lifespan(app: FastAPI): | |
| global model | |
| print("π Loading embedding model...") | |
| model = SentenceTransformer( | |
| MODEL_REPO, | |
| token=HF_TOKEN, | |
| device="cpu" # Force CPU for stability on 16GB RAM host | |
| ) | |
| # Put model into eval mode (more correct, avoids any unnecessary training state) | |
| model.eval() | |
| print("β Model loaded successfully.") | |
| yield | |
| print("π Shutting down.") | |
| # FastAPI App | |
| app = FastAPI( | |
| title="Custom Embedding Model API", | |
| description="Lightweight SentenceTransformer-based embedding service.", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Health / Info Route | |
| def root(): | |
| # Get the repo's latest commit hash/tag | |
| try: | |
| latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name | |
| except Exception: | |
| latest_tag = "unknown" | |
| return { | |
| "message": "Embedding API is running.", | |
| "model_loaded": True if model else False, | |
| "model_version": latest_tag, | |
| } | |
| # Embedding Route | |
| def generate_embedding(input_data: EmbeddingInput): | |
| try: | |
| # Normalize input to list | |
| texts = input_data.text if isinstance(input_data.text, list) else [input_data.text] | |
| # Compute embeddings (in batch for performance) | |
| vectors = model.encode( | |
| texts, | |
| batch_size=32, | |
| show_progress_bar=False, | |
| normalize_embeddings=True # helps with vector similarity use cases | |
| ).tolist() | |
| # Fetch model version | |
| try: | |
| latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name | |
| except Exception: | |
| latest_tag = "latest" | |
| return EmbeddingOutput( | |
| embeddings=vectors, | |
| model_version=latest_tag | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Embedding generation failed: {str(e)}" | |
| ) | |