Spaces:
Sleeping
Sleeping
File size: 3,001 Bytes
dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 dcc55a4 75c14d3 2897aba 75c14d3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import os
from typing import List, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from huggingface_hub import HfApi
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO = os.getenv("MODEL_REPO") # Example: "sentence-transformers/all-MiniLM-L6-v2"
if HF_TOKEN is None:
raise ValueError("HF_TOKEN is missing in environment variables")
if MODEL_REPO is None:
raise ValueError("MODEL_REPO is missing in environment variables")
# Request / Response Schemas
class EmbeddingInput(BaseModel):
text: Union[str, List[str]]
class EmbeddingOutput(BaseModel):
embeddings: List[List[float]]
model_version: str
# Global model instance
model: SentenceTransformer = None
api = HfApi()
# Lifespan (FastAPI startup/shutdown)
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
print("π Loading embedding model...")
model = SentenceTransformer(
MODEL_REPO,
token=HF_TOKEN,
device="cpu" # Force CPU for stability on 16GB RAM host
)
# Put model into eval mode (more correct, avoids any unnecessary training state)
model.eval()
print("β
Model loaded successfully.")
yield
print("π Shutting down.")
# FastAPI App
app = FastAPI(
title="Custom Embedding Model API",
description="Lightweight SentenceTransformer-based embedding service.",
version="1.0.0",
lifespan=lifespan
)
# Health / Info Route
@app.get("/")
def root():
# Get the repo's latest commit hash/tag
try:
latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name
except Exception:
latest_tag = "unknown"
return {
"message": "Embedding API is running.",
"model_loaded": True if model else False,
"model_version": latest_tag,
}
# Embedding Route
@app.post("/embed", response_model=EmbeddingOutput)
def generate_embedding(input_data: EmbeddingInput):
try:
# Normalize input to list
texts = input_data.text if isinstance(input_data.text, list) else [input_data.text]
# Compute embeddings (in batch for performance)
vectors = model.encode(
texts,
batch_size=32,
show_progress_bar=False,
normalize_embeddings=True # helps with vector similarity use cases
).tolist()
# Fetch model version
try:
latest_tag = api.list_repo_refs(repo_id=MODEL_REPO, repo_type="model").tags[0].name
except Exception:
latest_tag = "latest"
return EmbeddingOutput(
embeddings=vectors,
model_version=latest_tag
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Embedding generation failed: {str(e)}"
)
|