from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline import torch app = FastAPI(title="FinBERT API") # --------- Tải Model Global --------- model_name = "ProsusAI/finbert" print(f"Loading model {model_name}...") try: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3) cryptobert = TextClassificationPipeline( model=model, tokenizer=tokenizer, max_length=64, truncation=True, padding='max_length' ) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") cryptobert = None # --------- Định nghĩa Schema --------- class AnalyzeRequest(BaseModel): texts: list[str] class AnalyzeResult(BaseModel): text: str label: str score: float numeric_score: float class AnalyzeResponse(BaseModel): results: list[AnalyzeResult] avg_score: float # --------- Helper Function --------- def calculate_numeric_score(label: str, score: float) -> float: if label == 'positive': return score elif label == 'negative': return -score else: # Neutral return 0.0 # --------- API Endpoints --------- @app.get("/") def read_root(): return {"status": "ok", "message": "FinBERT API is running", "model_loaded": cryptobert is not None} @app.post("/api/sentiment", response_model=AnalyzeResponse) def analyze_sentiment(req: AnalyzeRequest): if not cryptobert: raise HTTPException(status_code=500, detail="Model is not loaded properly.") if not req.texts: return {"results": [], "avg_score": 0.0} try: # Run predictions in batch preds = cryptobert(req.texts) results = [] total_numeric = 0.0 for text, pred in zip(req.texts, preds): label = pred['label'] score = float(pred['score']) numeric_score = calculate_numeric_score(label, score) results.append({ "text": text, "label": label, "score": score, "numeric_score": numeric_score }) total_numeric += numeric_score avg_score = total_numeric / len(results) if len(results) > 0 else 0.0 return { "results": results, "avg_score": avg_score } except Exception as e: raise HTTPException(status_code=500, detail=str(e))