jflo's picture
Update app.py
be8eaf8
raw
history blame
1.06 kB
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import pandas as pd
import maven_text_preprocessing
app = FastAPI()
class ClassificationRequest(BaseModel):
message: str
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/classify")
def sentiment_analysis(payload: ClassificationRequest):
model = joblib.load("naive_bayes.joblib")
vectorizer = joblib.load("vectorizer.joblib")
clean_text = maven_text_preprocessing.clean_and_normalize(pd.Series([payload.message]))
X = vectorizer.transform(clean_text) # ⚠️ transform, NOT fit_transform
category_list = ["Politics", "Sport", "Technology", "Entertainment", "Business"]
predictions = model.predict(X)
pred_prob = model.predict_proba(X)
pred_prob = pred_prob.tolist()[0]
return {
category_list[0]: pred_prob[0],
category_list[1]: pred_prob[1],
category_list[2]: pred_prob[2],
category_list[3]: pred_prob[3],
category_list[4]: pred_prob[4]
}