from fastapi import FastAPI, HTTPException, Header, Depends, Request from pydantic import BaseModel from transformers import pipeline from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.util import get_remote_address from slowapi.errors import RateLimitExceeded import logging import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) limiter = Limiter(key_func=get_remote_address) app = FastAPI(title="Panoptifi Topics API") app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) API_KEY = os.environ.get("API_KEY", "") def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")): if API_KEY and x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") return True logger.info("Loading zero-shot classifier...") classifier = pipeline( "zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0" ) logger.info("Model loaded") DEFAULT_LABELS = [ "monetary policy", "earnings", "mergers and acquisitions", "regulation", "layoffs", "product launch", "legal issues", "market sentiment", "cryptocurrency", "economic data" ] class TopicInput(BaseModel): text: str labels: list[str] | None = None multi_label: bool = False class TopicScore(BaseModel): label: str score: float class TopicResult(BaseModel): labels: list[TopicScore] class BatchTopicInput(BaseModel): texts: list[str] labels: list[str] | None = None multi_label: bool = False @app.get("/health") # @limiter.limit("60/minute") def health(request: Request): return {"status": "healthy", "model": "ModernBERT-large-zeroshot-v2.0"} @app.post("/classify", response_model=TopicResult) # @limiter.limit("30/minute") def classify_topic(request: Request, input: TopicInput, _: bool = Depends(verify_api_key)): if not input.text.strip(): raise HTTPException(400, "Text cannot be empty") labels = input.labels or DEFAULT_LABELS result = classifier( input.text[:2000], labels, multi_label=input.multi_label ) return TopicResult(labels=[ TopicScore(label=label, score=score) for label, score in zip(result["labels"], result["scores"]) ]) @app.post("/classify/batch", response_model=list[TopicResult]) # @limiter.limit("10/minute") def classify_batch(request: Request, input: BatchTopicInput, _: bool = Depends(verify_api_key)): if len(input.texts) > 50: raise HTTPException(400, "Max 50 texts per batch") labels = input.labels or DEFAULT_LABELS results = [] for text in input.texts: if text.strip(): result = classifier(text[:2000], labels, multi_label=input.multi_label) results.append(TopicResult(labels=[ TopicScore(label=label, score=score) for label, score in zip(result["labels"], result["scores"]) ])) return results