Backend / app.py
SandeepMM's picture
Upload folder using huggingface_hub
1e9e683 verified
import sys
import joblib
import pandas as pd
import numpy as np
from flask import Flask, request, jsonify
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import OneHotEncoder, StandardScaler, LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor # Included for compatibility if you switch models
class FeatureEngineer(BaseEstimator, TransformerMixin):
def __init__(self):
self.le_prod = LabelEncoder()
self.le_store = LabelEncoder()
def fit(self, X, y=None):
X_copy = X.copy()
X_copy['Product_Id_Cd'] = X_copy['Product_Id'].apply(lambda x: x[:2])
X_copy['Product_Sugar_Content_Corr'] = X_copy['Product_Sugar_Content'].str.replace('reg', 'Regular', regex=True)
X_copy['Operation_Years'] = 2025 - X_copy['Store_Establishment_Year']
self.le_prod.fit(X_copy['Product_Id_Cd'])
le_feat=['Product_Sugar_Content_Corr','Store_Size','Store_Location_City_Type','Store_Type','Product_Id_Cd']
for i in le_feat:
self.le_prod.fit(X_copy[i])
self.le_store.fit(X_copy['Store_Id'])
return self
def transform(self, X):
X_copy = X.copy()
X_copy['Product_Id_Cd'] = X_copy['Product_Id'].apply(lambda x: x[:2])
X_copy['Product_Sugar_Content_Corr'] = X_copy['Product_Sugar_Content'].str.replace('reg', 'Regular', regex=True)
X_copy['Operation_Years'] = 2013 - X_copy['Store_Establishment_Year']
try:
le_feat=['Product_Sugar_Content_Corr','Store_Size','Store_Location_City_Type','Store_Type','Product_Id_Cd']
for i in le_feat:
X_copy[i] = self.le_prod.transform(X_copy[i])
except ValueError:
X_copy['Product_Id_Cd'] = -1
try:
X_copy['Store'] = self.le_store.transform(X_copy['Store_Id'])
except ValueError:
X_copy['Store'] = -1
rem_feat=['Product_Id','Store_Id','Product_Sugar_Content','Product_Type', 'Store_Establishment_Year']
X_copy.drop(rem_feat, axis=1, inplace=True)
return X_copy
# This allows joblib's pickle to find the class reference it saved during training.
sys.modules['__main__'].FeatureEngineer = FeatureEngineer
# Initialize Flask app with a name
app = Flask("SuperKart Sales Predictor")
# Load the trained churn prediction model
model = joblib.load("XGBoostRegressor_BEST_Pipeline.joblib")
# Define a route for the home page
@app.get('/')
def home():
return "Welcome to the SuperKart Sales Prediction API"
# Define an endpoint to predict churn for a single customer
@app.post('/v1/product')
def predict_sales():
# Get JSON data from the request
customer_data = request.get_json()
# Use .get() with a default value to avoid a KeyError
required_keys = ['Product_Id', 'Product_Weight', 'Product_Sugar_Content', 'Product_Allocated_Area',
'Product_Type', 'Product_MRP', 'Store_Id', 'Store_Establishment_Year',
'Store_Size', 'Store_Location_City_Type', 'Store_Type']
sample = {}
for key in required_keys:
sample[key] = customer_data.get(key)
if sample[key] is None:
return jsonify({'error': f'Missing key: {key}'}), 400
# Extract relevant customer features from the input data
sample = {
'Product_Id': customer_data['Product_Id'],
'Product_Weight': customer_data['Product_Weight'],
'Product_Sugar_Content': customer_data['Product_Sugar_Content'],
'Product_Allocated_Area': customer_data['Product_Allocated_Area'],
'Product_Type': customer_data['Product_Type'],
'Product_MRP': customer_data['Product_MRP'],
'Store_Id': customer_data['Store_Id'],
'Store_Establishment_Year': customer_data['Store_Establishment_Year'],
'Store_Size': customer_data['Store_Size'],
'Store_Location_City_Type': customer_data['Store_Location_City_Type'],
'Store_Type': customer_data['Store_Type']
}
# Convert the extracted data into a DataFrame
input_data = pd.DataFrame([sample])
# Make a Sales prediction using the trained model
prediction = model.predict(input_data).tolist()[0]
# Return the prediction as a JSON response
return jsonify({'Prediction': prediction})
# Run the Flask app in debug mode
if __name__ == '__main__':
app.run(debug=True)