finbert / main.py
dien2112's picture
Update main.py
2178543 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
import torch
app = FastAPI(title="FinBERT API")
# --------- Tải Model Global ---------
model_name = "ProsusAI/finbert"
print(f"Loading model {model_name}...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)
cryptobert = TextClassificationPipeline(
model=model,
tokenizer=tokenizer,
max_length=64,
truncation=True,
padding='max_length'
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
cryptobert = None
# --------- Định nghĩa Schema ---------
class AnalyzeRequest(BaseModel):
texts: list[str]
class AnalyzeResult(BaseModel):
text: str
label: str
score: float
numeric_score: float
class AnalyzeResponse(BaseModel):
results: list[AnalyzeResult]
avg_score: float
# --------- Helper Function ---------
def calculate_numeric_score(label: str, score: float) -> float:
if label == 'positive':
return score
elif label == 'negative':
return -score
else: # Neutral
return 0.0
# --------- API Endpoints ---------
@app.get("/")
def read_root():
return {"status": "ok", "message": "FinBERT API is running", "model_loaded": cryptobert is not None}
@app.post("/api/sentiment", response_model=AnalyzeResponse)
def analyze_sentiment(req: AnalyzeRequest):
if not cryptobert:
raise HTTPException(status_code=500, detail="Model is not loaded properly.")
if not req.texts:
return {"results": [], "avg_score": 0.0}
try:
# Run predictions in batch
preds = cryptobert(req.texts)
results = []
total_numeric = 0.0
for text, pred in zip(req.texts, preds):
label = pred['label']
score = float(pred['score'])
numeric_score = calculate_numeric_score(label, score)
results.append({
"text": text,
"label": label,
"score": score,
"numeric_score": numeric_score
})
total_numeric += numeric_score
avg_score = total_numeric / len(results) if len(results) > 0 else 0.0
return {
"results": results,
"avg_score": avg_score
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))