vichter's picture
Update app.py
4b9ad61 verified
from fastapi import FastAPI, HTTPException, Header, Depends, Request
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import logging
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
limiter = Limiter(key_func=get_remote_address)
app = FastAPI(title="Panoptifi Embeddings API")
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
API_KEY = os.environ.get("API_KEY", "")
def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")):
if API_KEY and x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
return True
logger.info("Loading embedding model...")
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
logger.info("Model loaded")
class EmbedInput(BaseModel):
text: str
class BatchEmbedInput(BaseModel):
texts: list[str]
class EmbeddingResult(BaseModel):
embedding: list[float]
dimensions: int
class BatchEmbeddingResult(BaseModel):
embeddings: list[list[float]]
dimensions: int
@app.get("/health")
# @limiter.limit("60/minute")
def health(request: Request):
return {"status": "healthy", "model": "all-MiniLM-L6-v2", "dimensions": 384}
@app.post("/embed", response_model=EmbeddingResult)
# @limiter.limit("60/minute")
def embed(request: Request, input: EmbedInput, _: bool = Depends(verify_api_key)):
if not input.text.strip():
raise HTTPException(400, "Text cannot be empty")
embedding = model.encode(input.text[:2000]).tolist()
return EmbeddingResult(embedding=embedding, dimensions=len(embedding))
@app.post("/embed/batch", response_model=BatchEmbeddingResult)
# @limiter.limit("20/minute")
def embed_batch(request: Request, input: BatchEmbedInput, _: bool = Depends(verify_api_key)):
if len(input.texts) > 100:
raise HTTPException(400, "Max 100 texts per batch")
texts = [t[:2000] for t in input.texts if t.strip()]
embeddings = model.encode(texts).tolist()
return BatchEmbeddingResult(embeddings=embeddings, dimensions=384)