BeyzaTopbas commited on
Commit
6582ce8
Β·
verified Β·
1 Parent(s): c695e1a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +90 -23
src/streamlit_app.py CHANGED
@@ -6,7 +6,9 @@ import os
6
  import matplotlib.pyplot as plt
7
  from sklearn.metrics import mean_squared_error
8
 
9
- # ================= CONFIG =================
 
 
10
  st.set_page_config(page_title="Store Sales Forecasting", layout="wide")
11
 
12
  BASE_DIR = os.path.dirname(__file__)
@@ -14,17 +16,23 @@ BASE_DIR = os.path.dirname(__file__)
14
  model = joblib.load(os.path.join(BASE_DIR, "model.pkl"))
15
  feature_names = joblib.load(os.path.join(BASE_DIR, "features.pkl"))
16
 
17
- # test data (optioneel voor insights)
18
  X_test_path = os.path.join(BASE_DIR, "X_test.npy")
19
  y_test_path = os.path.join(BASE_DIR, "y_test.npy")
20
 
21
  if os.path.exists(X_test_path):
22
  X_test = np.load(X_test_path)
23
  y_test = np.load(y_test_path)
 
24
  y_pred_test = model.predict(X_test)
 
 
 
 
 
25
  rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))
26
  else:
27
- X_test, y_test, rmse = None, None, None
28
 
29
  # ================= TITLE =================
30
  st.title("πŸ›’ Store Sales Forecasting")
@@ -39,22 +47,28 @@ with tab1:
39
 
40
  families = [c.replace("family_", "") for c in feature_names if "family_" in c]
41
 
42
- col1, col2 = st.columns(2)
 
 
 
 
 
 
43
 
44
- with col1:
45
- store_nbr = st.number_input("Store Number", 1)
46
- onpromotion = st.number_input("On Promotion", 0)
47
- family = st.selectbox("Product Family", families)
48
 
49
- with col2:
50
- date = st.date_input("Date")
 
51
 
52
- year = date.year
53
- month = date.month
54
- day = date.day
55
- dayofweek = date.weekday()
 
 
 
56
 
57
- # -------- One-hot encoding in background --------
58
  input_dict = dict.fromkeys(feature_names, 0)
59
 
60
  input_dict["store_nbr"] = store_nbr
@@ -63,7 +77,6 @@ with tab1:
63
  input_dict["month"] = month
64
  input_dict["day"] = day
65
  input_dict["dayofweek"] = dayofweek
66
-
67
  input_dict[f"family_{family}"] = 1
68
 
69
  features = pd.DataFrame([input_dict])
@@ -71,22 +84,48 @@ with tab1:
71
  # ================= PREDICT =================
72
  if st.button("Predict Sales"):
73
 
74
- prediction = model.predict(features)[0]
 
 
 
 
 
75
 
76
  st.markdown("## πŸ“ˆ Predicted Sales")
77
- st.success(f"πŸ’° {prediction:,.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # ================= MODEL INSIGHTS =================
80
  with tab2:
81
 
82
  st.subheader("Model Performance")
83
 
84
- if rmse is not None:
85
  st.metric("RMSE", f"{rmse:,.2f}")
86
  else:
87
  st.info("Upload X_test.npy & y_test.npy to display RMSE.")
88
 
89
- # -------- Feature Importance --------
90
  if hasattr(model, "feature_importances_"):
91
 
92
  st.subheader("Top Feature Importances")
@@ -94,8 +133,36 @@ with tab2:
94
  importance = pd.Series(
95
  model.feature_importances_,
96
  index=feature_names
97
- ).sort_values(ascending=False).head(15)
 
 
98
 
99
  fig, ax = plt.subplots()
100
- importance.sort_values().plot(kind="barh", ax=ax)
101
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import matplotlib.pyplot as plt
7
  from sklearn.metrics import mean_squared_error
8
 
9
+ # ================= SETTINGS =================
10
+ USE_LOG_TARGET = True
11
+
12
  st.set_page_config(page_title="Store Sales Forecasting", layout="wide")
13
 
14
  BASE_DIR = os.path.dirname(__file__)
 
16
  model = joblib.load(os.path.join(BASE_DIR, "model.pkl"))
17
  feature_names = joblib.load(os.path.join(BASE_DIR, "features.pkl"))
18
 
19
+ # ================= LOAD TEST DATA =================
20
  X_test_path = os.path.join(BASE_DIR, "X_test.npy")
21
  y_test_path = os.path.join(BASE_DIR, "y_test.npy")
22
 
23
  if os.path.exists(X_test_path):
24
  X_test = np.load(X_test_path)
25
  y_test = np.load(y_test_path)
26
+
27
  y_pred_test = model.predict(X_test)
28
+
29
+ if USE_LOG_TARGET:
30
+ y_pred_test = np.expm1(y_pred_test)
31
+ y_test = np.expm1(y_test)
32
+
33
  rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))
