Lokiiparihar commited on
Commit
8deeab2
·
verified ·
1 Parent(s): c58c951

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -45
app.py CHANGED
@@ -1,62 +1,99 @@
1
- # app.py
2
  import numpy as np
3
- from flask import Flask, request, jsonify
4
- import joblib
5
- import pandas as pd
6
 
7
- # Inititialize Flask app with name
8
- sales_prediction_api = Flask("Sales Predictor")
9
 
10
- # Load the trained model predictor model
11
- dt_model = joblib.load("decision_tree_model.pkl")
12
- xgb_model = joblib.load("xgboost_model.pkl")
13
 
14
- # Define a route for the home page
15
- @sales_prediction_api.route('/')
 
 
 
 
 
 
 
 
16
  def home():
17
- return "Sales Prediction API"
 
 
 
 
 
18
 
19
- # Define an endpoint to predict sales
20
- @sales_prediction_api.post('/predict')
21
- def predict():
22
- # Get the data from the request
23
- data = request.get_json()
 
 
 
 
 
 
24
 
25
- # Extract relevant features from the input data
 
26
  sample = {
27
- 'Product_Weight': data['Product_Weight'],
28
- 'Product_Sugar_Content': data['Product_Sugar_Content'],
29
- 'Product_Allocated_Area': data['Product_Allocated_Area'],
30
- 'Product_Type': data['Product_Type'],
31
- 'Product_MRP': data['Product_MRP'],
32
- 'Store_Size': data['Store_Size'],
33
- 'Store_Location_City_Type': data['Store_Location_City_Type'],
34
- 'Store_Type': data['Store_Type'],
35
- 'Store_Age': data['Store_Age']
 
36
  }
37
 
38
- #convert the extracted data into a dataframe
39
- sample_df = pd.DataFrame(sample, index=[0])
 
 
 
 
 
 
 
 
 
 
40
 
41
- # --------------------------------
42
- # Model selection logic (FIXED)
43
- # --------------------------------
44
- model_choice = data.get("model", "dt")
 
 
 
 
 
 
 
45
 
46
- if model_choice == "dt":
47
- prediction = dt_model.predict(sample_df)[0]
48
 
49
- else :
50
- prediction = xgb_model.predict(sample_df)[0]
51
 
 
 
52
 
53
- # --------------------------------
54
- # Response
55
- # --------------------------------
56
- return jsonify({
57
- "model_used": model_choice,
58
- "prediction": float(prediction)
59
- })
60
 
 
 
61
  if __name__ == '__main__':
62
- sales_prediction_api.run(debug=True)
 
 
1
+ # Import necessary libraries
2
  import numpy as np
3
+ import joblib # For loading the serialized model
4
+ import pandas as pd # For data manipulation
5
+ from flask import Flask, request, jsonify # For creating the Flask API
6
 
7
+ print("--- app.py: Starting Flask application setup ---")
 
8
 
9
+ # Initialize the Flask application
10
+ superkart_sales_api = Flask("SuperKart Sales Predictor")
11
+ print("--- app.py: Flask app initialized ---")
12
 
13
+ # Load the trained machine learning model
14
+ try:
15
+ model = joblib.load("superkart_sales_model.pkl")
16
+ print("--- app.py: Model loaded successfully ---")
17
+ except Exception as e:
18
+ print(f"--- app.py: ERROR loading model: {e} ---")
19
+ raise # Re-raise to ensure the error is visible
20
+
21
+ # Define a route for the home page (GET request)
22
+ @superkart_sales_api.get('/')
23
  def home():
24
+ print("--- API: Home route accessed ---")
25
+ """
26
+ This function handles GET requests to the root URL ('/') of the API.
27
+ It returns a simple welcome message.
28
+ """
29
+ return "Welcome to the SuperKart Sales Prediction API!"
30
 
31
+ # Define an endpoint for single sales prediction (POST request)
32
+ @superkart_sales_api.post('/v1/sales')
33
+ def predict_sales():
34
+ print("--- API: Single sales prediction route accessed ---")
35
+ """
36
+ This function handles POST requests to the '/v1/sales' endpoint.
37
+ It expects a JSON payload containing product and store details and returns
38
+ the predicted sales as a JSON response.
39
+ """
40
+ # Get the JSON data from the request body
41
+ input_data_json = request.get_json()
42
 
43
+ # Extract relevant features from the JSON data, matching x_train columns
44
+ # The model expects original feature names before one-hot encoding
45
  sample = {
46
+ 'Product_Id': input_data_json['Product_Id'],
47
+ 'Product_Weight': input_data_json['Product_Weight'],
48
+ 'Product_Sugar_Content': input_data_json['Product_Sugar_Content'],
49
+ 'Product_Allocated_Area': input_data_json['Product_Allocated_Area'],
50
+ 'Product_Type': input_data_json['Product_Type'],
51
+ 'Product_MRP': input_data_json['Product_MRP'],
52
+ 'Store_Id': input_data_json['Store_Id'],
53
+ 'Store_Size': input_data_json['Store_Size'],
54
+ 'Store_Location_City_Type': input_data_json['Store_Location_City_Type'],
55
+ 'Store_Current_Age': input_data_json['Store_Current_Age']
56
  }
57
 
58
+ # Convert the extracted data into a Pandas DataFrame
59
+ input_df = pd.DataFrame([sample])
60
+
61
+ # Make prediction
62
+ predicted_sales = model.predict(input_df)[0]
63
+
64
+ # Convert predicted_sales to Python float and round
65
+ predicted_sales = round(float(predicted_sales), 2)
66
+
67
+ # Return the predicted sales
68
+ return jsonify({'Predicted Sales': predicted_sales})
69
+
70
 
71
+ # Define an endpoint for batch prediction (POST request)
72
+ @superkart_sales_api.post('/v1/salesbatch')
73
+ def predict_sales_batch():
74
+ print("--- API: Batch sales prediction route accessed ---")
75
+ """
76
+ This function handles POST requests to the '/v1/salesbatch' endpoint.
77
+ It expects a CSV file containing product and store details for multiple entries
78
+ and returns the predicted sales as a list in the JSON response.
79
+ """
80
+ # Get the uploaded CSV file from the request
81
+ file = request.files['file']
82
 
83
+ # Read the CSV file into a Pandas DataFrame
84
+ input_df_batch = pd.read_csv(file)
85
 
86
+ # Make predictions for all entries in the DataFrame
87
+ predicted_sales_batch = model.predict(input_df_batch).tolist()
88
 
89
+ # Round each prediction and convert to float
90
+ predicted_sales_batch = [round(float(s), 2) for s in predicted_sales_batch]
91
 
92
+ # Return the predictions list as a JSON response
93
+ return jsonify({'Predicted Sales': predicted_sales_batch})
 
 
 
 
 
94
 
95
+ # Run the Flask application in debug mode if this script is executed directly
96
+ # When deploying with Gunicorn, this block is usually commented out or removed
97
  if __name__ == '__main__':
98
+ print("--- app.py: Running Flask app in debug mode ---")
99
+ superkart_sales_api.run(debug=True)