Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Header, Depends, Request | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| import logging | |
| import os | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| limiter = Limiter(key_func=get_remote_address) | |
| app = FastAPI(title="Panoptifi Topics API") | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| API_KEY = os.environ.get("API_KEY", "") | |
| def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")): | |
| if API_KEY and x_api_key != API_KEY: | |
| raise HTTPException(status_code=401, detail="Invalid API key") | |
| return True | |
| logger.info("Loading zero-shot classifier...") | |
| classifier = pipeline( | |
| "zero-shot-classification", | |
| model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0" | |
| ) | |
| logger.info("Model loaded") | |
| DEFAULT_LABELS = [ | |
| "monetary policy", | |
| "earnings", | |
| "mergers and acquisitions", | |
| "regulation", | |
| "layoffs", | |
| "product launch", | |
| "legal issues", | |
| "market sentiment", | |
| "cryptocurrency", | |
| "economic data" | |
| ] | |
| class TopicInput(BaseModel): | |
| text: str | |
| labels: list[str] | None = None | |
| multi_label: bool = False | |
| class TopicScore(BaseModel): | |
| label: str | |
| score: float | |
| class TopicResult(BaseModel): | |
| labels: list[TopicScore] | |
| class BatchTopicInput(BaseModel): | |
| texts: list[str] | |
| labels: list[str] | None = None | |
| multi_label: bool = False | |
| # @limiter.limit("60/minute") | |
| def health(request: Request): | |
| return {"status": "healthy", "model": "ModernBERT-large-zeroshot-v2.0"} | |
| # @limiter.limit("30/minute") | |
| def classify_topic(request: Request, input: TopicInput, _: bool = Depends(verify_api_key)): | |
| if not input.text.strip(): | |
| raise HTTPException(400, "Text cannot be empty") | |
| labels = input.labels or DEFAULT_LABELS | |
| result = classifier( | |
| input.text[:2000], | |
| labels, | |
| multi_label=input.multi_label | |
| ) | |
| return TopicResult(labels=[ | |
| TopicScore(label=label, score=score) | |
| for label, score in zip(result["labels"], result["scores"]) | |
| ]) | |
| # @limiter.limit("10/minute") | |
| def classify_batch(request: Request, input: BatchTopicInput, _: bool = Depends(verify_api_key)): | |
| if len(input.texts) > 50: | |
| raise HTTPException(400, "Max 50 texts per batch") | |
| labels = input.labels or DEFAULT_LABELS | |
| results = [] | |
| for text in input.texts: | |
| if text.strip(): | |
| result = classifier(text[:2000], labels, multi_label=input.multi_label) | |
| results.append(TopicResult(labels=[ | |
| TopicScore(label=label, score=score) | |
| for label, score in zip(result["labels"], result["scores"]) | |
| ])) | |
| return results |