Amittripipathi commited on
Commit
f085fbd
·
verified ·
1 Parent(s): a21f07c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +14 -12
src/streamlit_app.py CHANGED
@@ -2,10 +2,13 @@ import streamlit as st
2
  import pandas as pd
3
  from datetime import datetime
4
  import joblib
 
 
 
5
 
6
  # Load the trained model
7
  def load_model():
8
- return joblib.load("SuperKart_sales_prediction_model_v1_0.joblib")
9
 
10
  model = load_model()
11
 
@@ -39,21 +42,19 @@ input_data = {
39
  'Store_Type': [store_type],
40
  }
41
 
42
- # Custom transformer to calculate store age
43
- class StoreAgeCalculator(BaseEstimator, TransformerMixin):
44
- def __init__(self):
45
- self.current_year = datetime.now().year
46
 
47
- def fit(self, X, y=None):
 
 
48
  return self
49
 
50
- def transform(self, X):
51
- X = X.copy()
52
- X['Store_Age'] = self.current_year - X['Store_Establishment_Year']
53
- return X.drop(columns=['Store_Establishment_Year'])
54
 
55
- # Convert the input data to a DataFrame
56
- input_df = pd.DataFrame(input_data)
57
 
58
  # Convert categorical columns to category type
59
  input_df['Product_Sugar_Content'] = input_df['Product_Sugar_Content'].astype('category')
@@ -63,6 +64,7 @@ input_df['Store_Size'] = input_df['Store_Size'].astype('category')
63
  input_df['Store_Location_City_Type'] = input_df['Store_Location_City_Type'].astype('category')
64
  input_df['Store_Type'] = input_df['Store_Type'].astype('category')
65
 
 
66
  # Make predictions
67
  if st.button("Predict"):
68
  predictions = model.predict(input_df)
 
2
  import pandas as pd
3
  from datetime import datetime
4
  import joblib
5
+ from sklearn.base import BaseEstimator, TransformerMixin
6
+ from datetime import datetime
7
+ from transformers import SugarContentReplacer,StoreAgeCalculator # Import the custom transformer
8
 
9
  # Load the trained model
10
  def load_model():
11
+ return joblib.load("src/SuperKart_sales_prediction_model_v1_0.joblib")
12
 
13
  model = load_model()
14
 
 
42
  'Store_Type': [store_type],
43
  }
44
 
45
+ # Convert the input data to a DataFrame
46
+ input_df = pd.DataFrame(input_data)
 
 
47
 
48
+ # Custom transformer to replace 'reg' with 'Regular' in Product_Sugar_Content
49
+ class SugarContentReplacer(BaseEstimator, TransformerMixin):
50
+ def fit(self, input_df, y=None):
51
  return self
52
 
53
+ def transform(self, input_df):
54
+ input_df = input_df.copy()
55
+ input_df['Product_Sugar_Content'] = input_df['Product_Sugar_Content'].replace('reg', 'Regular')
56
+ return input_df
57
 
 
 
58
 
59
  # Convert categorical columns to category type
60
  input_df['Product_Sugar_Content'] = input_df['Product_Sugar_Content'].astype('category')
 
64
  input_df['Store_Location_City_Type'] = input_df['Store_Location_City_Type'].astype('category')
65
  input_df['Store_Type'] = input_df['Store_Type'].astype('category')
66
 
67
+
68
  # Make predictions
69
  if st.button("Predict"):
70
  predictions = model.predict(input_df)