Store-Sales-Forecasting / src /streamlit_app.py
BeyzaTopbas's picture
Update src/streamlit_app.py
6582ce8 verified
import streamlit as st
import pandas as pd
import numpy as np
import joblib
import os
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
# ================= SETTINGS =================
USE_LOG_TARGET = True
st.set_page_config(page_title="Store Sales Forecasting", layout="wide")
BASE_DIR = os.path.dirname(__file__)
model = joblib.load(os.path.join(BASE_DIR, "model.pkl"))
feature_names = joblib.load(os.path.join(BASE_DIR, "features.pkl"))
# ================= LOAD TEST DATA =================
X_test_path = os.path.join(BASE_DIR, "X_test.npy")
y_test_path = os.path.join(BASE_DIR, "y_test.npy")
if os.path.exists(X_test_path):
X_test = np.load(X_test_path)
y_test = np.load(y_test_path)
y_pred_test = model.predict(X_test)
if USE_LOG_TARGET:
y_pred_test = np.expm1(y_pred_test)
y_test = np.expm1(y_test)
rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))
else:
rmse = None
# ================= TITLE =================
st.title("๐Ÿ›’ Store Sales Forecasting")
st.markdown("Predict daily store sales using Machine Learning.")
tab1, tab2 = st.tabs(["๐Ÿ”ฎ Prediction", "๐Ÿ“Š Model Insights"])
# ================= PREDICTION TAB =================
with tab1:
st.subheader("Input Features")
families = [c.replace("family_", "") for c in feature_names if "family_" in c]
if st.button("๐ŸŽฒ Load Example"):
store_nbr = 1
onpromotion = 5
date = pd.to_datetime("2017-08-15")
family = families[0]
else:
date = st.date_input("Date")
col1, col2 = st.columns(2)
with col1:
store_nbr = st.number_input("Store Number", 1)
onpromotion = st.number_input("On Promotion", 0)
with col2:
family = st.selectbox("Product Family", families)
year = date.year
month = date.month
day = date.day
dayofweek = date.weekday()
input_dict = dict.fromkeys(feature_names, 0)
input_dict["store_nbr"] = store_nbr
input_dict["onpromotion"] = onpromotion
input_dict["year"] = year
input_dict["month"] = month
input_dict["day"] = day
input_dict["dayofweek"] = dayofweek
input_dict[f"family_{family}"] = 1
features = pd.DataFrame([input_dict])
# ================= PREDICT =================
if st.button("Predict Sales"):
with st.spinner("Making prediction..."):
pred = model.predict(features)[0]
if USE_LOG_TARGET:
pred = np.expm1(pred)
st.markdown("## ๐Ÿ“ˆ Predicted Sales")
col1, col2 = st.columns(2)
with col1:
st.metric("๐Ÿ’ฐ Sales", f"{pred:,.2f}")
with col2:
st.metric("๐Ÿช Store", store_nbr)
# download
result_df = pd.DataFrame({
"store_nbr": [store_nbr],
"family": [family],
"prediction": [pred]
})
st.download_button(
"โฌ‡ Download prediction",
result_df.to_csv(index=False),
"prediction.csv",
"text/csv"
)
# ================= MODEL INSIGHTS =================
with tab2:
st.subheader("Model Performance")
if rmse:
st.metric("RMSE", f"{rmse:,.2f}")
else:
st.info("Upload X_test.npy & y_test.npy to display RMSE.")
# ================= FEATURE IMPORTANCE =================
if hasattr(model, "feature_importances_"):
st.subheader("Top Feature Importances")
importance = pd.Series(
model.feature_importances_,
index=feature_names
)
top = importance.sort_values(ascending=False).head(15)
fig, ax = plt.subplots()
top.sort_values().plot(kind="barh", ax=ax)
st.pyplot(fig)
# grouped importance
st.subheader("Grouped Importance")
family_imp = importance[importance.index.str.contains("family_")].sum()
other_imp = importance[~importance.index.str.contains("family_")]
grouped = pd.concat([
pd.Series({"family_total": family_imp}),
other_imp
]).sort_values(ascending=False).head(10)
fig2, ax2 = plt.subplots()
grouped.sort_values().plot(kind="barh", ax=ax2)
st.pyplot(fig2)
# ================= MODEL INFO =================
st.subheader("Model Info")
st.info(f"""
Model type: **{type(model).__name__}**
Features used: **{len(feature_names)}**
Log target: **{USE_LOG_TARGET}**
""")