Sample / app.py
Lokiiparihar's picture
Update Streamlit app
64529f8 verified
raw
history blame
1.96 kB
# app.py
import numpy as np
from flask import Flask, request, jsonify
import joblib
import pandas as pd
# Inititialize Flask app with name
sales_prediction_api = Flask("Sales Predictor")
# Load the trained model predictor model
dt_model = joblib.load("decision_tree_model.pkl")
xgb_model = joblib.load("xgboost_model.pkl")
# Define a route for the home page
@sales_prediction_api.route('/')
def home():
return "Sales Prediction API"
# Define an endpoint to predict sales
@sales_prediction_api.post('/predict')
def predict():
# Get the data from the request
data = request.get_json()
# Extract relevant features from the input data
sample = {
'Product_Weight' = data['Product_Weight'],
'Product_Sugar_Content' = data['Product_Sugar_Content'],
'Product_Allocated_Area ' = data['Product_Allocated_Area'],
'Product_Type' = data['Product_Type'],
'Product_MRP' = data['Product_MRP'],
'Store_Size' = data['Store_Size'],
'Store_Location_City_Type' = data['Store_Location_City_Type'],
'Store_Type' = data['Store_Type'],
'Store_Age' = data['Store_Age']
}
#convert the extracted data into a dataframe
sample_df = pd.DataFrame(sample, index=[0])
# --------------------------------
# Model selection logic (FIXED)
# --------------------------------
model_choice = data.get("model", "dt")
if model_choice == "dt":
prediction = dt_model.predict(sample_df)[0]
elif model_choice == "xgb":
prediction = xgb_model.predict(sample_df)[0]
else:
return jsonify({"error": "Invalid model specified. Use 'dt' or 'xgb'"}), 400
# --------------------------------
# Response
# --------------------------------
return jsonify({
"model_used": model_choice,
"prediction": float(prediction)
})
if __name__ == '__main__':
sales_prediction_api.run(host="0.0.0.0", port=7860,debug=True)