vichter commited on
Commit
da6e6ee
·
verified ·
1 Parent(s): 2dd8299

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Header, Depends, Request
2
+ from pydantic import BaseModel
3
+ from transformers import pipeline
4
+ from slowapi import Limiter, _rate_limit_exceeded_handler
5
+ from slowapi.util import get_remote_address
6
+ from slowapi.errors import RateLimitExceeded
7
+ import logging
8
+ import os
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ limiter = Limiter(key_func=get_remote_address)
14
+
15
+ app = FastAPI(title="Panoptifi Topics API")
16
+ app.state.limiter = limiter
17
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
18
+
19
+ API_KEY = os.environ.get("API_KEY", "")
20
+
21
+
22
+ def verify_api_key(x_api_key: str = Header(None, alias="X-API-Key")):
23
+ if API_KEY and x_api_key != API_KEY:
24
+ raise HTTPException(status_code=401, detail="Invalid API key")
25
+ return True
26
+
27
+
28
+ logger.info("Loading zero-shot classifier...")
29
+ classifier = pipeline(
30
+ "zero-shot-classification",
31
+ model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
32
+ )
33
+ logger.info("Model loaded")
34
+
35
+ DEFAULT_LABELS = [
36
+ "monetary policy",
37
+ "earnings",
38
+ "mergers and acquisitions",
39
+ "regulation",
40
+ "layoffs",
41
+ "product launch",
42
+ "legal issues",
43
+ "market sentiment",
44
+ "cryptocurrency",
45
+ "economic data"
46
+ ]
47
+
48
+
49
+ class TopicInput(BaseModel):
50
+ text: str
51
+ labels: list[str] | None = None
52
+ multi_label: bool = False
53
+
54
+
55
+ class TopicScore(BaseModel):
56
+ label: str
57
+ score: float
58
+
59
+
60
+ class TopicResult(BaseModel):
61
+ labels: list[TopicScore]
62
+
63
+
64
+ @app.get("/health")
65
+ @limiter.limit("60/minute")
66
+ def health(request: Request):
67
+ return {"status": "healthy", "model": "ModernBERT-large-zeroshot-v2.0"}
68
+
69
+
70
+ @app.post("/classify", response_model=TopicResult)
71
+ @limiter.limit("30/minute")
72
+ def classify_topic(request: Request, input: TopicInput, _: bool = Depends(verify_api_key)):
73
+ if not input.text.strip():
74
+ raise HTTPException(400, "Text cannot be empty")
75
+
76
+ labels = input.labels or DEFAULT_LABELS
77
+ result = classifier(
78
+ input.text[:2000],
79
+ labels,
80
+ multi_label=input.multi_label
81
+ )
82
+
83
+ return TopicResult(labels=[
84
+ TopicScore(label=label, score=score)
85
+ for label, score in zip(result["labels"], result["scores"])
86
+ ])