34
  else:
35
+ rmse = None
36
 
37
  # ================= TITLE =================
38
  st.title("πŸ›’ Store Sales Forecasting")
 
47
 
48
  families = [c.replace("family_", "") for c in feature_names if "family_" in c]
49
 
50
+ if st.button("🎲 Load Example"):
51
+ store_nbr = 1
52
+ onpromotion = 5
53
+ date = pd.to_datetime("2017-08-15")
54
+ family = families[0]
55
+ else:
56
+ date = st.date_input("Date")
57
 
58
+ col1, col2 = st.columns(2)
 
 
 
59
 
60
+ with col1:
61
+ store_nbr = st.number_input("Store Number", 1)
62
+ onpromotion = st.number_input("On Promotion", 0)
63
 
64
+ with col2:
65
+ family = st.selectbox("Product Family", families)
66
+
67
+ year = date.year
68
+ month = date.month
69
+ day = date.day
70
+ dayofweek = date.weekday()
71
 
 
72
  input_dict = dict.fromkeys(feature_names, 0)
73
 
74
  input_dict["store_nbr"] = store_nbr
 
77
  input_dict["month"] = month
78
  input_dict["day"] = day
79
  input_dict["dayofweek"] = dayofweek
 
80
  input_dict[f"family_{family}"] = 1
81
 
82
  features = pd.DataFrame([input_dict])
 
84
  # ================= PREDICT =================
85
  if st.button("Predict Sales"):
86
 
87
+ with st.spinner("Making prediction..."):
88
+
89
+ pred = model.predict(features)[0]
90
+
91
+ if USE_LOG_TARGET:
92
+ pred = np.expm1(pred)
93
 
94
  st.markdown("## πŸ“ˆ Predicted Sales")
95
+
96
+ col1, col2 = st.columns(2)
97
+
98
+ with col1:
99
+ st.metric("πŸ’° Sales", f"{pred:,.2f}")
100
+
101
+ with col2:
102
+ st.metric("πŸͺ Store", store_nbr)
103
+
104
+ # download
105
+ result_df = pd.DataFrame({
106
+ "store_nbr": [store_nbr],
107
+ "family": [family],
108
+ "prediction": [pred]
109
+ })
110
+
111
+ st.download_button(
112
+ "⬇ Download prediction",
113
+ result_df.to_csv(index=False),
114
+ "prediction.csv",
115
+ "text/csv"
116
+ )
117
 
118
  # ================= MODEL INSIGHTS =================
119
  with tab2:
120
 
121
  st.subheader("Model Performance")
122
 
123
+ if rmse:
124
  st.metric("RMSE", f"{rmse:,.2f}")
125
  else:
126
  st.info("Upload X_test.npy & y_test.npy to display RMSE.")
127
 
128
+ # ================= FEATURE IMPORTANCE =================
129
  if hasattr(model, "feature_importances_"):
130
 
131
  st.subheader("Top Feature Importances")
 
133
  importance = pd.Series(
134
  model.feature_importances_,
135
  index=feature_names
136
+ )
137
+
138
+ top = importance.sort_values(ascending=False).head(15)
139
 
140
  fig, ax = plt.subplots()
141
+ top.sort_values().plot(kind="barh", ax=ax)
142
+ st.pyplot(fig)
143
+
144
+ # grouped importance
145
+ st.subheader("Grouped Importance")
146
+
147
+ family_imp = importance[importance.index.str.contains("family_")].sum()
148
+ other_imp = importance[~importance.index.str.contains("family_")]
149
+
150
+ grouped = pd.concat([
151
+ pd.Series({"family_total": family_imp}),
152
+ other_imp
153
+ ]).sort_values(ascending=False).head(10)
154
+
155
+ fig2, ax2 = plt.subplots()
156
+ grouped.sort_values().plot(kind="barh", ax=ax2)
157
+ st.pyplot(fig2)
158
+
159
+ # ================= MODEL INFO =================
160
+ st.subheader("Model Info")
161
+
162
+ st.info(f"""
163
+ Model type: **{type(model).__name__}**
164
+
165
+ Features used: **{len(feature_names)}**
166
+
167
+ Log target: **{USE_LOG_TARGET}**
168
+ """)