Spaces:
Running
Running
File size: 2,300 Bytes
7492d83 926f542 92b802f e76d903 926f542 92b802f e76d903 04826fc 926f542 92b802f 2d07eaf 92b802f 926f542 92b802f 3700b4a 7ce4e16 92b802f 3700b4a 92b802f 7ce4e16 92b802f 3700b4a 92b802f 7ce4e16 3700b4a 92b802f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.orm import Session
from datetime import datetime, timezone
from .database import get_db, Base, engine
from .models import InferenceLog
from .schemas import SentimentRequest, SentimentResponse
from . import ml_model
Base.metadata.create_all(bind=engine) # Ensure tables are created at startup
app = FastAPI(
title='FinBERT Sentiment Analyzer API',
description='An API for analyzing the sentiment of financial news articles using FinBERT.',
version='1.0.0',
docs_url='/docs',
redoc_url='/'
)
origins = [
'http://localhost:3000',
'https://portfolio-frontend-livid.vercel.app',
'*'
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"], # Allows all standard methods (GET, POST, PUT, DELETE, etc.)
allow_headers=["*"], # Allows all standard headers
)
@app.post('/predict', response_model=SentimentResponse)
def predict_sentiment(request: SentimentRequest, db: Session = Depends(get_db)):
try:
request_data = request.model_dump()
prediction_result = ml_model.predict(request_data["text"])
current_time = datetime.now(timezone.utc)
log_entry = InferenceLog(
input_text=request_data['text'],
sentiment_prediction=prediction_result['sentiment'],
confidence_score=prediction_result['confidence'],
timestamp=current_time
)
db.add(log_entry)
db.commit()
return SentimentResponse(
input_text=request_data["text"],
sentiment=prediction_result['sentiment'],
confidence=prediction_result['confidence'],
timestamp=log_entry.timestamp
)
except Exception as e:
db.rollback() # Rollback in case of any error during database operations
raise HTTPException(status_code=500, detail=str(e))
@app.get('/logs')
def get_inference_logs(db: Session = Depends(get_db)):
logs = db.query(InferenceLog).order_by(InferenceLog.timestamp.desc()).all()
return logs
@app.get('/health')
def health_check():
return {"status": "ok", "timestamp": datetime.now(timezone.utc)} |