Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Header, Depends, Request | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| import logging | |
| import os | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| limiter = Limiter(key_func=get_remote_address) | |
| app = FastAPI(title="Panoptifi Embeddings API") | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| API_KEY = os.environ.get("API_KEY", "") | |
| def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")): | |
| if API_KEY and x_api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API key") | |
| return True | |
| logger.info("Loading embedding model...") | |
| model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| logger.info("Model loaded") | |
| class EmbedInput(BaseModel): | |
| text: str | |
| class BatchEmbedInput(BaseModel): | |
| texts: list[str] | |
| class EmbeddingResult(BaseModel): | |
| embedding: list[float] | |
| dimensions: int | |
| class BatchEmbeddingResult(BaseModel): | |
| embeddings: list[list[float]] | |
| dimensions: int | |
| # @limiter.limit("60/minute") | |
| def health(request: Request): | |
| return {"status": "healthy", "model": "all-MiniLM-L6-v2", "dimensions": 384} | |
| # @limiter.limit("60/minute") | |
| def embed(request: Request, input: EmbedInput, _: bool = Depends(verify_api_key)): | |
| if not input.text.strip(): | |
| raise HTTPException(400, "Text cannot be empty") | |
| embedding = model.encode(input.text[:2000]).tolist() | |
| return EmbeddingResult(embedding=embedding, dimensions=len(embedding)) | |
| # @limiter.limit("20/minute") | |
| def embed_batch(request: Request, input: BatchEmbedInput, _: bool = Depends(verify_api_key)): | |
| if len(input.texts) > 100: | |
| raise HTTPException(400, "Max 100 texts per batch") | |
| texts = [t[:2000] for t in input.texts if t.strip()] | |
| embeddings = model.encode(texts).tolist() | |
| return BatchEmbeddingResult(embeddings=embeddings, dimensions=384) |