Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import os | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import mean_squared_error | |
| # ================= SETTINGS ================= | |
| USE_LOG_TARGET = True | |
| st.set_page_config(page_title="Store Sales Forecasting", layout="wide") | |
| BASE_DIR = os.path.dirname(__file__) | |
| model = joblib.load(os.path.join(BASE_DIR, "model.pkl")) | |
| feature_names = joblib.load(os.path.join(BASE_DIR, "features.pkl")) | |
| # ================= LOAD TEST DATA ================= | |
| X_test_path = os.path.join(BASE_DIR, "X_test.npy") | |
| y_test_path = os.path.join(BASE_DIR, "y_test.npy") | |
| if os.path.exists(X_test_path): | |
| X_test = np.load(X_test_path) | |
| y_test = np.load(y_test_path) | |
| y_pred_test = model.predict(X_test) | |
| if USE_LOG_TARGET: | |
| y_pred_test = np.expm1(y_pred_test) | |
| y_test = np.expm1(y_test) | |
| rmse = np.sqrt(mean_squared_error(y_test, y_pred_test)) | |
| else: | |
| rmse = None | |
| # ================= TITLE ================= | |
| st.title("๐ Store Sales Forecasting") | |
| st.markdown("Predict daily store sales using Machine Learning.") | |
| tab1, tab2 = st.tabs(["๐ฎ Prediction", "๐ Model Insights"]) | |
| # ================= PREDICTION TAB ================= | |
| with tab1: | |
| st.subheader("Input Features") | |
| families = [c.replace("family_", "") for c in feature_names if "family_" in c] | |
| if st.button("๐ฒ Load Example"): | |
| store_nbr = 1 | |
| onpromotion = 5 | |
| date = pd.to_datetime("2017-08-15") | |
| family = families[0] | |
| else: | |
| date = st.date_input("Date") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| store_nbr = st.number_input("Store Number", 1) | |
| onpromotion = st.number_input("On Promotion", 0) | |
| with col2: | |
| family = st.selectbox("Product Family", families) | |
| year = date.year | |
| month = date.month | |
| day = date.day | |
| dayofweek = date.weekday() | |
| input_dict = dict.fromkeys(feature_names, 0) | |
| input_dict["store_nbr"] = store_nbr | |
| input_dict["onpromotion"] = onpromotion | |
| input_dict["year"] = year | |
| input_dict["month"] = month | |
| input_dict["day"] = day | |
| input_dict["dayofweek"] = dayofweek | |
| input_dict[f"family_{family}"] = 1 | |
| features = pd.DataFrame([input_dict]) | |
| # ================= PREDICT ================= | |
| if st.button("Predict Sales"): | |
| with st.spinner("Making prediction..."): | |
| pred = model.predict(features)[0] | |
| if USE_LOG_TARGET: | |
| pred = np.expm1(pred) | |
| st.markdown("## ๐ Predicted Sales") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.metric("๐ฐ Sales", f"{pred:,.2f}") | |
| with col2: | |
| st.metric("๐ช Store", store_nbr) | |
| # download | |
| result_df = pd.DataFrame({ | |
| "store_nbr": [store_nbr], | |
| "family": [family], | |
| "prediction": [pred] | |
| }) | |
| st.download_button( | |
| "โฌ Download prediction", | |
| result_df.to_csv(index=False), | |
| "prediction.csv", | |
| "text/csv" | |
| ) | |
| # ================= MODEL INSIGHTS ================= | |
| with tab2: | |
| st.subheader("Model Performance") | |
| if rmse: | |
| st.metric("RMSE", f"{rmse:,.2f}") | |
| else: | |
| st.info("Upload X_test.npy & y_test.npy to display RMSE.") | |
| # ================= FEATURE IMPORTANCE ================= | |
| if hasattr(model, "feature_importances_"): | |
| st.subheader("Top Feature Importances") | |
| importance = pd.Series( | |
| model.feature_importances_, | |
| index=feature_names | |
| ) | |
| top = importance.sort_values(ascending=False).head(15) | |
| fig, ax = plt.subplots() | |
| top.sort_values().plot(kind="barh", ax=ax) | |
| st.pyplot(fig) | |
| # grouped importance | |
| st.subheader("Grouped Importance") | |
| family_imp = importance[importance.index.str.contains("family_")].sum() | |
| other_imp = importance[~importance.index.str.contains("family_")] | |
| grouped = pd.concat([ | |
| pd.Series({"family_total": family_imp}), | |
| other_imp | |
| ]).sort_values(ascending=False).head(10) | |
| fig2, ax2 = plt.subplots() | |
| grouped.sort_values().plot(kind="barh", ax=ax2) | |
| st.pyplot(fig2) | |
| # ================= MODEL INFO ================= | |
| st.subheader("Model Info") | |
| st.info(f""" | |
| Model type: **{type(model).__name__}** | |
| Features used: **{len(feature_names)}** | |
| Log target: **{USE_LOG_TARGET}** | |
| """) |