Amittripipathi commited on
Commit
b7be1e7
·
verified ·
1 Parent(s): 42c27bc

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +21 -0
src/streamlit_app.py CHANGED
@@ -42,6 +42,26 @@ input_data = {
42
  # Convert the input data to a DataFrame
43
  input_df = pd.DataFrame(input_data)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # Convert categorical columns to category type
46
  input_df['Product_Sugar_Content'] = input_df['Product_Sugar_Content'].astype('category')
47
  input_df['Product_Type'] = input_df['Product_Type'].astype('category')
@@ -50,6 +70,7 @@ input_df['Store_Size'] = input_df['Store_Size'].astype('category')
50
  input_df['Store_Location_City_Type'] = input_df['Store_Location_City_Type'].astype('category')
51
  input_df['Store_Type'] = input_df['Store_Type'].astype('category')
52
 
 
53
  # Make predictions
54
  if st.button("Predict"):
55
  predictions = model.predict(input_df)
 
42
  # Convert the input data to a DataFrame
43
  input_df = pd.DataFrame(input_data)
44
 
45
+ # Custom transformer to replace 'reg' with 'Regular' in Product_Sugar_Content
46
+ class SugarContentReplacer(BaseEstimator, TransformerMixin):
47
+ def fit(self, X, y=None):
48
+ return self
49
+
50
+ def transform(self, X):
51
+ X = X.copy()
52
+ X['Product_Sugar_Content'] = X['Product_Sugar_Content'].replace('reg', 'Regular')
53
+ return X
54
+
55
+ # Add get_feature_names_out method
56
+ def get_feature_names_out(self, input_features=None):
57
+ if input_features is None:
58
+ # Assuming the transformer operates on a single column if input_features is not provided
59
+ return ['Product_Sugar_Content']
60
+ else:
61
+ # Return the input feature names as the output feature names
62
+ return input_features
63
+
64
+
65
  # Convert categorical columns to category type
66
  input_df['Product_Sugar_Content'] = input_df['Product_Sugar_Content'].astype('category')
67
  input_df['Product_Type'] = input_df['Product_Type'].astype('category')
 
70
  input_df['Store_Location_City_Type'] = input_df['Store_Location_City_Type'].astype('category')
71
  input_df['Store_Type'] = input_df['Store_Type'].astype('category')
72
 
73
+
74
  # Make predictions
75
  if st.button("Predict"):
76
  predictions = model.predict(input_df)