BeyzaTopbas commited on
Commit
3192444
ยท
verified ยท
1 Parent(s): d876b29

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +45 -29
src/streamlit_app.py CHANGED
@@ -7,25 +7,24 @@ import matplotlib.pyplot as plt
7
  from sklearn.metrics import mean_squared_error
8
 
9
  # ================= CONFIG =================
10
- st.set_page_config(page_title="Store Sales Forecasting", layout="centered")
11
 
12
  BASE_DIR = os.path.dirname(__file__)
13
 
14
  model = joblib.load(os.path.join(BASE_DIR, "model.pkl"))
15
  feature_names = joblib.load(os.path.join(BASE_DIR, "features.pkl"))
16
 
17
- # ================= LOAD TEST DATA (OPTIONAL) =================
18
  X_test_path = os.path.join(BASE_DIR, "X_test.npy")
19
  y_test_path = os.path.join(BASE_DIR, "y_test.npy")
20
 
21
  if os.path.exists(X_test_path):
22
  X_test = np.load(X_test_path)
23
  y_test = np.load(y_test_path)
24
-
25
- y_pred = model.predict(X_test)
26
- rmse = np.sqrt(mean_squared_error(y_test, y_pred))
27
  else:
28
- X_test, y_test, y_pred, rmse = None, None, None, None
29
 
30
  # ================= TITLE =================
31
  st.title("๐Ÿ›’ Store Sales Forecasting")
@@ -33,44 +32,61 @@ st.markdown("Predict daily store sales using Machine Learning.")
33
 
34
  tab1, tab2 = st.tabs(["๐Ÿ”ฎ Prediction", "๐Ÿ“Š Model Insights"])
35
 
36
- # ================= TAB 1 โ€“ PREDICTION =================
37
  with tab1:
38
 
39
  st.subheader("Input Features")
40
 
41
- input_data = {}
 
 
 
 
 
 
 
 
 
 
42
 
43
- for feature in feature_names:
44
- input_data[feature] = st.number_input(feature, value=0.0)
 
 
45
 
46
- input_df = pd.DataFrame([input_data])
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  if st.button("Predict Sales"):
49
 
50
- prediction = model.predict(input_df)[0]
51
 
52
  st.markdown("## ๐Ÿ“ˆ Predicted Sales")
53
- st.success(f"${prediction:,.2f}")
54
 
55
- # ================= TAB 2 โ€“ MODEL INSIGHTS =================
56
  with tab2:
57
 
58
  st.subheader("Model Performance")
59
 
60
- if X_test is None:
61
- st.info("Upload X_test.npy and y_test.npy to see performance.")
62
  else:
63
- st.metric("RMSE", round(rmse, 2))
64
-
65
- # Actual vs Predicted
66
- fig, ax = plt.subplots(figsize=(10, 4))
67
- ax.plot(y_test[:200], label="Actual")
68
- ax.plot(y_pred[:200], label="Predicted")
69
- ax.legend()
70
- ax.set_title("Actual vs Predicted Sales")
71
- st.pyplot(fig)
72
 
73
- # ================= FEATURE IMPORTANCE =================
74
  if hasattr(model, "feature_importances_"):
75
 
76
  st.subheader("Top Feature Importances")
@@ -80,6 +96,6 @@ with tab2:
80
  index=feature_names
81
  ).sort_values(ascending=False).head(15)
82
 
83
- fig2, ax2 = plt.subplots()
84
- importance.sort_values().plot(kind="barh", ax=ax2)
85
- st.pyplot(fig2)
 
7
  from sklearn.metrics import mean_squared_error
8
 
9
  # ================= CONFIG =================
10
+ st.set_page_config(page_title="Store Sales Forecasting", layout="wide")
11
 
12
  BASE_DIR = os.path.dirname(__file__)
13
 
14
  model = joblib.load(os.path.join(BASE_DIR, "model.pkl"))
15
  feature_names = joblib.load(os.path.join(BASE_DIR, "features.pkl"))
16
 
17
+ # test data (optioneel voor insights)
18
  X_test_path = os.path.join(BASE_DIR, "X_test.npy")
19
  y_test_path = os.path.join(BASE_DIR, "y_test.npy")
20
 
21
  if os.path.exists(X_test_path):
22
  X_test = np.load(X_test_path)
23
  y_test = np.load(y_test_path)
24
+ y_pred_test = model.predict(X_test)
25
+ rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))
 
26
  else:
27
+ X_test, y_test, rmse = None, None, None
28
 
29
  # ================= TITLE =================
30
  st.title("๐Ÿ›’ Store Sales Forecasting")
 
32
 
33
  tab1, tab2 = st.tabs(["๐Ÿ”ฎ Prediction", "๐Ÿ“Š Model Insights"])
34
 
35
+ # ================= PREDICTION TAB =================
36
  with tab1:
37
 
38
  st.subheader("Input Features")
39
 
40
+ families = [c.replace("family_", "") for c in feature_names if "family_" in c]
41
+
42
+ col1, col2 = st.columns(2)
43
+
44
+ with col1:
45
+ store_nbr = st.number_input("Store Number", 1)
46
+ onpromotion = st.number_input("On Promotion", 0)
47
+ family = st.selectbox("Product Family", families)
48
+
49
+ with col2:
50
+ date = st.date_input("Date")
51
 
52
+ year = date.year
53
+ month = date.month
54
+ day = date.day
55
+ dayofweek = date.weekday()
56
 
57
+ # -------- One-hot encoding in background --------
58
+ input_dict = dict.fromkeys(feature_names, 0)
59
 
60
+ input_dict["store_nbr"] = store_nbr
61
+ input_dict["onpromotion"] = onpromotion
62
+ input_dict["year"] = year
63
+ input_dict["month"] = month
64
+ input_dict["day"] = day
65
+ input_dict["dayofweek"] = dayofweek
66
+
67
+ input_dict[f"family_{family}"] = 1
68
+
69
+ features = pd.DataFrame([input_dict])
70
+
71
+ # ================= PREDICT =================
72
  if st.button("Predict Sales"):
73
 
74
+ prediction = model.predict(features)[0]
75
 
76
  st.markdown("## ๐Ÿ“ˆ Predicted Sales")
77
+ st.success(f"๐Ÿ’ฐ {prediction:,.2f}")
78
 
79
+ # ================= MODEL INSIGHTS =================
80
  with tab2:
81
 
82
  st.subheader("Model Performance")
83
 
84
+ if rmse is not None:
85
+ st.metric("RMSE", f"{rmse:,.2f}")
86
  else:
87
+ st.info("Upload X_test.npy & y_test.npy to display RMSE.")
 
 
 
 
 
 
 
 
88
 
89
+ # -------- Feature Importance --------
90
  if hasattr(model, "feature_importances_"):
91
 
92
  st.subheader("Top Feature Importances")
 
96
  index=feature_names
97
  ).sort_values(ascending=False).head(15)
98
 
99
+ fig, ax = plt.subplots()
100
+ importance.sort_values().plot(kind="barh", ax=ax)
101
+ st.pyplot(fig)