| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline |
| import torch |
|
|
| app = FastAPI(title="FinBERT API") |
|
|
| |
| model_name = "ProsusAI/finbert" |
| print(f"Loading model {model_name}...") |
|
|
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
| model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3) |
| cryptobert = TextClassificationPipeline( |
| model=model, |
| tokenizer=tokenizer, |
| max_length=64, |
| truncation=True, |
| padding='max_length' |
| ) |
| print("Model loaded successfully!") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| cryptobert = None |
|
|
| |
| class AnalyzeRequest(BaseModel): |
| texts: list[str] |
|
|
| class AnalyzeResult(BaseModel): |
| text: str |
| label: str |
| score: float |
| numeric_score: float |
|
|
| class AnalyzeResponse(BaseModel): |
| results: list[AnalyzeResult] |
| avg_score: float |
|
|
| |
| def calculate_numeric_score(label: str, score: float) -> float: |
| if label == 'positive': |
| return score |
| elif label == 'negative': |
| return -score |
| else: |
| return 0.0 |
|
|
| |
| @app.get("/") |
| def read_root(): |
| return {"status": "ok", "message": "FinBERT API is running", "model_loaded": cryptobert is not None} |
|
|
| @app.post("/api/sentiment", response_model=AnalyzeResponse) |
| def analyze_sentiment(req: AnalyzeRequest): |
| if not cryptobert: |
| raise HTTPException(status_code=500, detail="Model is not loaded properly.") |
| |
| if not req.texts: |
| return {"results": [], "avg_score": 0.0} |
|
|
| try: |
| |
| preds = cryptobert(req.texts) |
| |
| results = [] |
| total_numeric = 0.0 |
| |
| for text, pred in zip(req.texts, preds): |
| label = pred['label'] |
| score = float(pred['score']) |
| numeric_score = calculate_numeric_score(label, score) |
| |
| results.append({ |
| "text": text, |
| "label": label, |
| "score": score, |
| "numeric_score": numeric_score |
| }) |
| total_numeric += numeric_score |
| |
| avg_score = total_numeric / len(results) if len(results) > 0 else 0.0 |
| |
| return { |
| "results": results, |
| "avg_score": avg_score |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |