model / app.py
JigneshPrajapati18's picture
Update app.py
c9a29bf verified
import time
import joblib
import numpy as np
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.perf_counter()
response = await call_next(request)
process_time = time.perf_counter() - start_time
response.headers["time"] = str(process_time)
return response
model_paths = {
"logistic_regression": r"logistic_regression.pkl",
"random_forest_model": r"random_forest_model.pkl",
"decision_tree": r"DecisionTreeClassifier.pkl",
"svm": r"SVM_model.pkl",
"knn": r"KNeighborsClassifier_model.pkl",
"naive_bayes": r"Naive_Bayes_model.pkl",
"ann": r"ANN_model.pkl"
}
models = {name: joblib.load(path) for name, path in model_paths.items()}
class MusicFeatures(BaseModel):
danceability: float
energy: float
key: int
loudness: float
mode: int
speechiness: float
acousticness: float
instrumentalness: float
liveness: float
valence: float
tempo: float
duration_ms: int
time_signature: int
def make_prediction(model, features: MusicFeatures):
input_data = np.array([[
features.danceability, features.energy, features.key,
features.loudness, features.mode, features.speechiness,
features.acousticness, features.instrumentalness, features.liveness,
features.valence, features.tempo, features.duration_ms, features.time_signature
]])
return int(model.predict(input_data)[0])
@app.get("/", response_class=HTMLResponse)
async def read_home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/predict/logistic_regression")
def predict_logistic(features: MusicFeatures):
return {"model": "Logistic Regression", "prediction": make_prediction(models["logistic_regression"], features)}
@app.post("/predict/random_forest")
def predict_rf(features: MusicFeatures):
return {"model": "Random Forest", "prediction": make_prediction(models["random_forest_model"], features)}
@app.post("/predict/decision_tree")
def predict_dt(features: MusicFeatures):
return {"model": "Decision Tree", "prediction": make_prediction(models["decision_tree"], features)}
@app.post("/predict/svm")
def predict_svm(features: MusicFeatures):
return {"model": "SVM", "prediction": make_prediction(models["svm"], features)}
@app.post("/predict/knn")
def predict_knn(features: MusicFeatures):
return {"model": "K-Nearest Neighbors", "prediction": make_prediction(models["knn"], features)}
@app.post("/predict/naive_bayes")
def predict_nb(features: MusicFeatures):
return {"model": "Naive Bayes", "prediction": make_prediction(models["naive_bayes"], features)}
@app.post("/predict/ann")
def predict_ann(features: MusicFeatures):
return {"model": "Artificial Neural Network", "prediction": make_prediction(models["ann"], features)}