vichter's picture
Update app.py
054efe9 verified
from fastapi import FastAPI, HTTPException, Header, Depends, Request
from pydantic import BaseModel
from transformers import pipeline
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
import logging
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
limiter = Limiter(key_func=get_remote_address)
app = FastAPI(title="Panoptifi Topics API")
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
API_KEY = os.environ.get("API_KEY", "")
def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")):
if API_KEY and x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
return True
logger.info("Loading zero-shot classifier...")
classifier = pipeline(
"zero-shot-classification",
model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
)
logger.info("Model loaded")
DEFAULT_LABELS = [
"monetary policy",
"earnings",
"mergers and acquisitions",
"regulation",
"layoffs",
"product launch",
"legal issues",
"market sentiment",
"cryptocurrency",
"economic data"
]
class TopicInput(BaseModel):
text: str
labels: list[str] | None = None
multi_label: bool = False
class TopicScore(BaseModel):
label: str
score: float
class TopicResult(BaseModel):
labels: list[TopicScore]
class BatchTopicInput(BaseModel):
texts: list[str]
labels: list[str] | None = None
multi_label: bool = False
@app.get("/health")
# @limiter.limit("60/minute")
def health(request: Request):
return {"status": "healthy", "model": "ModernBERT-large-zeroshot-v2.0"}
@app.post("/classify", response_model=TopicResult)
# @limiter.limit("30/minute")
def classify_topic(request: Request, input: TopicInput, _: bool = Depends(verify_api_key)):
if not input.text.strip():
raise HTTPException(400, "Text cannot be empty")
labels = input.labels or DEFAULT_LABELS
result = classifier(
input.text[:2000],
labels,
multi_label=input.multi_label
)
return TopicResult(labels=[
TopicScore(label=label, score=score)
for label, score in zip(result["labels"], result["scores"])
])
@app.post("/classify/batch", response_model=list[TopicResult])
# @limiter.limit("10/minute")
def classify_batch(request: Request, input: BatchTopicInput, _: bool = Depends(verify_api_key)):
if len(input.texts) > 50:
raise HTTPException(400, "Max 50 texts per batch")
labels = input.labels or DEFAULT_LABELS
results = []
for text in input.texts:
if text.strip():
result = classifier(text[:2000], labels, multi_label=input.multi_label)
results.append(TopicResult(labels=[
TopicScore(label=label, score=score)
for label, score in zip(result["labels"], result["scores"])
]))
return results