Spaces:
Sleeping
Sleeping
File size: 4,484 Bytes
3b26a2a 179e727 6582ce8 3192444 179e727 6582ce8 179e727 6582ce8 3192444 6582ce8 3192444 179e727 6582ce8 179e727 3192444 179e727 3192444 6582ce8 3192444 6582ce8 3192444 6582ce8 179e727 6582ce8 179e727 3192444 179e727 3192444 179e727 6582ce8 179e727 6582ce8 179e727 3192444 179e727 6582ce8 3192444 179e727 3192444 179e727 6582ce8 179e727 6582ce8 3b26a2a 3192444 6582ce8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import streamlit as st
import pandas as pd
import numpy as np
import joblib
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
# ================= SETTINGS =================
USE_LOG_TARGET = True
st.set_page_config(page_title="Store Sales Forecasting", layout="wide")
BASE_DIR = os.path.dirname(__file__)
model = joblib.load(os.path.join(BASE_DIR, "model.pkl"))
feature_names = joblib.load(os.path.join(BASE_DIR, "features.pkl"))
# ================= LOAD TEST DATA =================
X_test_path = os.path.join(BASE_DIR, "X_test.npy")
y_test_path = os.path.join(BASE_DIR, "y_test.npy")
if os.path.exists(X_test_path):
X_test = np.load(X_test_path)
y_test = np.load(y_test_path)
y_pred_test = model.predict(X_test)
if USE_LOG_TARGET:
y_pred_test = np.expm1(y_pred_test)
y_test = np.expm1(y_test)
rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))
else:
rmse = None
# ================= TITLE =================
st.title("๐ Store Sales Forecasting")
st.markdown("Predict daily store sales using Machine Learning.")
tab1, tab2 = st.tabs(["๐ฎ Prediction", "๐ Model Insights"])
# ================= PREDICTION TAB =================
with tab1:
st.subheader("Input Features")
families = [c.replace("family_", "") for c in feature_names if "family_" in c]
if st.button("๐ฒ Load Example"):
store_nbr = 1
onpromotion = 5
date = pd.to_datetime("2017-08-15")
family = families[0]
else:
date = st.date_input("Date")
col1, col2 = st.columns(2)
with col1:
store_nbr = st.number_input("Store Number", 1)
onpromotion = st.number_input("On Promotion", 0)
with col2:
family = st.selectbox("Product Family", families)
year = date.year
month = date.month
day = date.day
dayofweek = date.weekday()
input_dict = dict.fromkeys(feature_names, 0)
input_dict["store_nbr"] = store_nbr
input_dict["onpromotion"] = onpromotion
input_dict["year"] = year
input_dict["month"] = month
input_dict["day"] = day
input_dict["dayofweek"] = dayofweek
input_dict[f"family_{family}"] = 1
features = pd.DataFrame([input_dict])
# ================= PREDICT =================
if st.button("Predict Sales"):
with st.spinner("Making prediction..."):
pred = model.predict(features)[0]
if USE_LOG_TARGET:
pred = np.expm1(pred)
st.markdown("## ๐ Predicted Sales")
col1, col2 = st.columns(2)
with col1:
st.metric("๐ฐ Sales", f"{pred:,.2f}")
with col2:
st.metric("๐ช Store", store_nbr)
# download
result_df = pd.DataFrame({
"store_nbr": [store_nbr],
"family": [family],
"prediction": [pred]
})
st.download_button(
"โฌ Download prediction",
result_df.to_csv(index=False),
"prediction.csv",
"text/csv"
)
# ================= MODEL INSIGHTS =================
with tab2:
st.subheader("Model Performance")
if rmse:
st.metric("RMSE", f"{rmse:,.2f}")
else:
st.info("Upload X_test.npy & y_test.npy to display RMSE.")
# ================= FEATURE IMPORTANCE =================
if hasattr(model, "feature_importances_"):
st.subheader("Top Feature Importances")
importance = pd.Series(
model.feature_importances_,
index=feature_names
)
top = importance.sort_values(ascending=False).head(15)
fig, ax = plt.subplots()
top.sort_values().plot(kind="barh", ax=ax)
st.pyplot(fig)
# grouped importance
st.subheader("Grouped Importance")
family_imp = importance[importance.index.str.contains("family_")].sum()
other_imp = importance[~importance.index.str.contains("family_")]
grouped = pd.concat([
pd.Series({"family_total": family_imp}),
other_imp
]).sort_values(ascending=False).head(10)
fig2, ax2 = plt.subplots()
grouped.sort_values().plot(kind="barh", ax=ax2)
st.pyplot(fig2)
# ================= MODEL INFO =================
st.subheader("Model Info")
st.info(f"""
Model type: **{type(model).__name__}**
Features used: **{len(feature_names)}**
Log target: **{USE_LOG_TARGET}**
""") |