File size: 1,924 Bytes
84f612c
 
 
 
 
bcbf797
84f612c
43e5c86
84f612c
 
093285e
bcbf797
ea72b54
 
 
093285e
 
 
bcbf797
 
 
84f612c
 
 
 
bcbf797
 
 
 
84f612c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcbf797
84f612c
 
bcbf797
 
 
 
 
 
 
 
 
84f612c
 
bcbf797
84f612c
 
 
 
bcbf797
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from transformers import pipeline

MODEL_NAME = "will702/indo-roBERTa-financial-sentiment-v2"
API_KEY = os.getenv("API_KEY")

# Label mapping — flipped: 0=Positive, 1=Neutral, 2=Negative
LABEL_MAP = {
    "label_0": "positive",
    "label_1": "neutral",
    "label_2": "negative",
    "positive": "positive",
    "neutral": "neutral",
    "negative": "negative",
}

classifier = None


@asynccontextmanager
async def lifespan(app: FastAPI):
    global classifier
    print(f"Loading model: {MODEL_NAME}")
    classifier = pipeline("text-classification", model=MODEL_NAME)
    print("Model loaded.")
    yield


app = FastAPI(title="StockPro Sentiment", lifespan=lifespan)


class PredictRequest(BaseModel):
    texts: list[str]


@app.post("/predict")
async def predict(body: PredictRequest, request: Request):
    if API_KEY:
        key = request.headers.get("X-API-Key")
        if key != API_KEY:
            raise HTTPException(status_code=401, detail="Invalid API key")

    texts = body.texts
    if not texts:
        raise HTTPException(status_code=400, detail="texts must not be empty")
    if len(texts) > 20:
        raise HTTPException(status_code=400, detail="Maximum 20 texts per request")

    if classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded yet")

    predictions = classifier(texts, truncation=True, max_length=512)

    results = []
    for text, pred in zip(texts, predictions):
        label = LABEL_MAP.get(pred["label"].lower(), "neutral")
        results.append({
            "text": text,
            "sentiment": label,
            "score": round(pred["score"], 4),
        })

    return {"results": results}


@app.get("/health")
def health():
    return {"status": "ok", "model_loaded": classifier is not None}