Commit ·
84548c1
1
Parent(s): cc73c46
Add demo for Sale forcasting
Browse files- app/core/config.py +4 -6
- app/frontend/dashboard.py +380 -0
- app/frontend/data_viz.py +228 -0
- app/frontend/gradio_ui.py +209 -269
- app/frontend/ui_template.py +420 -0
- app/main.py +9 -0
- app/services/prediction.py +731 -0
- app/utils/data_generator.py +774 -0
- app/utils/data_loader.py +113 -0
- app/utils/plots.py +131 -0
- app/utils/utils.py +119 -0
- app/utils/visualization_code.py +522 -0
app/core/config.py
CHANGED
|
@@ -9,9 +9,8 @@ class Settings(BaseSettings):
|
|
| 9 |
|
| 10 |
# Server
|
| 11 |
HOST: str = "0.0.0.0"
|
| 12 |
-
# HOST: str = "127.0.0.1"
|
| 13 |
PORT: int = 5050
|
| 14 |
-
API_PREFIX: str = "
|
| 15 |
|
| 16 |
# Model
|
| 17 |
MODEL_CHECKPOINT: str = "yainage90/fashion-object-detection"
|
|
@@ -19,13 +18,12 @@ class Settings(BaseSettings):
|
|
| 19 |
|
| 20 |
# Security
|
| 21 |
SECRET_KEY: str = "xxx"
|
| 22 |
-
ALGORITHM: str = "HS256"
|
| 23 |
-
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 30 # 1 month
|
| 24 |
-
|
| 25 |
API_TOKEN: str = "xxx"
|
|
|
|
|
|
|
| 26 |
|
| 27 |
class Config:
|
| 28 |
-
|
| 29 |
case_sensitive = True
|
| 30 |
|
| 31 |
settings = Settings()
|
|
|
|
| 9 |
|
| 10 |
# Server
|
| 11 |
HOST: str = "0.0.0.0"
|
|
|
|
| 12 |
PORT: int = 5050
|
| 13 |
+
API_PREFIX: str = "x"
|
| 14 |
|
| 15 |
# Model
|
| 16 |
MODEL_CHECKPOINT: str = "yainage90/fashion-object-detection"
|
|
|
|
| 18 |
|
| 19 |
# Security
|
| 20 |
SECRET_KEY: str = "xxx"
|
|
|
|
|
|
|
|
|
|
| 21 |
API_TOKEN: str = "xxx"
|
| 22 |
+
ALGORITHM: str = ".xxx"
|
| 23 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
| 24 |
|
| 25 |
class Config:
|
| 26 |
+
env_file = ".env"
|
| 27 |
case_sensitive = True
|
| 28 |
|
| 29 |
settings = Settings()
|
app/frontend/dashboard.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
from app.frontend.data_viz import (
|
| 7 |
+
plot_category_distribution,
|
| 8 |
+
plot_day_of_week_pattern,
|
| 9 |
+
plot_sales_distribution,
|
| 10 |
+
plot_sales_time_series,
|
| 11 |
+
plot_store_comparison,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
# Mocking st.session_state for Gradio logic compatibility
|
| 15 |
+
class SessionState(dict):
|
| 16 |
+
def __getattr__(self, item): return self.get(item)
|
| 17 |
+
def __setattr__(self, key, value): self[key] = value
|
| 18 |
+
|
| 19 |
+
session_state = SessionState()
|
| 20 |
+
|
| 21 |
+
def configure_filters(data, start_date, end_date, selected_store_input, selected_categories):
|
| 22 |
+
"""Logic-only version of configure_filters (removing st.sidebar calls)"""
|
| 23 |
+
|
| 24 |
+
# Resolve store selection logic from the original code
|
| 25 |
+
selected_store = "All Stores"
|
| 26 |
+
selected_store_name = "All Stores"
|
| 27 |
+
|
| 28 |
+
if "store_name" in data.columns:
|
| 29 |
+
selected_store_name = selected_store_input
|
| 30 |
+
elif "store" in data.columns:
|
| 31 |
+
selected_store = selected_store_input
|
| 32 |
+
|
| 33 |
+
# Filter data based on selection
|
| 34 |
+
filtered_data = data.copy()
|
| 35 |
+
|
| 36 |
+
# Gradio strings to datetime.date
|
| 37 |
+
start_dt = pd.to_datetime(start_date).date()
|
| 38 |
+
end_dt = pd.to_datetime(end_date).date()
|
| 39 |
+
|
| 40 |
+
mask = (filtered_data["date"].dt.date >= start_dt) & (
|
| 41 |
+
filtered_data["date"].dt.date <= end_dt
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Apply store filter
|
| 45 |
+
if "store_name" in data.columns and selected_store_name != "All Stores":
|
| 46 |
+
mask &= filtered_data["store_name"] == selected_store_name
|
| 47 |
+
elif "store" in data.columns and selected_store != "All Stores":
|
| 48 |
+
mask &= filtered_data["store"] == selected_store
|
| 49 |
+
|
| 50 |
+
# Apply category filter
|
| 51 |
+
if selected_categories:
|
| 52 |
+
mask &= filtered_data["category"].isin(selected_categories)
|
| 53 |
+
|
| 54 |
+
# Update state for other functions
|
| 55 |
+
session_state.selected_store = selected_store
|
| 56 |
+
session_state.selected_store_name = selected_store_name
|
| 57 |
+
session_state.start_date = start_dt
|
| 58 |
+
session_state.end_date = end_dt
|
| 59 |
+
|
| 60 |
+
return filtered_data[mask]
|
| 61 |
+
|
| 62 |
+
def display_kpis(filtered_data):
|
| 63 |
+
"""Logic-only version of display_kpis (returns strings for UI)"""
|
| 64 |
+
total_sales = filtered_data["sales"].sum()
|
| 65 |
+
avg_daily_sales = filtered_data.groupby("date")["sales"].sum().mean()
|
| 66 |
+
|
| 67 |
+
if len(filtered_data["date"].unique()) >= 2:
|
| 68 |
+
mid_date = session_state.start_date + (session_state.end_date - session_state.start_date) / 2
|
| 69 |
+
period1_data = filtered_data[filtered_data["date"].dt.date <= mid_date]
|
| 70 |
+
period2_data = filtered_data[filtered_data["date"].dt.date > mid_date]
|
| 71 |
+
period1_sales = period1_data["sales"].sum() if not period1_data.empty else 0
|
| 72 |
+
period2_sales = period2_data["sales"].sum() if not period2_data.empty else 0
|
| 73 |
+
sales_change_pct = (((period2_sales - period1_sales) / period1_sales * 100) if period1_sales > 0 else 0)
|
| 74 |
+
else:
|
| 75 |
+
sales_change_pct = 0
|
| 76 |
+
|
| 77 |
+
if "transactions" in filtered_data.columns:
|
| 78 |
+
total_transactions = filtered_data["transactions"].sum()
|
| 79 |
+
else:
|
| 80 |
+
total_transactions = filtered_data.shape[0]
|
| 81 |
+
|
| 82 |
+
avg_transaction_value = (total_sales / total_transactions if total_transactions > 0 else 0)
|
| 83 |
+
|
| 84 |
+
# Return formatted strings for Gradio Label/Textbox components
|
| 85 |
+
return (
|
| 86 |
+
total_sales,
|
| 87 |
+
sales_change_pct,
|
| 88 |
+
avg_daily_sales,
|
| 89 |
+
total_transactions,
|
| 90 |
+
avg_transaction_value
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def display_sales_trends(filtered_data):
|
| 94 |
+
"""Logic-only version of display_sales_trends (returns figures)"""
|
| 95 |
+
fig1 = plot_sales_time_series(
|
| 96 |
+
filtered_data,
|
| 97 |
+
session_state.selected_store,
|
| 98 |
+
session_state.selected_store_name,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
fig2 = None
|
| 102 |
+
if len(filtered_data["date"].unique()) >= 7:
|
| 103 |
+
fig2 = plot_day_of_week_pattern(filtered_data)
|
| 104 |
+
|
| 105 |
+
return fig1, fig2
|
| 106 |
+
|
| 107 |
+
def display_performance_breakdown(filtered_data):
|
| 108 |
+
"""Logic-only version of display_performance_breakdown (returns DF and Fig)"""
|
| 109 |
+
category_df = pd.DataFrame()
|
| 110 |
+
fig_cat = None
|
| 111 |
+
store_df = pd.DataFrame()
|
| 112 |
+
fig_store = None
|
| 113 |
+
|
| 114 |
+
if "category" in filtered_data.columns and len(filtered_data["category"].unique()) > 1:
|
| 115 |
+
category_sales = filtered_data.groupby("category")["sales"].sum().sort_values(ascending=False)
|
| 116 |
+
category_sales_pct = (category_sales / category_sales.sum() * 100).round(1)
|
| 117 |
+
category_df = pd.DataFrame({"Sales": category_sales, "Percentage": category_sales_pct}).reset_index()
|
| 118 |
+
category_df["Sales"] = category_df["Sales"].apply(lambda x: f"${x:,.2f}")
|
| 119 |
+
category_df["Percentage"] = category_df["Percentage"].apply(lambda x: f"{x}%")
|
| 120 |
+
fig_cat = plot_category_distribution(filtered_data)
|
| 121 |
+
|
| 122 |
+
if (session_state.selected_store_name == "All Stores" and session_state.selected_store == "All Stores") and \
|
| 123 |
+
("store_name" in filtered_data.columns or "store" in filtered_data.columns):
|
| 124 |
+
store_identifier = "store_name" if "store_name" in filtered_data.columns else "store"
|
| 125 |
+
store_sales = filtered_data.groupby(store_identifier)["sales"].sum().sort_values(ascending=False)
|
| 126 |
+
top_stores = store_sales.head(10)
|
| 127 |
+
store_df = pd.DataFrame({"Store": top_stores.index, "Sales": top_stores.values})
|
| 128 |
+
store_df["Sales"] = store_df["Sales"].apply(lambda x: f"${x:,.2f}")
|
| 129 |
+
fig_store = plot_store_comparison(filtered_data, store_identifier)
|
| 130 |
+
|
| 131 |
+
return category_df, fig_cat, store_df, fig_store
|
| 132 |
+
|
| 133 |
+
def format_kpi_html(label, value_str, delta_pct=None):
|
| 134 |
+
"""Create HTML for metric"""
|
| 135 |
+
|
| 136 |
+
# Process Delta
|
| 137 |
+
delta_html = ""
|
| 138 |
+
if delta_pct is not None and delta_pct != 0:
|
| 139 |
+
if delta_pct > 0:
|
| 140 |
+
color = "color: #38a169;" # Greenn
|
| 141 |
+
arrow = "▲"
|
| 142 |
+
else:
|
| 143 |
+
color = "color: #e53e3e;" # Red
|
| 144 |
+
arrow = "▼"
|
| 145 |
+
|
| 146 |
+
# Format delta: Ví dụ: "▲ 4.2%"
|
| 147 |
+
delta_str = f"{arrow} {abs(delta_pct):.1f}%"
|
| 148 |
+
delta_html = f'<div style="{color} font-size: 14px; font-weight: 500; margin-top: 5px; line-height: 1;">{delta_str}</div>'
|
| 149 |
+
|
| 150 |
+
html_output = f"""
|
| 151 |
+
<div style="font-family: Arial, sans-serif; padding: 10px;">
|
| 152 |
+
<div style="font-size: 14px; color: #555; margin-bottom: 5px;">{label}</div>
|
| 153 |
+
<div style="font-size: 30px; font-weight: 600; color: #1a1a1a; line-height: 1;">{value_str}</div>
|
| 154 |
+
{delta_html}
|
| 155 |
+
</div>
|
| 156 |
+
"""
|
| 157 |
+
return html_output
|
| 158 |
+
|
| 159 |
+
def update_kpis_html(total_sales, sales_change_pct, avg_daily_sales, total_transactions, avg_transaction_value):
|
| 160 |
+
"""wrapper function update KPI HTML"""
|
| 161 |
+
|
| 162 |
+
html1 = format_kpi_html(
|
| 163 |
+
"💰 Total Sales",
|
| 164 |
+
f"${total_sales:,.2f}",
|
| 165 |
+
sales_change_pct
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
html2 = format_kpi_html(
|
| 169 |
+
"📊 Avg Daily Sales",
|
| 170 |
+
f"${avg_daily_sales:,.2f}"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
html3 = format_kpi_html(
|
| 174 |
+
"🛒 Total Transactions",
|
| 175 |
+
f"{total_transactions:,}"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
html4 = format_kpi_html(
|
| 179 |
+
"💵 Avg Transaction Value",
|
| 180 |
+
f"${avg_transaction_value:,.2f}"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return html1, html2, html3, html4
|
| 184 |
+
|
| 185 |
+
def historical_sales_view(data):
|
| 186 |
+
"""Main Gradio Interface Builder"""
|
| 187 |
+
|
| 188 |
+
def run_dashboard_update(start_date, end_date, store_selection, categories):
|
| 189 |
+
# 1. Logic: Filter
|
| 190 |
+
filtered_data = configure_filters(data, start_date, end_date, store_selection, categories)
|
| 191 |
+
|
| 192 |
+
if filtered_data.empty:
|
| 193 |
+
empty_msg = "⚠️ No data available for the selected filters. Please adjust your selections."
|
| 194 |
+
return [empty_msg] * 4 + [None] * 5 + [pd.DataFrame()]
|
| 195 |
+
|
| 196 |
+
# 2. Logic: KPIs
|
| 197 |
+
kpi_metrics = display_kpis(filtered_data)
|
| 198 |
+
html1, html2, html3, html4 = update_kpis_html(
|
| 199 |
+
*kpi_metrics
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# 3. Logic: Trends
|
| 203 |
+
fig_ts, fig_dow = display_sales_trends(filtered_data)
|
| 204 |
+
|
| 205 |
+
# 4. Logic: Breakdown
|
| 206 |
+
cat_df, fig_cat, store_df, fig_store = display_performance_breakdown(filtered_data)
|
| 207 |
+
|
| 208 |
+
# 5. Logic: Distribution
|
| 209 |
+
fig_dist = plot_sales_distribution(filtered_data)
|
| 210 |
+
|
| 211 |
+
# 6. Logic: Table
|
| 212 |
+
detailed_table = filtered_data.sort_values("date", ascending=False)
|
| 213 |
+
|
| 214 |
+
return (
|
| 215 |
+
html1, html2, html3, html4,
|
| 216 |
+
fig_ts, fig_dow,
|
| 217 |
+
cat_df, fig_cat,
|
| 218 |
+
store_df, fig_store,
|
| 219 |
+
fig_dist,
|
| 220 |
+
detailed_table
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Define the App Layout (Compatible with older Gradio versions)
|
| 224 |
+
with gr.Blocks(title="Store Sales Dashboard") as demo:
|
| 225 |
+
# Left Sidebar - Filters (Fixed)
|
| 226 |
+
with gr.Sidebar(position="right"):
|
| 227 |
+
gr.Markdown("## 🔍 Dashboard Filters")
|
| 228 |
+
gr.Markdown("---")
|
| 229 |
+
|
| 230 |
+
# Date Filters
|
| 231 |
+
gr.Markdown("### 📅 Date Range")
|
| 232 |
+
min_date = data["date"].min().date()
|
| 233 |
+
max_date = data["date"].max().date()
|
| 234 |
+
|
| 235 |
+
start_in = gr.DateTime(
|
| 236 |
+
label="From",
|
| 237 |
+
value=str(min_date),
|
| 238 |
+
type="string",
|
| 239 |
+
interactive=True
|
| 240 |
+
)
|
| 241 |
+
end_in = gr.DateTime(
|
| 242 |
+
label="To",
|
| 243 |
+
value=str(max_date),
|
| 244 |
+
type="string",
|
| 245 |
+
interactive=True
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
gr.Markdown("---")
|
| 249 |
+
|
| 250 |
+
# Store Filter
|
| 251 |
+
gr.Markdown("### 🏬 Store Selection")
|
| 252 |
+
if "store_name" in data.columns:
|
| 253 |
+
opts = ["All Stores"] + sorted(data["store_name"].unique().tolist())
|
| 254 |
+
elif "store" in data.columns:
|
| 255 |
+
opts = ["All Stores"] + sorted(data["store"].unique().tolist())
|
| 256 |
+
else:
|
| 257 |
+
opts = ["All Stores"]
|
| 258 |
+
|
| 259 |
+
store_in = gr.Dropdown(
|
| 260 |
+
choices=opts,
|
| 261 |
+
value="All Stores",
|
| 262 |
+
label="Select Store",
|
| 263 |
+
interactive=True
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Category Filter
|
| 267 |
+
cat_in = None
|
| 268 |
+
if "category" in data.columns:
|
| 269 |
+
gr.Markdown("---")
|
| 270 |
+
gr.Markdown("### 📦 Product Categories")
|
| 271 |
+
cats = sorted(data["category"].unique().tolist())
|
| 272 |
+
cat_in = gr.CheckboxGroup(
|
| 273 |
+
choices=cats,
|
| 274 |
+
value=cats,
|
| 275 |
+
label="Select Categories",
|
| 276 |
+
interactive=True
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
gr.Markdown("---")
|
| 280 |
+
btn = gr.Button("🔄 Update Dashboard", variant="primary", size="lg")
|
| 281 |
+
|
| 282 |
+
gr.Markdown(
|
| 283 |
+
"""
|
| 284 |
+
<br>
|
| 285 |
+
💡 **Tip:** Adjust filters and click Update to refresh
|
| 286 |
+
"""
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Right Column - Main Dashboard
|
| 290 |
+
with gr.Column(scale=1):
|
| 291 |
+
# Header
|
| 292 |
+
gr.Markdown(
|
| 293 |
+
"""
|
| 294 |
+
# 📊 Store Sales Dashboard
|
| 295 |
+
### Comprehensive sales analytics and performance insights
|
| 296 |
+
"""
|
| 297 |
+
)
|
| 298 |
+
# KPI Section
|
| 299 |
+
gr.Markdown("## 📈 Key Performance Indicators")
|
| 300 |
+
with gr.Row():
|
| 301 |
+
m1 = gr.HTML(label=None, scale=1, container=True)
|
| 302 |
+
m2 = gr.HTML(label=None, scale=1, container=True)
|
| 303 |
+
m3 = gr.HTML(label=None, scale=1, container=True)
|
| 304 |
+
m4 = gr.HTML(label=None, scale=1, container=True)
|
| 305 |
+
|
| 306 |
+
gr.Markdown("---")
|
| 307 |
+
|
| 308 |
+
# Sales Trends Section
|
| 309 |
+
gr.Markdown("## 📉 Sales Trends Analysis")
|
| 310 |
+
with gr.Row():
|
| 311 |
+
p_ts = gr.Plot(label="📈 Sales Time Series", container=True, scale=1)
|
| 312 |
+
p_dow = gr.Plot(label="📅 Weekly Patterns", container=True, scale=1)
|
| 313 |
+
|
| 314 |
+
gr.Markdown("---")
|
| 315 |
+
|
| 316 |
+
# Performance Breakdown Section
|
| 317 |
+
gr.Markdown("## 🎯 Performance Breakdown")
|
| 318 |
+
|
| 319 |
+
# Category Performance Section
|
| 320 |
+
gr.Markdown("### 📦 Category Performance")
|
| 321 |
+
with gr.Row():
|
| 322 |
+
with gr.Column(scale=1):
|
| 323 |
+
df_cat = gr.DataFrame(label="Category Sales Data", max_height=300)
|
| 324 |
+
with gr.Column(scale=1):
|
| 325 |
+
p_cat = gr.Plot(label="Sales by Category", container=True)
|
| 326 |
+
|
| 327 |
+
gr.Markdown("---")
|
| 328 |
+
|
| 329 |
+
# Store Comparison Section
|
| 330 |
+
gr.Markdown("### 🏪 Store Comparison (Top 10)")
|
| 331 |
+
with gr.Row():
|
| 332 |
+
with gr.Column(scale=1):
|
| 333 |
+
df_store = gr.DataFrame(label="Top Performing Stores", max_height=300)
|
| 334 |
+
with gr.Column(scale=2):
|
| 335 |
+
p_store = gr.Plot(label="Top 10 Stores by Sales", container=True)
|
| 336 |
+
|
| 337 |
+
gr.Markdown("---")
|
| 338 |
+
|
| 339 |
+
# Sales Distribution Section
|
| 340 |
+
gr.Markdown("## 📊 Sales Distribution")
|
| 341 |
+
p_dist = gr.Plot(label="Distribution Analysis", container=True)
|
| 342 |
+
|
| 343 |
+
gr.Markdown("---")
|
| 344 |
+
|
| 345 |
+
# Detailed Data Section
|
| 346 |
+
with gr.Accordion("📋 View Detailed Sales Data", open=True):
|
| 347 |
+
gr.Markdown("*Complete transaction history for the selected period*")
|
| 348 |
+
df_detailed = gr.DataFrame(max_height=400)
|
| 349 |
+
|
| 350 |
+
# Footer
|
| 351 |
+
gr.Markdown(
|
| 352 |
+
"""
|
| 353 |
+
---
|
| 354 |
+
<div style='text-align: center; color: #666; font-size: 0.9em;'>
|
| 355 |
+
📊 Store Sales Dashboard | Powered by Gradio
|
| 356 |
+
</div>
|
| 357 |
+
"""
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Link event - Update button
|
| 361 |
+
btn.click(
|
| 362 |
+
run_dashboard_update,
|
| 363 |
+
inputs=[start_in, end_in, store_in, cat_in],
|
| 364 |
+
outputs=[m1, m2, m3, m4, p_ts, p_dow, df_cat, p_cat, df_store, p_store, p_dist, df_detailed]
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Auto-load initial data on page load
|
| 368 |
+
demo.load(
|
| 369 |
+
run_dashboard_update,
|
| 370 |
+
inputs=[start_in, end_in, store_in, cat_in],
|
| 371 |
+
outputs=[m1, m2, m3, m4, p_ts, p_dow, df_cat, p_cat, df_store, p_store, p_dist, df_detailed]
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
return demo
|
| 375 |
+
|
| 376 |
+
# Usage:
|
| 377 |
+
# if __name__ == "__main__":
|
| 378 |
+
# df = pd.read_csv("your_data.csv", parse_dates=['date'])
|
| 379 |
+
# app = historical_sales_view(df)
|
| 380 |
+
# app.launch()
|
app/frontend/data_viz.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def plot_sales_forecast(
|
| 8 |
+
historical_data, prediction_date, prediction_value, store_id=None
|
| 9 |
+
):
|
| 10 |
+
"""
|
| 11 |
+
Plot historical sales with prediction point
|
| 12 |
+
"""
|
| 13 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 14 |
+
|
| 15 |
+
# Filter for specific store if provided
|
| 16 |
+
if store_id is not None and "store" in historical_data.columns:
|
| 17 |
+
plot_data = historical_data[historical_data["store"] == store_id].copy()
|
| 18 |
+
else:
|
| 19 |
+
plot_data = historical_data.copy()
|
| 20 |
+
|
| 21 |
+
# Group by date if multiple records per date
|
| 22 |
+
if len(plot_data) > len(plot_data["date"].unique()):
|
| 23 |
+
plot_data = plot_data.groupby("date")["sales"].sum().reset_index()
|
| 24 |
+
|
| 25 |
+
# Sort by date
|
| 26 |
+
plot_data = plot_data.sort_values("date")
|
| 27 |
+
|
| 28 |
+
# Plot historical data
|
| 29 |
+
ax.plot(plot_data["date"], plot_data["sales"], label="Historical Sales")
|
| 30 |
+
|
| 31 |
+
# Add prediction point
|
| 32 |
+
ax.scatter(
|
| 33 |
+
prediction_date, prediction_value, color="red", s=100, label="Prediction"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Formatting
|
| 37 |
+
ax.set_xlabel("Date")
|
| 38 |
+
ax.set_ylabel("Sales")
|
| 39 |
+
if store_id is not None:
|
| 40 |
+
ax.set_title(f"Sales Forecast for Store {store_id}")
|
| 41 |
+
else:
|
| 42 |
+
ax.set_title("Sales Forecast")
|
| 43 |
+
ax.legend()
|
| 44 |
+
fig.autofmt_xdate()
|
| 45 |
+
|
| 46 |
+
return fig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def plot_sales_time_series(
|
| 50 |
+
filtered_data, selected_store=None, selected_store_name=None
|
| 51 |
+
):
|
| 52 |
+
"""Generate time series plot of sales with moving average"""
|
| 53 |
+
fig, ax = plt.subplots(figsize=(7, 6))
|
| 54 |
+
|
| 55 |
+
# Plot data based on store selection
|
| 56 |
+
if selected_store_name == "All Stores" and selected_store == "All Stores":
|
| 57 |
+
# Group by date for the trend line
|
| 58 |
+
sales_by_date = filtered_data.groupby("date")["sales"].sum()
|
| 59 |
+
ax.plot(sales_by_date.index, sales_by_date.values, "b-")
|
| 60 |
+
|
| 61 |
+
# Add moving average
|
| 62 |
+
if len(sales_by_date) > 7:
|
| 63 |
+
sales_by_date_df = sales_by_date.reset_index()
|
| 64 |
+
sales_by_date_df["MA7"] = sales_by_date_df["sales"].rolling(window=7).mean()
|
| 65 |
+
ax.plot(
|
| 66 |
+
sales_by_date_df["date"],
|
| 67 |
+
sales_by_date_df["MA7"],
|
| 68 |
+
"r--",
|
| 69 |
+
label="7-Day Moving Avg",
|
| 70 |
+
)
|
| 71 |
+
ax.legend()
|
| 72 |
+
else:
|
| 73 |
+
# Single store - show daily sales and trend
|
| 74 |
+
sales_by_date = filtered_data.groupby("date")["sales"].sum()
|
| 75 |
+
ax.plot(sales_by_date.index, sales_by_date.values, "b-")
|
| 76 |
+
|
| 77 |
+
# Add moving average if enough data
|
| 78 |
+
if len(sales_by_date) > 7:
|
| 79 |
+
sales_by_date_df = sales_by_date.reset_index()
|
| 80 |
+
sales_by_date_df["MA7"] = sales_by_date_df["sales"].rolling(window=7).mean()
|
| 81 |
+
ax.plot(
|
| 82 |
+
sales_by_date_df["date"],
|
| 83 |
+
sales_by_date_df["MA7"],
|
| 84 |
+
"r--",
|
| 85 |
+
label="7-Day Moving Avg",
|
| 86 |
+
)
|
| 87 |
+
ax.legend()
|
| 88 |
+
|
| 89 |
+
ax.set_xlabel("")
|
| 90 |
+
ax.set_ylabel("Sales ($)")
|
| 91 |
+
|
| 92 |
+
if "store_name" in filtered_data.columns and selected_store_name != "All Stores":
|
| 93 |
+
ax.set_title(f"Daily Sales - {selected_store_name}")
|
| 94 |
+
elif "store" in filtered_data.columns and selected_store != "All Stores":
|
| 95 |
+
ax.set_title(f"Daily Sales - Store {selected_store}")
|
| 96 |
+
else:
|
| 97 |
+
ax.set_title("Daily Sales - All Stores")
|
| 98 |
+
|
| 99 |
+
fig.autofmt_xdate()
|
| 100 |
+
return fig
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def plot_day_of_week_pattern(filtered_data):
|
| 104 |
+
"""Generate bar chart showing sales by day of week"""
|
| 105 |
+
fig, ax = plt.subplots(figsize=(7, 7))
|
| 106 |
+
|
| 107 |
+
# Add day of week name
|
| 108 |
+
day_names = [
|
| 109 |
+
"Monday",
|
| 110 |
+
"Tuesday",
|
| 111 |
+
"Wednesday",
|
| 112 |
+
"Thursday",
|
| 113 |
+
"Friday",
|
| 114 |
+
"Saturday",
|
| 115 |
+
"Sunday",
|
| 116 |
+
]
|
| 117 |
+
filtered_data["day_name"] = filtered_data["date"].dt.dayofweek.apply(
|
| 118 |
+
lambda x: day_names[x]
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Group by day of week
|
| 122 |
+
day_sales = filtered_data.groupby("day_name")["sales"].mean().reindex(day_names)
|
| 123 |
+
|
| 124 |
+
# Calculate average line
|
| 125 |
+
avg_daily = day_sales.mean()
|
| 126 |
+
|
| 127 |
+
# Create bar chart with average line
|
| 128 |
+
bars = ax.bar(day_sales.index, day_sales.values, color="skyblue")
|
| 129 |
+
ax.axhline(y=avg_daily, color="red", linestyle="--", label="Daily Average")
|
| 130 |
+
|
| 131 |
+
# Highlight best and worst days
|
| 132 |
+
best_day = day_sales.idxmax()
|
| 133 |
+
worst_day = day_sales.idxmin()
|
| 134 |
+
|
| 135 |
+
for i, (day, sales) in enumerate(day_sales.items()):
|
| 136 |
+
if day == best_day:
|
| 137 |
+
bars[i].set_color("green")
|
| 138 |
+
elif day == worst_day:
|
| 139 |
+
bars[i].set_color("orange")
|
| 140 |
+
|
| 141 |
+
ax.set_xlabel("")
|
| 142 |
+
ax.set_ylabel("Average Sales ($)")
|
| 143 |
+
ax.set_title("Sales by Day of Week")
|
| 144 |
+
plt.xticks(rotation=45)
|
| 145 |
+
ax.legend()
|
| 146 |
+
|
| 147 |
+
return fig
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def plot_category_distribution(filtered_data):
|
| 151 |
+
"""Generate pie chart of sales by category"""
|
| 152 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 153 |
+
|
| 154 |
+
category_sales = (
|
| 155 |
+
filtered_data.groupby("category")["sales"].sum().sort_values(ascending=False)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
top_categories = category_sales.head(5)
|
| 159 |
+
others = category_sales.iloc[5:].sum() if len(category_sales) > 5 else 0
|
| 160 |
+
|
| 161 |
+
if others > 0:
|
| 162 |
+
plot_data = pd.concat([top_categories, pd.Series([others], index=["Others"])])
|
| 163 |
+
else:
|
| 164 |
+
plot_data = top_categories
|
| 165 |
+
|
| 166 |
+
plt.pie(
|
| 167 |
+
plot_data,
|
| 168 |
+
labels=plot_data.index,
|
| 169 |
+
autopct="%1.1f%%",
|
| 170 |
+
startangle=90,
|
| 171 |
+
shadow=False,
|
| 172 |
+
)
|
| 173 |
+
plt.axis("equal")
|
| 174 |
+
plt.title("Sales by Category")
|
| 175 |
+
|
| 176 |
+
return fig
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def plot_store_comparison(filtered_data, store_identifier="store"):
|
| 180 |
+
"""Generate horizontal bar chart for top stores by sales"""
|
| 181 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 182 |
+
|
| 183 |
+
# Group by store
|
| 184 |
+
store_sales = (
|
| 185 |
+
filtered_data.groupby(store_identifier)["sales"]
|
| 186 |
+
.sum()
|
| 187 |
+
.sort_values(ascending=False)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Take top 10 stores
|
| 191 |
+
top_stores = store_sales.head(10)
|
| 192 |
+
|
| 193 |
+
# Plot horizontal bar chart
|
| 194 |
+
y_pos = np.arange(len(top_stores))
|
| 195 |
+
ax.barh(y_pos, top_stores.values, align="center")
|
| 196 |
+
ax.set_yticks(y_pos)
|
| 197 |
+
ax.set_yticklabels(top_stores.index)
|
| 198 |
+
ax.invert_yaxis() # Labels read top-to-bottom
|
| 199 |
+
ax.set_xlabel("Sales ($)")
|
| 200 |
+
ax.set_title("Top 10 Stores by Sales")
|
| 201 |
+
|
| 202 |
+
return fig
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def plot_sales_distribution(filtered_data):
|
| 206 |
+
"""Generate histogram with KDE and summary statistics"""
|
| 207 |
+
fig, ax = plt.subplots(figsize=(18, 4))
|
| 208 |
+
|
| 209 |
+
# Create histogram with KDE
|
| 210 |
+
sns.histplot(filtered_data["sales"], bins=30, kde=True, ax=ax)
|
| 211 |
+
|
| 212 |
+
# Add vertical lines for key statistics
|
| 213 |
+
median_sales = filtered_data["sales"].median()
|
| 214 |
+
mean_sales = filtered_data["sales"].mean()
|
| 215 |
+
|
| 216 |
+
ax.axvline(
|
| 217 |
+
x=median_sales, color="r", linestyle="--", label=f"Median: ${median_sales:.2f}"
|
| 218 |
+
)
|
| 219 |
+
ax.axvline(
|
| 220 |
+
x=mean_sales, color="g", linestyle="--", label=f"Mean: ${mean_sales:.2f}"
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
ax.set_xlabel("Sales ($)")
|
| 224 |
+
ax.set_ylabel("Frequency")
|
| 225 |
+
ax.set_title("Sales Distribution")
|
| 226 |
+
ax.legend()
|
| 227 |
+
|
| 228 |
+
return fig
|
app/frontend/gradio_ui.py
CHANGED
|
@@ -1,59 +1,65 @@
|
|
| 1 |
-
import os
|
| 2 |
import gradio as gr
|
| 3 |
import requests
|
| 4 |
from PIL import Image, ImageDraw, ImageFont
|
| 5 |
import io
|
| 6 |
from typing import List, Dict, Any
|
| 7 |
-
from datetime import datetime
|
| 8 |
-
from pathlib import Path
|
| 9 |
import random
|
| 10 |
-
from app.core.config import settings
|
| 11 |
-
from app.core.security import create_access_token
|
| 12 |
|
| 13 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
try:
|
| 15 |
from app.utils.logger import logger
|
|
|
|
|
|
|
| 16 |
except ImportError:
|
|
|
|
|
|
|
|
|
|
| 17 |
import logging
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
-
logging.basicConfig(level=logging.INFO)
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
except ImportError:
|
| 26 |
-
HAS_UI_TEMPLATE = False
|
| 27 |
-
logger.warning("ui_template not found, using basic styling")
|
| 28 |
|
| 29 |
-
#
|
| 30 |
API_BASE_URL = "http://localhost:5050"
|
| 31 |
API_VERSION = "v1"
|
| 32 |
API_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/detect/image"
|
| 33 |
API_HEALTH_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/health"
|
| 34 |
API_BATCH_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/detect/batch"
|
| 35 |
-
|
| 36 |
SHOW_GRADIO_API = "undocumented"
|
| 37 |
|
| 38 |
|
| 39 |
-
# ==================== Fashion Detection Client ====================
|
| 40 |
class FashionDetectionClient:
|
| 41 |
"""Client for interacting with the Fashion Detection API"""
|
| 42 |
|
| 43 |
def __init__(self, base_url: str = API_BASE_URL, token: str = None):
|
| 44 |
self.base_url = base_url
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
else:
|
| 48 |
-
token = settings.API_TOKEN
|
| 49 |
-
self.token = token
|
| 50 |
-
self.headers = {"X-Token": token}
|
| 51 |
self.session = requests.Session()
|
| 52 |
self.session.headers.update(self.headers)
|
| 53 |
|
| 54 |
def check_health(self) -> Dict[str, Any]:
|
| 55 |
"""Check API health status"""
|
| 56 |
-
logger.info(">>> check_health called")
|
| 57 |
try:
|
| 58 |
response = self.session.get(API_HEALTH_ENDPOINT, timeout=10)
|
| 59 |
response.raise_for_status()
|
|
@@ -69,7 +75,6 @@ class FashionDetectionClient:
|
|
| 69 |
|
| 70 |
def detect_single_image(self, image: Image.Image, threshold: float = 0.4) -> Dict[str, Any]:
|
| 71 |
"""Detect objects in a single image"""
|
| 72 |
-
logger.info(">>> detect_single_image function")
|
| 73 |
try:
|
| 74 |
img_byte_arr = io.BytesIO()
|
| 75 |
image.save(img_byte_arr, format='PNG')
|
|
@@ -77,7 +82,7 @@ class FashionDetectionClient:
|
|
| 77 |
|
| 78 |
files = {"file": ("image.png", img_byte_arr, "image/png")}
|
| 79 |
params = {"threshold": threshold} if threshold else {}
|
| 80 |
-
|
| 81 |
response = self.session.post(
|
| 82 |
API_ENDPOINT,
|
| 83 |
files=files,
|
|
@@ -85,14 +90,12 @@ class FashionDetectionClient:
|
|
| 85 |
timeout=30
|
| 86 |
)
|
| 87 |
response.raise_for_status()
|
| 88 |
-
logger.info(f">>response {response}")
|
| 89 |
return response.json()
|
| 90 |
|
| 91 |
except requests.exceptions.RequestException as e:
|
| 92 |
-
logger.info(f"Lỗi: {response.status_code}. Chi tiết: {response.json()}, API_TOKEN={self.token}")
|
| 93 |
return {
|
| 94 |
"success": False,
|
| 95 |
-
"error": f"API request failed: {str(e)}
|
| 96 |
"details": f"URL: {API_ENDPOINT}"
|
| 97 |
}
|
| 98 |
except Exception as e:
|
|
@@ -129,7 +132,6 @@ class FashionDetectionClient:
|
|
| 129 |
}
|
| 130 |
|
| 131 |
|
| 132 |
-
# ==================== Drawing Functions ====================
|
| 133 |
def draw_bounding_boxes_pil(image: Image.Image, detections: List[Dict[str, Any]]) -> Image.Image:
|
| 134 |
"""Draw bounding boxes on PIL Image"""
|
| 135 |
img_with_boxes = image.copy()
|
|
@@ -212,62 +214,12 @@ def format_detection_results(result: Dict[str, Any]) -> str:
|
|
| 212 |
return result_text
|
| 213 |
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
"""Convert Gradio NamedString objects (file paths) to PIL Images"""
|
| 218 |
-
pil_images = []
|
| 219 |
-
for file in gradio_files:
|
| 220 |
-
try:
|
| 221 |
-
file_path = file.name if hasattr(file, 'name') else file
|
| 222 |
-
pil_image = Image.open(file_path)
|
| 223 |
-
if pil_image.mode != "RGB":
|
| 224 |
-
pil_image = pil_image.convert("RGB")
|
| 225 |
-
pil_images.append(pil_image)
|
| 226 |
-
except Exception as e:
|
| 227 |
-
logger.error(f"Error converting image {file_path}: {str(e)}")
|
| 228 |
-
return pil_images
|
| 229 |
-
|
| 230 |
-
#
|
| 231 |
-
def generate_test_token():
|
| 232 |
-
access_token = create_access_token(
|
| 233 |
-
data={"sub": "test_user"},
|
| 234 |
-
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 235 |
-
)
|
| 236 |
-
print(f"Generated test token: {access_token}")
|
| 237 |
-
return access_token
|
| 238 |
-
|
| 239 |
-
# ==================== Main Application ====================
|
| 240 |
-
def create_app():
|
| 241 |
-
"""Create the main Gradio application"""
|
| 242 |
-
|
| 243 |
-
# Configure UI template if available
|
| 244 |
-
if HAS_UI_TEMPLATE:
|
| 245 |
-
ui.configure(
|
| 246 |
-
project_name="Intelligent Retail Decision Making System",
|
| 247 |
-
year="2025",
|
| 248 |
-
about="AI-powered fashion detection and retail analytics",
|
| 249 |
-
description="An integrated platform for fashion item detection and sales forecasting.",
|
| 250 |
-
colors={
|
| 251 |
-
"primary": "#0F6CBD",
|
| 252 |
-
"accent": "#C4314B",
|
| 253 |
-
"success": "#2E7D32",
|
| 254 |
-
"bg1": "#F0F7FF",
|
| 255 |
-
"bg2": "#E8F0FA",
|
| 256 |
-
"bg3": "#DDE7F8"
|
| 257 |
-
},
|
| 258 |
-
meta_items=[
|
| 259 |
-
("Model", "Fashion Detection & Sales Forecasting"),
|
| 260 |
-
("Features", "Object Detection & Predictive Analytics"),
|
| 261 |
-
]
|
| 262 |
-
)
|
| 263 |
-
|
| 264 |
-
# Initialize API client
|
| 265 |
-
api_client = FashionDetectionClient()
|
| 266 |
|
| 267 |
-
# ==================== Prediction Functions ====================
|
| 268 |
def predict_single_image(image: Image.Image, threshold: float) -> tuple:
|
| 269 |
"""Predict objects in a single image"""
|
| 270 |
-
logger.info(">>> predict_single_image
|
| 271 |
try:
|
| 272 |
health_status = api_client.check_health()
|
| 273 |
if not health_status.get('success', False):
|
|
@@ -275,20 +227,20 @@ def create_app():
|
|
| 275 |
|
| 276 |
result = api_client.detect_single_image(image, threshold)
|
| 277 |
result_text = format_detection_results(result)
|
| 278 |
-
|
| 279 |
if result.get('success', False) and result.get('detections'):
|
| 280 |
image_with_boxes = draw_bounding_boxes_pil(image, result['detections'])
|
| 281 |
return image_with_boxes, result_text
|
| 282 |
else:
|
| 283 |
return image, result_text
|
| 284 |
-
|
| 285 |
except Exception as e:
|
| 286 |
error_msg = f"❌ Prediction error: {str(e)}"
|
| 287 |
return image, error_msg
|
| 288 |
|
| 289 |
def predict_batch_images(images: List[Image.Image], threshold: float):
|
| 290 |
"""Predict objects in multiple images"""
|
| 291 |
-
logger.info(">>> predict_batch_images
|
| 292 |
try:
|
| 293 |
if not images:
|
| 294 |
return [], "Please upload at least one image."
|
|
@@ -329,21 +281,32 @@ def create_app():
|
|
| 329 |
except Exception as e:
|
| 330 |
return [], f"❌ Batch prediction error: {str(e)}"
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
def check_api_health():
|
| 333 |
"""Check and display API health status"""
|
| 334 |
-
logger.info(">>> check_api_health
|
| 335 |
health_status = api_client.check_health()
|
|
|
|
| 336 |
|
| 337 |
-
if health_status.get('success', False)
|
| 338 |
-
|
| 339 |
-
status_text = "Healthy"
|
| 340 |
-
else:
|
| 341 |
-
status_emoji = "❌"
|
| 342 |
-
status_text = "Unhealthy"
|
| 343 |
|
| 344 |
health_info = f"{status_emoji} API Status: {status_text}\n\n"
|
| 345 |
health_info += f"📡 Endpoint: {API_BASE_URL}\n"
|
| 346 |
-
health_info += f"
|
| 347 |
|
| 348 |
if health_status.get('success', False):
|
| 349 |
health_info += f"🚀 Version: {health_status.get('version', 'N/A')}\n"
|
|
@@ -355,193 +318,171 @@ def create_app():
|
|
| 355 |
|
| 356 |
return health_info
|
| 357 |
|
| 358 |
-
|
| 359 |
-
"
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
|
| 382 |
-
gr.Markdown("AI-powered fashion detection and retail analytics platform")
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
gr.HTML(ui.render_info_card(
|
| 387 |
-
icon="📈",
|
| 388 |
-
title="About this Application"
|
| 389 |
-
))
|
| 390 |
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
# Tab 1: Fashion Detection
|
| 394 |
-
with gr.Tab("👔 Fashion Detection"):
|
| 395 |
-
gr.Markdown("## Fashion Item Detection")
|
| 396 |
-
|
| 397 |
-
# API Health Section
|
| 398 |
-
with gr.Row():
|
| 399 |
-
with gr.Column():
|
| 400 |
-
gr.Markdown("### 📊 API Status")
|
| 401 |
-
health_btn = gr.Button("Check API Health", variant="secondary")
|
| 402 |
-
health_output = gr.Textbox(label="API Health Status", lines=6, interactive=False)
|
| 403 |
-
|
| 404 |
-
# Single Image Detection
|
| 405 |
-
with gr.Row():
|
| 406 |
-
with gr.Column():
|
| 407 |
-
gr.Markdown("### 📷 Single Image Detection")
|
| 408 |
-
single_image = gr.Image(type="pil", label="Upload Fashion Image")
|
| 409 |
-
threshold_slider = gr.Slider(
|
| 410 |
-
minimum=0.1, maximum=0.9, value=0.4, step=0.05,
|
| 411 |
-
label="Detection Confidence Threshold"
|
| 412 |
-
)
|
| 413 |
-
single_btn = gr.Button("Detect Objects", variant="primary")
|
| 414 |
-
|
| 415 |
-
with gr.Column():
|
| 416 |
-
single_output_image = gr.Image(label="Detection Results", interactive=False)
|
| 417 |
-
single_output_text = gr.Textbox(label="Detection Results", lines=12)
|
| 418 |
-
|
| 419 |
-
# Batch Image Detection
|
| 420 |
-
with gr.Row():
|
| 421 |
-
with gr.Column():
|
| 422 |
-
gr.Markdown("### 📦 Batch Image Detection")
|
| 423 |
-
batch_images = gr.File(
|
| 424 |
-
label="Upload Multiple Images",
|
| 425 |
-
file_count="multiple",
|
| 426 |
-
file_types=["image"]
|
| 427 |
-
)
|
| 428 |
-
batch_threshold = gr.Slider(
|
| 429 |
-
minimum=0.1, maximum=0.9, value=0.4, step=0.05,
|
| 430 |
-
label="Detection Confidence Threshold"
|
| 431 |
-
)
|
| 432 |
-
batch_btn = gr.Button("Process Batch", variant="primary")
|
| 433 |
-
|
| 434 |
-
with gr.Column():
|
| 435 |
-
batch_output_images = gr.Gallery(
|
| 436 |
-
label="Detection Results",
|
| 437 |
-
columns=3,
|
| 438 |
-
height="auto",
|
| 439 |
-
interactive=False
|
| 440 |
-
)
|
| 441 |
-
batch_output_text = gr.Textbox(label="Batch Results", lines=15)
|
| 442 |
-
|
| 443 |
-
# Examples
|
| 444 |
-
if os.path.exists("static/examples"):
|
| 445 |
-
gr.Examples(
|
| 446 |
-
examples=[
|
| 447 |
-
["static/examples/image1.png"],
|
| 448 |
-
["static/examples/image2.png"],
|
| 449 |
-
["static/examples/image3.png"]
|
| 450 |
-
],
|
| 451 |
-
inputs=single_image,
|
| 452 |
-
label="Try these example images"
|
| 453 |
-
)
|
| 454 |
-
|
| 455 |
-
# Event handlers
|
| 456 |
-
health_btn.click(
|
| 457 |
-
fn=check_api_health,
|
| 458 |
-
outputs=health_output,
|
| 459 |
-
api_visibility=SHOW_GRADIO_API
|
| 460 |
-
)
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
|
|
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
-
# Tab 2: Historical Sales Analysis
|
| 477 |
-
with gr.Tab("📊 Historical Sales Analysis"):
|
| 478 |
-
gr.Markdown("### Explore and visualize historical sales data")
|
| 479 |
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
"Load Historical Analysis",
|
| 483 |
-
variant="primary",
|
| 484 |
-
size="lg"
|
| 485 |
-
)
|
| 486 |
|
| 487 |
-
|
|
|
|
| 488 |
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
with gr.Column(scale=2):
|
| 525 |
-
gr.Markdown("#### Prediction Results")
|
| 526 |
-
prediction_output = gr.HTML(label="Forecast")
|
| 527 |
-
|
| 528 |
-
predict_btn.click(
|
| 529 |
-
fn=make_prediction,
|
| 530 |
-
inputs=[date_input, forecast_horizon],
|
| 531 |
-
outputs=prediction_output
|
| 532 |
-
)
|
| 533 |
|
| 534 |
# Footer
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 545 |
gr.Markdown("© 2025 Intelligent Retail System. All rights reserved.")
|
| 546 |
|
| 547 |
return demo
|
|
@@ -550,7 +491,7 @@ def create_app():
|
|
| 550 |
# ==================== Application Entry Point ====================
|
| 551 |
def main():
|
| 552 |
"""Main entry point"""
|
| 553 |
-
demo =
|
| 554 |
|
| 555 |
# Custom CSS
|
| 556 |
custom_css = """
|
|
@@ -559,8 +500,7 @@ def main():
|
|
| 559 |
.error {color: red; font-weight: bold;}
|
| 560 |
"""
|
| 561 |
|
| 562 |
-
|
| 563 |
-
custom_css = ui.get_custom_css() + custom_css
|
| 564 |
|
| 565 |
demo.launch(
|
| 566 |
server_name="0.0.0.0",
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import requests
|
| 3 |
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
import io
|
| 5 |
from typing import List, Dict, Any
|
| 6 |
+
from datetime import datetime
|
|
|
|
| 7 |
import random
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
# Import modules from sales forecasting app
|
| 10 |
+
try:
|
| 11 |
+
import app.frontend.ui_template as ui
|
| 12 |
+
from app.utils.data_loader import (
|
| 13 |
+
load_data,
|
| 14 |
+
load_feature_engineered_data,
|
| 15 |
+
load_feature_stats,
|
| 16 |
+
load_model,
|
| 17 |
+
)
|
| 18 |
+
from app.frontend.dashboard import historical_sales_view
|
| 19 |
+
from app.services.prediction import sales_prediction_view
|
| 20 |
+
SALES_MODULE_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
SALES_MODULE_AVAILABLE = False
|
| 23 |
+
print("Warning: Sales forecasting modules not available")
|
| 24 |
+
|
| 25 |
+
# Import fashion detection modules
|
| 26 |
try:
|
| 27 |
from app.utils.logger import logger
|
| 28 |
+
from app.core.config import settings
|
| 29 |
+
FASHION_MODULE_AVAILABLE = True
|
| 30 |
except ImportError:
|
| 31 |
+
FASHION_MODULE_AVAILABLE = False
|
| 32 |
+
print("Warning: Fashion detection modules not available")
|
| 33 |
+
# Fallback logger
|
| 34 |
import logging
|
| 35 |
logger = logging.getLogger(__name__)
|
|
|
|
| 36 |
|
| 37 |
+
# Fallback settings
|
| 38 |
+
class Settings:
|
| 39 |
+
API_TOKEN = "your-api-token-here"
|
| 40 |
+
settings = Settings()
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
# Configuration for Fashion Detection API
|
| 43 |
API_BASE_URL = "http://localhost:5050"
|
| 44 |
API_VERSION = "v1"
|
| 45 |
API_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/detect/image"
|
| 46 |
API_HEALTH_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/health"
|
| 47 |
API_BATCH_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/detect/batch"
|
|
|
|
| 48 |
SHOW_GRADIO_API = "undocumented"
|
| 49 |
|
| 50 |
|
|
|
|
| 51 |
class FashionDetectionClient:
|
| 52 |
"""Client for interacting with the Fashion Detection API"""
|
| 53 |
|
| 54 |
def __init__(self, base_url: str = API_BASE_URL, token: str = None):
|
| 55 |
self.base_url = base_url
|
| 56 |
+
self.token = token or (settings.API_TOKEN if FASHION_MODULE_AVAILABLE else "default-token")
|
| 57 |
+
self.headers = {"X-Token": self.token}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
self.session = requests.Session()
|
| 59 |
self.session.headers.update(self.headers)
|
| 60 |
|
| 61 |
def check_health(self) -> Dict[str, Any]:
|
| 62 |
"""Check API health status"""
|
|
|
|
| 63 |
try:
|
| 64 |
response = self.session.get(API_HEALTH_ENDPOINT, timeout=10)
|
| 65 |
response.raise_for_status()
|
|
|
|
| 75 |
|
| 76 |
def detect_single_image(self, image: Image.Image, threshold: float = 0.4) -> Dict[str, Any]:
|
| 77 |
"""Detect objects in a single image"""
|
|
|
|
| 78 |
try:
|
| 79 |
img_byte_arr = io.BytesIO()
|
| 80 |
image.save(img_byte_arr, format='PNG')
|
|
|
|
| 82 |
|
| 83 |
files = {"file": ("image.png", img_byte_arr, "image/png")}
|
| 84 |
params = {"threshold": threshold} if threshold else {}
|
| 85 |
+
|
| 86 |
response = self.session.post(
|
| 87 |
API_ENDPOINT,
|
| 88 |
files=files,
|
|
|
|
| 90 |
timeout=30
|
| 91 |
)
|
| 92 |
response.raise_for_status()
|
|
|
|
| 93 |
return response.json()
|
| 94 |
|
| 95 |
except requests.exceptions.RequestException as e:
|
|
|
|
| 96 |
return {
|
| 97 |
"success": False,
|
| 98 |
+
"error": f"API request failed: {str(e)}",
|
| 99 |
"details": f"URL: {API_ENDPOINT}"
|
| 100 |
}
|
| 101 |
except Exception as e:
|
|
|
|
| 132 |
}
|
| 133 |
|
| 134 |
|
|
|
|
| 135 |
def draw_bounding_boxes_pil(image: Image.Image, detections: List[Dict[str, Any]]) -> Image.Image:
|
| 136 |
"""Draw bounding boxes on PIL Image"""
|
| 137 |
img_with_boxes = image.copy()
|
|
|
|
| 214 |
return result_text
|
| 215 |
|
| 216 |
|
| 217 |
+
def create_fashion_detection_tab(api_client: FashionDetectionClient):
|
| 218 |
+
"""Create the Fashion Detection tab"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
|
|
|
| 220 |
def predict_single_image(image: Image.Image, threshold: float) -> tuple:
|
| 221 |
"""Predict objects in a single image"""
|
| 222 |
+
logger.info(">>> predict_single_image ping clicked")
|
| 223 |
try:
|
| 224 |
health_status = api_client.check_health()
|
| 225 |
if not health_status.get('success', False):
|
|
|
|
| 227 |
|
| 228 |
result = api_client.detect_single_image(image, threshold)
|
| 229 |
result_text = format_detection_results(result)
|
| 230 |
+
|
| 231 |
if result.get('success', False) and result.get('detections'):
|
| 232 |
image_with_boxes = draw_bounding_boxes_pil(image, result['detections'])
|
| 233 |
return image_with_boxes, result_text
|
| 234 |
else:
|
| 235 |
return image, result_text
|
| 236 |
+
|
| 237 |
except Exception as e:
|
| 238 |
error_msg = f"❌ Prediction error: {str(e)}"
|
| 239 |
return image, error_msg
|
| 240 |
|
| 241 |
def predict_batch_images(images: List[Image.Image], threshold: float):
|
| 242 |
"""Predict objects in multiple images"""
|
| 243 |
+
logger.info(">>> predict_batch_images ping clicked")
|
| 244 |
try:
|
| 245 |
if not images:
|
| 246 |
return [], "Please upload at least one image."
|
|
|
|
| 281 |
except Exception as e:
|
| 282 |
return [], f"❌ Batch prediction error: {str(e)}"
|
| 283 |
|
| 284 |
+
def convert_to_pil_images(gradio_files: List) -> List[Image.Image]:
|
| 285 |
+
"""Convert Gradio file objects to PIL Images"""
|
| 286 |
+
pil_images = []
|
| 287 |
+
for file in gradio_files:
|
| 288 |
+
try:
|
| 289 |
+
file_path = file.name if hasattr(file, 'name') else file
|
| 290 |
+
pil_image = Image.open(file_path)
|
| 291 |
+
if pil_image.mode != "RGB":
|
| 292 |
+
pil_image = pil_image.convert("RGB")
|
| 293 |
+
pil_images.append(pil_image)
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.error(f"Error converting image {file_path}: {str(e)}")
|
| 296 |
+
return pil_images
|
| 297 |
+
|
| 298 |
def check_api_health():
|
| 299 |
"""Check and display API health status"""
|
| 300 |
+
logger.info(">>> check_api_health ping clicked")
|
| 301 |
health_status = api_client.check_health()
|
| 302 |
+
logger.info(health_status)
|
| 303 |
|
| 304 |
+
status_emoji = "✅" if health_status.get('success', False) else "❌"
|
| 305 |
+
status_text = "Healthy" if health_status.get('success', False) else "Unhealthy"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
health_info = f"{status_emoji} API Status: {status_text}\n\n"
|
| 308 |
health_info += f"📡 Endpoint: {API_BASE_URL}\n"
|
| 309 |
+
health_info += f"🕐 Checked: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
| 310 |
|
| 311 |
if health_status.get('success', False):
|
| 312 |
health_info += f"🚀 Version: {health_status.get('version', 'N/A')}\n"
|
|
|
|
| 318 |
|
| 319 |
return health_info
|
| 320 |
|
| 321 |
+
with gr.Column():
|
| 322 |
+
gr.Markdown("# 👔 Fashion Detection System")
|
| 323 |
+
gr.Markdown("Upload images to detect fashion items using our AI-powered API")
|
| 324 |
+
|
| 325 |
+
# API Health Section
|
| 326 |
+
with gr.Row():
|
| 327 |
+
with gr.Column():
|
| 328 |
+
gr.Markdown("## 📊 API Status")
|
| 329 |
+
health_btn = gr.Button("Check API Health", variant="secondary")
|
| 330 |
+
health_output = gr.Textbox(label="API Health Status", lines=6, interactive=False)
|
| 331 |
+
|
| 332 |
+
# Single Image Detection
|
| 333 |
+
with gr.Row():
|
| 334 |
+
with gr.Column():
|
| 335 |
+
gr.Markdown("## 📷 Single Image Detection")
|
| 336 |
+
single_image = gr.Image(type="pil", label="Upload Fashion Image")
|
| 337 |
+
threshold_slider = gr.Slider(
|
| 338 |
+
minimum=0.1, maximum=0.9, value=0.4, step=0.05,
|
| 339 |
+
label="Detection Confidence Threshold"
|
| 340 |
+
)
|
| 341 |
+
single_btn = gr.Button("Detect Objects", variant="primary")
|
| 342 |
+
|
| 343 |
+
with gr.Column():
|
| 344 |
+
single_output_image = gr.Image(label="Detection Results", interactive=False)
|
| 345 |
+
single_output_text = gr.Textbox(label="Detection Results", lines=12)
|
| 346 |
+
|
| 347 |
+
# Batch Image Detection
|
| 348 |
+
with gr.Row():
|
| 349 |
+
with gr.Column():
|
| 350 |
+
gr.Markdown("## 📦 Batch Image Detection")
|
| 351 |
+
batch_images = gr.File(
|
| 352 |
+
label="Upload Multiple Images",
|
| 353 |
+
file_count="multiple",
|
| 354 |
+
file_types=["image"]
|
| 355 |
+
)
|
| 356 |
+
batch_threshold = gr.Slider(
|
| 357 |
+
minimum=0.1, maximum=0.9, value=0.4, step=0.05,
|
| 358 |
+
label="Detection Confidence Threshold"
|
| 359 |
+
)
|
| 360 |
+
batch_btn = gr.Button("Process Batch", variant="primary")
|
| 361 |
+
|
| 362 |
+
with gr.Column():
|
| 363 |
+
batch_output_images = gr.Gallery(
|
| 364 |
+
label="Detection Results",
|
| 365 |
+
columns=3,
|
| 366 |
+
height="auto",
|
| 367 |
+
interactive=False
|
| 368 |
+
)
|
| 369 |
+
batch_output_text = gr.Textbox(label="Batch Results", lines=15)
|
| 370 |
+
|
| 371 |
+
# Examples
|
| 372 |
+
gr.Examples(
|
| 373 |
+
examples=[
|
| 374 |
+
["static/examples/image1.png"],
|
| 375 |
+
["static/examples/image2.png"],
|
| 376 |
+
["static/examples/image3.png"]
|
| 377 |
+
],
|
| 378 |
+
inputs=single_image,
|
| 379 |
+
label="Try these example images"
|
| 380 |
+
)
|
| 381 |
|
| 382 |
+
# Event handlers
|
| 383 |
+
health_btn.click(
|
| 384 |
+
fn=check_api_health,
|
| 385 |
+
outputs=health_output,
|
| 386 |
+
api_visibility=SHOW_GRADIO_API
|
| 387 |
+
)
|
| 388 |
|
| 389 |
+
single_btn.click(
|
| 390 |
+
fn=predict_single_image,
|
| 391 |
+
inputs=[single_image, threshold_slider],
|
| 392 |
+
outputs=[single_output_image, single_output_text],
|
| 393 |
+
api_visibility=SHOW_GRADIO_API
|
| 394 |
+
)
|
| 395 |
|
| 396 |
+
batch_btn.click(
|
| 397 |
+
fn=lambda images, threshold: predict_batch_images(convert_to_pil_images(images), threshold),
|
| 398 |
+
inputs=[batch_images, batch_threshold],
|
| 399 |
+
outputs=[batch_output_images, batch_output_text],
|
| 400 |
+
api_visibility=SHOW_GRADIO_API
|
| 401 |
+
)
|
| 402 |
|
|
|
|
| 403 |
|
| 404 |
+
def create_sales_forecasting_tab(data, model, feature_stats):
|
| 405 |
+
"""Create the Sales Forecasting tab"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
+
with gr.Column():
|
| 408 |
+
gr.Markdown("# 📈 Sales Forecasting System")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
+
# Page selector for sales forecasting sub-sections
|
| 411 |
+
page_selector = gr.Dropdown(
|
| 412 |
+
choices=["Historical Sales Analysis", "Sales Prediction"],
|
| 413 |
+
value="Historical Sales Analysis",
|
| 414 |
+
label="Choose a view",
|
| 415 |
+
interactive=True
|
| 416 |
+
)
|
| 417 |
|
| 418 |
+
# Render content based on selection
|
| 419 |
+
@gr.render(inputs=page_selector)
|
| 420 |
+
def render_sales_content(page):
|
| 421 |
+
if page == "Historical Sales Analysis":
|
| 422 |
+
historical_sales_view(data)
|
| 423 |
+
else:
|
| 424 |
+
print("Loading feature engineered data for prediction...")
|
| 425 |
+
feature_engineered_data = load_feature_engineered_data()
|
| 426 |
+
sales_prediction_view(data, model, feature_stats, feature_engineered_data)
|
| 427 |
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
+
def create_gradio_interface():
|
| 430 |
+
"""Create the Gradio application"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
|
| 432 |
+
# Initialize API client for fashion detection
|
| 433 |
+
api_client = FashionDetectionClient()
|
| 434 |
|
| 435 |
+
# Load sales forecasting data if available
|
| 436 |
+
sales_data = None
|
| 437 |
+
sales_model = None
|
| 438 |
+
sales_feature_stats = None
|
| 439 |
|
| 440 |
+
if SALES_MODULE_AVAILABLE:
|
| 441 |
+
try:
|
| 442 |
+
sales_data = load_data()
|
| 443 |
+
sales_model = load_model()
|
| 444 |
+
sales_feature_stats = load_feature_stats()
|
| 445 |
+
except Exception as e:
|
| 446 |
+
print(f"Warning: Could not load sales forecasting data: {e}")
|
| 447 |
+
|
| 448 |
+
# Create main interface
|
| 449 |
+
with gr.Blocks(
|
| 450 |
+
title="💡 Intelligent Retail Decision Making System",
|
| 451 |
+
) as demo:
|
| 452 |
+
|
| 453 |
+
gr.Markdown("# 💡 Intelligent Retail Decision Making System")
|
| 454 |
+
gr.Markdown("### Comprehensive AI-powered solution for retail analytics and product detection")
|
| 455 |
+
|
| 456 |
+
# Main navigation tabs
|
| 457 |
+
with gr.Tabs():
|
| 458 |
+
# Fashion Detection Tab
|
| 459 |
+
with gr.Tab("👔 Fashion Detection"):
|
| 460 |
+
create_fashion_detection_tab(api_client)
|
| 461 |
+
|
| 462 |
+
# Sales Forecasting Tab
|
| 463 |
+
if SALES_MODULE_AVAILABLE and sales_data is not None:
|
| 464 |
+
with gr.Tab("📈 Sales Forecasting"):
|
| 465 |
+
create_sales_forecasting_tab(sales_data, sales_model, sales_feature_stats)
|
| 466 |
+
else:
|
| 467 |
+
with gr.Tab("📈 Sales Forecasting"):
|
| 468 |
+
gr.Markdown("## ⚠️ Sales Forecasting Module Not Available")
|
| 469 |
+
gr.Markdown("Please ensure all required dependencies are installed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
# Footer
|
| 472 |
+
try:
|
| 473 |
+
if SALES_MODULE_AVAILABLE:
|
| 474 |
+
ui.create_footer(
|
| 475 |
+
logo_path="static/intelligent_retail.png",
|
| 476 |
+
creator_name="Thi-Diem-My Le",
|
| 477 |
+
creator_link="https://beacons.ai/elizabethmyn",
|
| 478 |
+
org_name="AI VIET NAM",
|
| 479 |
+
org_link="https://aivietnam.edu.vn/"
|
| 480 |
+
)
|
| 481 |
+
else:
|
| 482 |
+
gr.Markdown("---")
|
| 483 |
+
gr.Markdown("### Created by Thi-Diem-My Le | AI VIET NAM")
|
| 484 |
+
except Exception as e:
|
| 485 |
+
print(f"Warning: Could not create footer: {e}")
|
| 486 |
gr.Markdown("© 2025 Intelligent Retail System. All rights reserved.")
|
| 487 |
|
| 488 |
return demo
|
|
|
|
| 491 |
# ==================== Application Entry Point ====================
|
| 492 |
def main():
|
| 493 |
"""Main entry point"""
|
| 494 |
+
demo = create_gradio_interface()
|
| 495 |
|
| 496 |
# Custom CSS
|
| 497 |
custom_css = """
|
|
|
|
| 500 |
.error {color: red; font-weight: bold;}
|
| 501 |
"""
|
| 502 |
|
| 503 |
+
custom_css = ui.get_custom_css() + custom_css
|
|
|
|
| 504 |
|
| 505 |
demo.launch(
|
| 506 |
server_name="0.0.0.0",
|
app/frontend/ui_template.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import base64
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Dict, List, Tuple
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ThemeConfig:
|
| 9 |
+
"""Centralized theme configuration with validation."""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
# Default color palette
|
| 13 |
+
self.primary_color = "#0F6CBD"
|
| 14 |
+
self.accent_color = "#C4314B"
|
| 15 |
+
self.success_color = "#2E7D32"
|
| 16 |
+
self.bg1 = "#F0F7FF"
|
| 17 |
+
self.bg2 = "#E8F0FA"
|
| 18 |
+
self.bg3 = "#DDE7F8"
|
| 19 |
+
self.font_family = (
|
| 20 |
+
"'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', "
|
| 21 |
+
"Roboto, 'Helvetica Neue', Arial, sans-serif"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Metadata
|
| 25 |
+
self.project_name = "Heart Project"
|
| 26 |
+
self.year = "2025"
|
| 27 |
+
self.about = ""
|
| 28 |
+
self.description = ""
|
| 29 |
+
self.meta_items: List[Tuple[str, str]] = []
|
| 30 |
+
|
| 31 |
+
# Cache for CSS
|
| 32 |
+
self._css_cache: Optional[str] = None
|
| 33 |
+
|
| 34 |
+
def update_colors(self, **kwargs) -> None:
|
| 35 |
+
"""Update color scheme with validation."""
|
| 36 |
+
valid_keys = {'primary', 'accent', 'success', 'bg1', 'bg2', 'bg3'}
|
| 37 |
+
for key, value in kwargs.items():
|
| 38 |
+
if key not in valid_keys or value is None:
|
| 39 |
+
continue
|
| 40 |
+
if not self._is_valid_color(value):
|
| 41 |
+
raise ValueError(f"Invalid color format for {key}: {value}")
|
| 42 |
+
setattr(self, f"{key}_color" if not key.startswith('bg') else key, value)
|
| 43 |
+
self._invalidate_cache()
|
| 44 |
+
|
| 45 |
+
def update_font(self, font_family: str) -> None:
|
| 46 |
+
"""Update font family."""
|
| 47 |
+
if font_family and isinstance(font_family, str):
|
| 48 |
+
self.font_family = font_family
|
| 49 |
+
self._invalidate_cache()
|
| 50 |
+
|
| 51 |
+
def update_meta(self, project_name: Optional[str] = None,
|
| 52 |
+
year: Optional[str] = None,
|
| 53 |
+
about: Optional[str] = None,
|
| 54 |
+
description: Optional[str] = None,
|
| 55 |
+
meta_items: Optional[List[Tuple[str, str]]] = None) -> None:
|
| 56 |
+
"""Update metadata."""
|
| 57 |
+
if project_name is not None:
|
| 58 |
+
self.project_name = project_name
|
| 59 |
+
if year is not None:
|
| 60 |
+
self.year = year
|
| 61 |
+
if about is not None:
|
| 62 |
+
self.about = about
|
| 63 |
+
if description is not None:
|
| 64 |
+
self.description = description
|
| 65 |
+
if meta_items is not None:
|
| 66 |
+
self.meta_items = meta_items
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def _is_valid_color(color: str) -> bool:
|
| 70 |
+
"""Validate hex color format."""
|
| 71 |
+
return isinstance(color, str) and (
|
| 72 |
+
color.startswith('#') and len(color) in (4, 7, 9)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def _invalidate_cache(self) -> None:
|
| 76 |
+
"""Clear CSS cache when theme changes."""
|
| 77 |
+
self._css_cache = None
|
| 78 |
+
|
| 79 |
+
def get_css(self) -> str:
|
| 80 |
+
"""Get or generate CSS with caching."""
|
| 81 |
+
if self._css_cache is None:
|
| 82 |
+
self._css_cache = self._build_css()
|
| 83 |
+
return self._css_cache
|
| 84 |
+
|
| 85 |
+
def _build_css(self) -> str:
|
| 86 |
+
"""Build the complete CSS string."""
|
| 87 |
+
return f"""
|
| 88 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
|
| 89 |
+
|
| 90 |
+
.gradio-container {{
|
| 91 |
+
min-height: 100vh !important;
|
| 92 |
+
width: 100vw !important;
|
| 93 |
+
margin: 0 !important;
|
| 94 |
+
padding: 0px !important;
|
| 95 |
+
background: linear-gradient(135deg, {self.bg1} 0%, {self.bg2} 50%, {self.bg3} 100%);
|
| 96 |
+
background-size: 600% 600%;
|
| 97 |
+
animation: gradientBG 7s ease infinite;
|
| 98 |
+
}}
|
| 99 |
+
|
| 100 |
+
/* Global font setup */
|
| 101 |
+
body, .gradio-container, .gr-block, .gr-markdown, .gr-button, .gr-input,
|
| 102 |
+
.gr-dropdown, .gr-number, .gr-plot, .gr-dataframe, .gr-accordion, .gr-form,
|
| 103 |
+
.gr-textbox, .gr-html, table, th, td, label, h1, h2, h3, h4, h5, h6, p, span, div {{
|
| 104 |
+
font-family: {self.font_family} !important;
|
| 105 |
+
}}
|
| 106 |
+
|
| 107 |
+
@keyframes gradientBG {{
|
| 108 |
+
0% {{background-position: 0% 50%;}}
|
| 109 |
+
50% {{background-position: 100% 50%;}}
|
| 110 |
+
100% {{background-position: 0% 50%;}}
|
| 111 |
+
}}
|
| 112 |
+
|
| 113 |
+
/* Minimize spacing and padding */
|
| 114 |
+
.content-wrap {{
|
| 115 |
+
padding: 2px !important;
|
| 116 |
+
margin: 0 !important;
|
| 117 |
+
}}
|
| 118 |
+
|
| 119 |
+
/* Reduce component spacing */
|
| 120 |
+
.gr-row {{
|
| 121 |
+
gap: 5px !important;
|
| 122 |
+
margin: 2px 0 !important;
|
| 123 |
+
}}
|
| 124 |
+
|
| 125 |
+
.gr-column {{
|
| 126 |
+
gap: 4px !important;
|
| 127 |
+
padding: 4px !important;
|
| 128 |
+
}}
|
| 129 |
+
|
| 130 |
+
/* Accordion optimization */
|
| 131 |
+
.gr-accordion {{
|
| 132 |
+
margin: 4px 0 !important;
|
| 133 |
+
}}
|
| 134 |
+
|
| 135 |
+
.gr-accordion .gr-accordion-content {{
|
| 136 |
+
padding: 2px !important;
|
| 137 |
+
}}
|
| 138 |
+
|
| 139 |
+
/* Form elements spacing */
|
| 140 |
+
.gr-form {{
|
| 141 |
+
gap: 2px !important;
|
| 142 |
+
}}
|
| 143 |
+
|
| 144 |
+
/* Button styling */
|
| 145 |
+
.gr-button {{
|
| 146 |
+
margin: 2px 0 !important;
|
| 147 |
+
}}
|
| 148 |
+
|
| 149 |
+
/* DataFrame optimization */
|
| 150 |
+
.gr-dataframe {{
|
| 151 |
+
margin: 4px 0 !important;
|
| 152 |
+
}}
|
| 153 |
+
|
| 154 |
+
/* Remove horizontal scroll from data preview */
|
| 155 |
+
.gr-dataframe .wrap {{
|
| 156 |
+
overflow-x: auto !important;
|
| 157 |
+
max-width: 100% !important;
|
| 158 |
+
}}
|
| 159 |
+
|
| 160 |
+
/* Plot optimization */
|
| 161 |
+
.gr-plot {{
|
| 162 |
+
margin: 4px 0 !important;
|
| 163 |
+
}}
|
| 164 |
+
|
| 165 |
+
/* Reduce markdown margins */
|
| 166 |
+
.gr-markdown {{
|
| 167 |
+
margin: 2px 0 !important;
|
| 168 |
+
}}
|
| 169 |
+
|
| 170 |
+
/* Footer positioning */
|
| 171 |
+
.sticky-footer {{
|
| 172 |
+
position: fixed;
|
| 173 |
+
bottom: 0px;
|
| 174 |
+
left: 0;
|
| 175 |
+
width: 100%;
|
| 176 |
+
background: {self.bg1};
|
| 177 |
+
padding: 6px !important;
|
| 178 |
+
box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
|
| 179 |
+
z-index: 1000;
|
| 180 |
+
}}
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Global theme instance
|
| 185 |
+
_theme = ThemeConfig()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def configure(project_name: Optional[str] = None,
|
| 189 |
+
year: Optional[str] = None,
|
| 190 |
+
about: Optional[str] = None,
|
| 191 |
+
description: Optional[str] = None,
|
| 192 |
+
colors: Optional[Dict[str, str]] = None,
|
| 193 |
+
font_family: Optional[str] = None,
|
| 194 |
+
meta_items: Optional[List[Tuple[str, str]]] = None) -> None:
|
| 195 |
+
"""
|
| 196 |
+
One-call configuration for the entire theme.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
project_name: Name of the project
|
| 200 |
+
year: Project year
|
| 201 |
+
about: About project
|
| 202 |
+
description: Project description
|
| 203 |
+
colors: Dict with keys: primary, accent, success, bg1, bg2, bg3
|
| 204 |
+
font_family: CSS font family string
|
| 205 |
+
meta_items: List of (label, value) tuples for metadata
|
| 206 |
+
"""
|
| 207 |
+
if colors:
|
| 208 |
+
_theme.update_colors(**colors)
|
| 209 |
+
if font_family:
|
| 210 |
+
_theme.update_font(font_family)
|
| 211 |
+
_theme.update_meta(project_name, year, about, description, meta_items)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def get_custom_css() -> str:
|
| 215 |
+
"""Get the current custom CSS."""
|
| 216 |
+
return _theme.get_css()
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _image_to_base64(image_path: str) -> str:
|
| 220 |
+
"""
|
| 221 |
+
Convert image to base64 string with better error handling.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
image_path: Relative path to image file
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Base64 encoded string
|
| 228 |
+
|
| 229 |
+
Raises:
|
| 230 |
+
FileNotFoundError: If image file doesn't exist
|
| 231 |
+
"""
|
| 232 |
+
current_dir = Path(__file__).parent
|
| 233 |
+
full_path = current_dir / image_path
|
| 234 |
+
|
| 235 |
+
if not full_path.exists():
|
| 236 |
+
raise FileNotFoundError(f"Image not found: {full_path}")
|
| 237 |
+
|
| 238 |
+
with open(full_path, "rb") as f:
|
| 239 |
+
return base64.b64encode(f.read()).decode("utf-8")
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def create_header(logo_path: str = "static/intelligent_retail.png") -> None:
|
| 243 |
+
"""
|
| 244 |
+
Create a header with logo and project name.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
logo_path: Path to logo image
|
| 248 |
+
"""
|
| 249 |
+
with gr.Row():
|
| 250 |
+
with gr.Column(scale=2):
|
| 251 |
+
try:
|
| 252 |
+
logo_base64 = _image_to_base64(logo_path)
|
| 253 |
+
gr.HTML(
|
| 254 |
+
f"""<img src="data:image/png;base64,{logo_base64}"
|
| 255 |
+
alt="Logo"
|
| 256 |
+
style="height:100px;width:auto;margin:0 auto;margin-bottom:18px;display:block;">"""
|
| 257 |
+
)
|
| 258 |
+
except FileNotFoundError:
|
| 259 |
+
gr.HTML("<div style='text-align:center;color:#999;'>Logo not found</div>")
|
| 260 |
+
|
| 261 |
+
with gr.Column(scale=2):
|
| 262 |
+
gr.HTML(f"""
|
| 263 |
+
<div style="display:flex;justify-content:flex-start;align-items:center;gap:30px;">
|
| 264 |
+
<div>
|
| 265 |
+
<h1 style="margin-bottom:0;color:{_theme.primary_color};font-size:2.32em;font-weight:bold;">
|
| 266 |
+
{_theme.project_name}
|
| 267 |
+
</h1>
|
| 268 |
+
<p style="margin-top:4px;font-size:1.1em;color:#555;">{_theme.about}</p>
|
| 269 |
+
</div>
|
| 270 |
+
</div>
|
| 271 |
+
""")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def create_footer(logo_path: str = "static/intelligent_retail.png",
|
| 275 |
+
creator_name: str = "Thi-Diem-My Le",
|
| 276 |
+
creator_link: str = "https://beacons.ai/elizabethmyn",
|
| 277 |
+
org_name: str = "AI VIET NAM",
|
| 278 |
+
org_link: str = "https://aivietnam.edu.vn/") -> gr.HTML:
|
| 279 |
+
"""
|
| 280 |
+
Create a sticky footer with creator information.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
logo_path: Path to logo image
|
| 284 |
+
creator_name: Name of creator
|
| 285 |
+
creator_link: Link to creator profile
|
| 286 |
+
org_name: Organization name
|
| 287 |
+
org_link: Link to organization
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
Gradio HTML component
|
| 291 |
+
"""
|
| 292 |
+
try:
|
| 293 |
+
logo_base64 = _image_to_base64(logo_path)
|
| 294 |
+
logo_html = f'<img src="data:image/png;base64,{logo_base64}" alt="Logo" style="height:0px;width:auto;">'
|
| 295 |
+
except FileNotFoundError:
|
| 296 |
+
logo_html = ""
|
| 297 |
+
|
| 298 |
+
footer_html = f"""
|
| 299 |
+
<style>
|
| 300 |
+
.sticky-footer{{
|
| 301 |
+
position:fixed;
|
| 302 |
+
bottom:0px;
|
| 303 |
+
left:0;
|
| 304 |
+
width:100%;
|
| 305 |
+
background:#E8F5E8;
|
| 306 |
+
padding:10px;
|
| 307 |
+
box-shadow:0 -2px 10px rgba(0,0,0,0.1);
|
| 308 |
+
z-index:1000;
|
| 309 |
+
}}
|
| 310 |
+
.content-wrap{{padding-bottom:60px;}}
|
| 311 |
+
</style>
|
| 312 |
+
<div class="sticky-footer">
|
| 313 |
+
<div style="text-align:center;font-size:18px;color:#888">
|
| 314 |
+
Created by
|
| 315 |
+
<a href="{creator_link}" target="_blank"
|
| 316 |
+
style="color:#465C88;text-decoration:none;font-weight:bold;display:inline-flex;align-items:center;">
|
| 317 |
+
{creator_name}
|
| 318 |
+
{logo_html}
|
| 319 |
+
</a>
|
| 320 |
+
from
|
| 321 |
+
<a href="{org_link}" target="_blank"
|
| 322 |
+
style="color:#355724;text-decoration:none;font-weight:bold;">
|
| 323 |
+
{org_name}
|
| 324 |
+
</a>
|
| 325 |
+
</div>
|
| 326 |
+
</div>
|
| 327 |
+
"""
|
| 328 |
+
return gr.HTML(footer_html)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def render_info_card(description: Optional[str] = None,
|
| 332 |
+
meta_items: Optional[List[Tuple[str, str]]] = None,
|
| 333 |
+
icon: str = "🧠",
|
| 334 |
+
title: str = "About this demo") -> str:
|
| 335 |
+
"""
|
| 336 |
+
Render an informational card.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
description: Card description text
|
| 340 |
+
meta_items: List of (label, value) tuples
|
| 341 |
+
icon: Emoji or icon for the card
|
| 342 |
+
title: Card title
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
HTML string for the card
|
| 346 |
+
"""
|
| 347 |
+
desc = description if description is not None else _theme.description
|
| 348 |
+
items = meta_items if meta_items is not None else _theme.meta_items
|
| 349 |
+
|
| 350 |
+
meta_html = ""
|
| 351 |
+
if items:
|
| 352 |
+
meta_html = "".join([f"<span><strong>{k}</strong>: {v}</span><br>" for k, v in items])
|
| 353 |
+
|
| 354 |
+
return f"""
|
| 355 |
+
<div style="margin:8px 0 8px 0;">
|
| 356 |
+
<div style="background:#F5F9FF;border-left:6px solid {_theme.primary_color};
|
| 357 |
+
padding:14px 16px;border-radius:10px;box-shadow:0 1px 3px rgba(0,0,0,0.06);">
|
| 358 |
+
<div style="display:flex;gap:14px;align-items:flex-start;">
|
| 359 |
+
<div style="font-size:22px;">{icon}</div>
|
| 360 |
+
<div>
|
| 361 |
+
<div style="font-weight:700;color:{_theme.primary_color};margin-bottom:4px;">{title}</div>
|
| 362 |
+
<div style="color:#000;font-size:14px;line-height:1.5;">{desc}</div>
|
| 363 |
+
{f'<div style="margin-top:8px;color:#000;font-size:13px;">{meta_html}</div>' if meta_html else ''}
|
| 364 |
+
</div>
|
| 365 |
+
</div>
|
| 366 |
+
</div>
|
| 367 |
+
</div>
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def render_disclaimer(text: str,
|
| 372 |
+
icon: str = "⚠️",
|
| 373 |
+
title: str = "Educational Use Only") -> str:
|
| 374 |
+
"""
|
| 375 |
+
Render a disclaimer/warning card.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
text: Warning text
|
| 379 |
+
icon: Warning icon/emoji
|
| 380 |
+
title: Warning title
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
HTML string for the disclaimer
|
| 384 |
+
"""
|
| 385 |
+
return f"""
|
| 386 |
+
<div style="margin:8px 0 6px 0;">
|
| 387 |
+
<div style="background:#FFF4F4;border-left:6px solid {_theme.accent_color};
|
| 388 |
+
padding:12px 16px;border-radius:8px;box-shadow:0 1px 3px rgba(0,0,0,0.06);">
|
| 389 |
+
<div style="display:flex;gap:10px;align-items:flex-start;color:#000;">
|
| 390 |
+
<span style="font-size:20px">{icon}</span>
|
| 391 |
+
<div>
|
| 392 |
+
<div style="font-weight:700;margin-bottom:4px;">{title}</div>
|
| 393 |
+
<div style="font-size:14px;line-height:1.4;">{text}</div>
|
| 394 |
+
</div>
|
| 395 |
+
</div>
|
| 396 |
+
</div>
|
| 397 |
+
</div>
|
| 398 |
+
"""
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
# Backward compatibility - expose old function names
|
| 402 |
+
def set_colors(**kwargs):
|
| 403 |
+
"""Legacy function - use configure() instead."""
|
| 404 |
+
_theme.update_colors(**kwargs)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def set_font(font_family: str):
|
| 408 |
+
"""Legacy function - use configure() instead."""
|
| 409 |
+
_theme.update_font(font_family)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def set_meta(**kwargs):
|
| 413 |
+
"""Legacy function - use configure() instead."""
|
| 414 |
+
_theme.update_meta(**kwargs)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# Expose custom_css as a property for backward compatibility
|
| 418 |
+
@property
|
| 419 |
+
def custom_css():
|
| 420 |
+
return _theme.get_css()
|
app/main.py
CHANGED
|
@@ -7,6 +7,9 @@ from app.core.config import settings
|
|
| 7 |
from app.api.routes import detection, health
|
| 8 |
from app.utils.logger import logger
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
# Create FastAPI application
|
| 11 |
app = FastAPI(
|
| 12 |
title=settings.APP_NAME,
|
|
@@ -17,6 +20,12 @@ app = FastAPI(
|
|
| 17 |
openapi_url="/api/openapi.json"
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# Add CORS middleware
|
| 21 |
app.add_middleware(
|
| 22 |
CORSMiddleware,
|
|
|
|
| 7 |
from app.api.routes import detection, health
|
| 8 |
from app.utils.logger import logger
|
| 9 |
|
| 10 |
+
from datetime import timedelta
|
| 11 |
+
from app.core.security import create_access_token
|
| 12 |
+
|
| 13 |
# Create FastAPI application
|
| 14 |
app = FastAPI(
|
| 15 |
title=settings.APP_NAME,
|
|
|
|
| 20 |
openapi_url="/api/openapi.json"
|
| 21 |
)
|
| 22 |
|
| 23 |
+
access_token = create_access_token(
|
| 24 |
+
data={"sub": "test_user"},
|
| 25 |
+
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 26 |
+
)
|
| 27 |
+
print(f"Generated test token: {access_token}")
|
| 28 |
+
|
| 29 |
# Add CORS middleware
|
| 30 |
app.add_middleware(
|
| 31 |
CORSMiddleware,
|
app/services/prediction.py
ADDED
|
@@ -0,0 +1,731 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def sales_prediction_view(data, model, feature_stats, feature_engineered_data):
|
| 11 |
+
"""Display the sales prediction tool interface"""
|
| 12 |
+
|
| 13 |
+
if model is None:
|
| 14 |
+
return gr.Interface(
|
| 15 |
+
fn=lambda: "Model not loaded. Please check if the model file exists.",
|
| 16 |
+
inputs=[],
|
| 17 |
+
outputs=gr.Textbox(label="Error"),
|
| 18 |
+
title="Sales Prediction Tool"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
if feature_engineered_data.empty:
|
| 22 |
+
return gr.Interface(
|
| 23 |
+
fn=lambda: "Feature engineered data not loaded.",
|
| 24 |
+
inputs=[],
|
| 25 |
+
outputs=gr.Textbox(label="Error"),
|
| 26 |
+
title="Sales Prediction Tool"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Determine store and item column names
|
| 30 |
+
store_col = "store_id" if "store_id" in feature_engineered_data.columns else "store"
|
| 31 |
+
item_col = "item_id" if "item_id" in feature_engineered_data.columns else "item"
|
| 32 |
+
|
| 33 |
+
# Check for store/item name columns
|
| 34 |
+
has_store_names = "store_name" in feature_engineered_data.columns
|
| 35 |
+
has_item_names = "item_name" in feature_engineered_data.columns
|
| 36 |
+
|
| 37 |
+
# Create mapping dictionaries for names if available
|
| 38 |
+
store_names, item_names = create_name_mappings(
|
| 39 |
+
feature_engineered_data, store_col, item_col, has_store_names, has_item_names
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Get unique store and item lists
|
| 43 |
+
stores = sorted(feature_engineered_data[store_col].unique())
|
| 44 |
+
|
| 45 |
+
# Create store options
|
| 46 |
+
if has_store_names:
|
| 47 |
+
store_options = [f"{store_id} - {store_names[store_id]}" for store_id in stores]
|
| 48 |
+
else:
|
| 49 |
+
store_options = stores
|
| 50 |
+
|
| 51 |
+
def update_items(store_selection):
|
| 52 |
+
"""Update item dropdown based on selected store"""
|
| 53 |
+
if has_store_names:
|
| 54 |
+
store_id = int(store_selection.split(" - ")[0])
|
| 55 |
+
else:
|
| 56 |
+
store_id = store_selection
|
| 57 |
+
|
| 58 |
+
store_items = feature_engineered_data[feature_engineered_data[store_col] == store_id][item_col].unique()
|
| 59 |
+
|
| 60 |
+
if has_item_names:
|
| 61 |
+
item_options = [
|
| 62 |
+
f"{item_id} - {item_names[item_id]}"
|
| 63 |
+
for item_id in store_items
|
| 64 |
+
if item_id in item_names
|
| 65 |
+
]
|
| 66 |
+
else:
|
| 67 |
+
item_options = sorted(store_items)
|
| 68 |
+
|
| 69 |
+
return gr.Dropdown(choices=item_options)
|
| 70 |
+
|
| 71 |
+
def predict_sales(store_selection, item_selection, prediction_date, is_holiday,
|
| 72 |
+
special_event, promotion_impact, event_impact, clearance_impact,
|
| 73 |
+
launch_impact, temperature, weather_condition, humidity,
|
| 74 |
+
competition_level, supply_chain):
|
| 75 |
+
"""Wrapper function for prediction with all inputs"""
|
| 76 |
+
|
| 77 |
+
# Parse store and item IDs
|
| 78 |
+
if has_store_names:
|
| 79 |
+
store_id = int(store_selection.split(" - ")[0])
|
| 80 |
+
else:
|
| 81 |
+
store_id = store_selection
|
| 82 |
+
|
| 83 |
+
if has_item_names:
|
| 84 |
+
item_id = int(item_selection.split(" - ")[0])
|
| 85 |
+
else:
|
| 86 |
+
item_id = item_selection
|
| 87 |
+
|
| 88 |
+
# Collect prediction inputs
|
| 89 |
+
prediction_inputs = collect_prediction_inputs_from_values(
|
| 90 |
+
prediction_date, is_holiday, special_event, promotion_impact,
|
| 91 |
+
event_impact, clearance_impact, launch_impact, temperature,
|
| 92 |
+
weather_condition, humidity, competition_level, supply_chain
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Generate prediction and return results
|
| 96 |
+
return generate_prediction(
|
| 97 |
+
feature_engineered_data,
|
| 98 |
+
model,
|
| 99 |
+
store_id,
|
| 100 |
+
item_id,
|
| 101 |
+
store_col,
|
| 102 |
+
item_col,
|
| 103 |
+
prediction_inputs,
|
| 104 |
+
has_store_names,
|
| 105 |
+
has_item_names,
|
| 106 |
+
store_names,
|
| 107 |
+
item_names,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Get initial items for first store
|
| 111 |
+
initial_store = store_options[0] if store_options else None
|
| 112 |
+
if initial_store:
|
| 113 |
+
if has_store_names:
|
| 114 |
+
initial_store_id = int(initial_store.split(" - ")[0])
|
| 115 |
+
else:
|
| 116 |
+
initial_store_id = initial_store
|
| 117 |
+
|
| 118 |
+
initial_items = feature_engineered_data[feature_engineered_data[store_col] == initial_store_id][item_col].unique()
|
| 119 |
+
|
| 120 |
+
if has_item_names:
|
| 121 |
+
initial_item_options = [
|
| 122 |
+
f"{item_id} - {item_names[item_id]}"
|
| 123 |
+
for item_id in initial_items
|
| 124 |
+
if item_id in item_names
|
| 125 |
+
]
|
| 126 |
+
else:
|
| 127 |
+
initial_item_options = sorted(initial_items)
|
| 128 |
+
else:
|
| 129 |
+
initial_item_options = []
|
| 130 |
+
|
| 131 |
+
# Build Gradio interface
|
| 132 |
+
with gr.Blocks(title="Sales Prediction Tool") as demo:
|
| 133 |
+
gr.Markdown("# Sales Prediction Tool")
|
| 134 |
+
|
| 135 |
+
with gr.Row():
|
| 136 |
+
# Left column - Product Selection
|
| 137 |
+
with gr.Column(scale=1):
|
| 138 |
+
gr.Markdown("## Product Selection")
|
| 139 |
+
store_dropdown = gr.Dropdown(
|
| 140 |
+
choices=store_options,
|
| 141 |
+
label="Select Store",
|
| 142 |
+
value=initial_store,
|
| 143 |
+
interactive=True,
|
| 144 |
+
allow_custom_value=False
|
| 145 |
+
)
|
| 146 |
+
item_dropdown = gr.Dropdown(
|
| 147 |
+
choices=initial_item_options,
|
| 148 |
+
label="Select Product",
|
| 149 |
+
value=initial_item_options[0] if initial_item_options else None,
|
| 150 |
+
interactive=True,
|
| 151 |
+
allow_custom_value=False
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Update items when store changes
|
| 155 |
+
store_dropdown.change(
|
| 156 |
+
fn=update_items,
|
| 157 |
+
inputs=[store_dropdown],
|
| 158 |
+
outputs=[item_dropdown]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Right column - Prediction Parameters
|
| 162 |
+
with gr.Column(scale=2):
|
| 163 |
+
gr.Markdown("## Prediction Parameters")
|
| 164 |
+
|
| 165 |
+
with gr.Row():
|
| 166 |
+
with gr.Column():
|
| 167 |
+
prediction_date = gr.Textbox(
|
| 168 |
+
label="Prediction Date (YYYY-MM-DD)",
|
| 169 |
+
value=(datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d"),
|
| 170 |
+
interactive=True
|
| 171 |
+
)
|
| 172 |
+
is_holiday = gr.Checkbox(label="Holiday", value=False, interactive=True)
|
| 173 |
+
special_event = gr.Dropdown(
|
| 174 |
+
choices=["None", "Sale/Promotion", "Local Event",
|
| 175 |
+
"Inventory Clearance", "New Product Launch"],
|
| 176 |
+
label="Special Event",
|
| 177 |
+
value="None",
|
| 178 |
+
interactive=True
|
| 179 |
+
)
|
| 180 |
+
promotion_impact = gr.Slider(-50, 100, value=20, label="Promotion Impact (%)", interactive=True)
|
| 181 |
+
event_impact = gr.Slider(-20, 50, value=10, label="Event Impact (%)", interactive=True)
|
| 182 |
+
clearance_impact = gr.Slider(-70, 30, value=-10, label="Clearance Impact (%)", interactive=True)
|
| 183 |
+
launch_impact = gr.Slider(0, 200, value=50, label="Launch Impact (%)", interactive=True)
|
| 184 |
+
|
| 185 |
+
with gr.Column():
|
| 186 |
+
temperature = gr.Slider(-10.0, 40.0, value=20.0, label="Temperature (°C)", interactive=True)
|
| 187 |
+
weather_condition = gr.Dropdown(
|
| 188 |
+
choices=["Clear", "Cloudy", "Rainy", "Snowy", "Stormy"],
|
| 189 |
+
label="Weather Condition",
|
| 190 |
+
value="Clear",
|
| 191 |
+
interactive=True
|
| 192 |
+
)
|
| 193 |
+
gr.Markdown("*Note: Weather impacts vary by product category*")
|
| 194 |
+
|
| 195 |
+
with gr.Column():
|
| 196 |
+
humidity = gr.Slider(0, 100, value=50, label="Humidity (%)", interactive=True)
|
| 197 |
+
competition_level = gr.Radio(
|
| 198 |
+
choices=["Low", "Medium", "High"],
|
| 199 |
+
label="Competition Level",
|
| 200 |
+
value="Medium",
|
| 201 |
+
interactive=True
|
| 202 |
+
)
|
| 203 |
+
supply_chain = gr.Radio(
|
| 204 |
+
choices=["Constrained", "Normal", "Abundant"],
|
| 205 |
+
label="Supply Chain Status",
|
| 206 |
+
value="Normal",
|
| 207 |
+
interactive=True
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
predict_btn = gr.Button("Predict Sales", variant="primary")
|
| 211 |
+
|
| 212 |
+
# Output section
|
| 213 |
+
gr.Markdown("## Prediction Results")
|
| 214 |
+
with gr.Row():
|
| 215 |
+
result_text = gr.Textbox(label="Results", lines=10)
|
| 216 |
+
result_plot1 = gr.Plot(label="Sales History")
|
| 217 |
+
|
| 218 |
+
with gr.Row():
|
| 219 |
+
result_plot2 = gr.Plot(label="Weekly Pattern")
|
| 220 |
+
result_plot3 = gr.Plot(label="Feature Importance")
|
| 221 |
+
|
| 222 |
+
# Connect button to prediction function
|
| 223 |
+
predict_btn.click(
|
| 224 |
+
fn=predict_sales,
|
| 225 |
+
inputs=[
|
| 226 |
+
store_dropdown, item_dropdown, prediction_date, is_holiday,
|
| 227 |
+
special_event, promotion_impact, event_impact, clearance_impact,
|
| 228 |
+
launch_impact, temperature, weather_condition, humidity,
|
| 229 |
+
competition_level, supply_chain
|
| 230 |
+
],
|
| 231 |
+
outputs=[result_text, result_plot1, result_plot2, result_plot3]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return demo
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def create_name_mappings(df, store_col, item_col, has_store_names, has_item_names):
|
| 238 |
+
"""Create mapping dictionaries for store and item names"""
|
| 239 |
+
|
| 240 |
+
store_names = {}
|
| 241 |
+
item_names = {}
|
| 242 |
+
|
| 243 |
+
if has_store_names:
|
| 244 |
+
# Create store ID to name mapping
|
| 245 |
+
for _, row in df[[store_col, "store_name"]].drop_duplicates().iterrows():
|
| 246 |
+
store_names[row[store_col]] = row["store_name"]
|
| 247 |
+
|
| 248 |
+
if has_item_names:
|
| 249 |
+
# Create item ID to name mapping
|
| 250 |
+
for _, row in df[[item_col, "item_name"]].drop_duplicates().iterrows():
|
| 251 |
+
item_names[row[item_col]] = row["item_name"]
|
| 252 |
+
|
| 253 |
+
return store_names, item_names
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def create_product_selection_sidebar(
|
| 257 |
+
df,
|
| 258 |
+
stores,
|
| 259 |
+
store_col,
|
| 260 |
+
item_col,
|
| 261 |
+
has_store_names,
|
| 262 |
+
has_item_names,
|
| 263 |
+
store_names,
|
| 264 |
+
item_names,
|
| 265 |
+
):
|
| 266 |
+
"""Create sidebar for store and product selection"""
|
| 267 |
+
# This function is kept for compatibility but not used in Gradio version
|
| 268 |
+
# The logic is integrated into sales_prediction_view
|
| 269 |
+
pass
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def collect_prediction_inputs():
|
| 273 |
+
"""Collect all prediction inputs from the user"""
|
| 274 |
+
# This function is kept for compatibility but adapted for Gradio
|
| 275 |
+
# See collect_prediction_inputs_from_values instead
|
| 276 |
+
pass
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def collect_prediction_inputs_from_values(
|
| 280 |
+
prediction_date_str, is_holiday, special_event, promotion_impact,
|
| 281 |
+
event_impact, clearance_impact, launch_impact, temperature,
|
| 282 |
+
weather_condition, humidity, competition_level, supply_chain
|
| 283 |
+
):
|
| 284 |
+
"""Collect all prediction inputs from provided values"""
|
| 285 |
+
|
| 286 |
+
# Parse date
|
| 287 |
+
prediction_date = datetime.strptime(prediction_date_str, "%Y-%m-%d").date()
|
| 288 |
+
|
| 289 |
+
# Calculate special event factor
|
| 290 |
+
special_event_factor = 1.0
|
| 291 |
+
if special_event == "Sale/Promotion":
|
| 292 |
+
special_event_factor = promotion_impact / 100 + 1.0
|
| 293 |
+
elif special_event == "Local Event":
|
| 294 |
+
special_event_factor = event_impact / 100 + 1.0
|
| 295 |
+
elif special_event == "Inventory Clearance":
|
| 296 |
+
special_event_factor = clearance_impact / 100 + 1.0
|
| 297 |
+
elif special_event == "New Product Launch":
|
| 298 |
+
special_event_factor = launch_impact / 100 + 1.0
|
| 299 |
+
|
| 300 |
+
# Determine temperature category
|
| 301 |
+
if temperature < 15:
|
| 302 |
+
temp_category = "Cool"
|
| 303 |
+
elif temperature < 25:
|
| 304 |
+
temp_category = "Warm"
|
| 305 |
+
else:
|
| 306 |
+
temp_category = "Hot"
|
| 307 |
+
|
| 308 |
+
# Determine humidity level
|
| 309 |
+
if humidity < 40:
|
| 310 |
+
humidity_level = "Low"
|
| 311 |
+
elif humidity < 70:
|
| 312 |
+
humidity_level = "Medium"
|
| 313 |
+
else:
|
| 314 |
+
humidity_level = "High"
|
| 315 |
+
|
| 316 |
+
# Calculate derived parameters
|
| 317 |
+
month = prediction_date.month
|
| 318 |
+
if month in [3, 4, 5]:
|
| 319 |
+
season = "spring"
|
| 320 |
+
elif month in [6, 7, 8]:
|
| 321 |
+
season = "summer"
|
| 322 |
+
elif month in [9, 10, 11]:
|
| 323 |
+
season = "fall"
|
| 324 |
+
else:
|
| 325 |
+
season = "winter"
|
| 326 |
+
|
| 327 |
+
quarter = (prediction_date.month - 1) // 3 + 1
|
| 328 |
+
day_of_week = prediction_date.weekday()
|
| 329 |
+
is_weekend = 1 if day_of_week >= 5 else 0
|
| 330 |
+
|
| 331 |
+
# Calculate factors
|
| 332 |
+
weather_factor = {
|
| 333 |
+
"Clear": 1.0,
|
| 334 |
+
"Cloudy": 0.95,
|
| 335 |
+
"Rainy": 0.9,
|
| 336 |
+
"Snowy": 0.8,
|
| 337 |
+
"Stormy": 0.7,
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
competition_factor = {"Low": 1.1, "Medium": 1.0, "High": 0.9}
|
| 341 |
+
supply_factor = {"Constrained": 0.9, "Normal": 1.0, "Abundant": 1.05}
|
| 342 |
+
weekend_factor = 1.15 if is_weekend else 1.0
|
| 343 |
+
|
| 344 |
+
# Combined adjustment factor
|
| 345 |
+
adjustment_factor = (
|
| 346 |
+
special_event_factor
|
| 347 |
+
* weather_factor.get(weather_condition, 1.0)
|
| 348 |
+
* competition_factor.get(competition_level, 1.0)
|
| 349 |
+
* supply_factor.get(supply_chain, 1.0)
|
| 350 |
+
* weekend_factor
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
return {
|
| 354 |
+
"date": prediction_date,
|
| 355 |
+
"is_holiday": is_holiday,
|
| 356 |
+
"temperature": temperature,
|
| 357 |
+
"temp_category": temp_category,
|
| 358 |
+
"humidity": humidity,
|
| 359 |
+
"humidity_level": humidity_level,
|
| 360 |
+
"season": season,
|
| 361 |
+
"quarter": quarter,
|
| 362 |
+
"day_of_week": day_of_week,
|
| 363 |
+
"is_weekend": is_weekend,
|
| 364 |
+
"special_event": special_event,
|
| 365 |
+
"weather_condition": weather_condition,
|
| 366 |
+
"competition_level": competition_level,
|
| 367 |
+
"supply_chain": supply_chain,
|
| 368 |
+
"adjustment_factor": adjustment_factor,
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def generate_prediction(
|
| 373 |
+
feature_engineered_data,
|
| 374 |
+
model,
|
| 375 |
+
store_id,
|
| 376 |
+
item_id,
|
| 377 |
+
store_col,
|
| 378 |
+
item_col,
|
| 379 |
+
prediction_inputs,
|
| 380 |
+
has_store_names,
|
| 381 |
+
has_item_names,
|
| 382 |
+
store_names,
|
| 383 |
+
item_names,
|
| 384 |
+
):
|
| 385 |
+
"""Generate sales prediction and display results"""
|
| 386 |
+
|
| 387 |
+
try:
|
| 388 |
+
# Find recent samples for the same store-item combination
|
| 389 |
+
recent_samples = (
|
| 390 |
+
feature_engineered_data[
|
| 391 |
+
(feature_engineered_data[store_col] == store_id)
|
| 392 |
+
& (feature_engineered_data[item_col] == item_id)
|
| 393 |
+
]
|
| 394 |
+
.sort_values("date", ascending=False)
|
| 395 |
+
.head(5)
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
if recent_samples.empty:
|
| 399 |
+
return "No historical data found for this product-store combination.", None, None, None
|
| 400 |
+
|
| 401 |
+
# Create input based on most recent sample
|
| 402 |
+
input_row = prepare_prediction_input(recent_samples, prediction_inputs)
|
| 403 |
+
|
| 404 |
+
# Create DataFrame for prediction
|
| 405 |
+
input_df = pd.DataFrame([input_row])
|
| 406 |
+
|
| 407 |
+
# Get the features that the model expects
|
| 408 |
+
if hasattr(model, "feature_name_"):
|
| 409 |
+
model_features = model.feature_name_
|
| 410 |
+
else:
|
| 411 |
+
model_features = [
|
| 412 |
+
col
|
| 413 |
+
for col in input_df.columns
|
| 414 |
+
if col
|
| 415 |
+
not in ["sales", "date", "variation_factor", "adjustment_factor"]
|
| 416 |
+
]
|
| 417 |
+
|
| 418 |
+
# Select only the features used by the model
|
| 419 |
+
X_pred = input_df[model_features]
|
| 420 |
+
|
| 421 |
+
# Make prediction
|
| 422 |
+
base_prediction = model.predict(X_pred)[0]
|
| 423 |
+
|
| 424 |
+
# Apply adjustment factors
|
| 425 |
+
adjusted_prediction = base_prediction
|
| 426 |
+
|
| 427 |
+
# Apply the variation factor if it exists
|
| 428 |
+
if "variation_factor" in input_row:
|
| 429 |
+
adjusted_prediction *= input_row["variation_factor"]
|
| 430 |
+
|
| 431 |
+
# Apply adjustment factor from user inputs
|
| 432 |
+
if "adjustment_factor" in prediction_inputs:
|
| 433 |
+
adjusted_prediction *= prediction_inputs["adjustment_factor"]
|
| 434 |
+
|
| 435 |
+
# Display results
|
| 436 |
+
result_text, plot1, plot2, plot3 = display_prediction_results(
|
| 437 |
+
adjusted_prediction,
|
| 438 |
+
base_prediction,
|
| 439 |
+
store_id,
|
| 440 |
+
item_id,
|
| 441 |
+
prediction_inputs,
|
| 442 |
+
feature_engineered_data,
|
| 443 |
+
store_col,
|
| 444 |
+
item_col,
|
| 445 |
+
has_store_names,
|
| 446 |
+
has_item_names,
|
| 447 |
+
store_names,
|
| 448 |
+
item_names,
|
| 449 |
+
model,
|
| 450 |
+
model_features,
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
return result_text, plot1, plot2, plot3
|
| 454 |
+
|
| 455 |
+
except Exception as e:
|
| 456 |
+
import traceback
|
| 457 |
+
error_msg = f"Error making prediction: {str(e)}\n\n{traceback.format_exc()}"
|
| 458 |
+
return error_msg, None, None, None
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def prepare_prediction_input(recent_samples, prediction_inputs):
|
| 462 |
+
"""Prepare input row for prediction based on recent sample and user inputs"""
|
| 463 |
+
|
| 464 |
+
# Create input row based on most recent sample
|
| 465 |
+
input_row = recent_samples.iloc[0].copy()
|
| 466 |
+
|
| 467 |
+
# Update with user inputs
|
| 468 |
+
input_row["date"] = pd.to_datetime(prediction_inputs["date"])
|
| 469 |
+
input_row["day"] = prediction_inputs["date"].day
|
| 470 |
+
input_row["month"] = prediction_inputs["date"].month
|
| 471 |
+
input_row["year"] = prediction_inputs["date"].year
|
| 472 |
+
input_row["quarter"] = prediction_inputs["quarter"]
|
| 473 |
+
input_row["is_holiday"] = int(prediction_inputs["is_holiday"])
|
| 474 |
+
|
| 475 |
+
# Add day of week information
|
| 476 |
+
input_row["day_of_week"] = input_row["date"].dayofweek
|
| 477 |
+
input_row["day_of_month"] = input_row["date"].day
|
| 478 |
+
input_row["is_weekend"] = 1 if input_row["day_of_week"] >= 5 else 0
|
| 479 |
+
|
| 480 |
+
# Update actual temperature and humidity values if they exist in the dataframe
|
| 481 |
+
if "temperature" in input_row:
|
| 482 |
+
input_row["temperature"] = prediction_inputs["temperature"]
|
| 483 |
+
|
| 484 |
+
if "humidity" in input_row:
|
| 485 |
+
input_row["humidity"] = prediction_inputs["humidity"]
|
| 486 |
+
|
| 487 |
+
# Update temperature and humidity categories
|
| 488 |
+
for category in ["Cool", "Warm", "Hot"]:
|
| 489 |
+
if f"temp_category_{category}" in input_row:
|
| 490 |
+
input_row[f"temp_category_{category}"] = (
|
| 491 |
+
1 if category == prediction_inputs["temp_category"] else 0
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
for level in ["Low", "Medium", "High"]:
|
| 495 |
+
if f"humidity_level_{level}" in input_row:
|
| 496 |
+
input_row[f"humidity_level_{level}"] = (
|
| 497 |
+
1 if level == prediction_inputs["humidity_level"] else 0
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
# Update season
|
| 501 |
+
for s in ["spring", "summer", "fall", "winter", "wet"]:
|
| 502 |
+
if f"season_{s}" in input_row:
|
| 503 |
+
input_row[f"season_{s}"] = 1 if s == prediction_inputs["season"] else 0
|
| 504 |
+
|
| 505 |
+
# Set a random variation factor
|
| 506 |
+
variation_factor = 1.0 + np.random.uniform(-0.02, 0.02)
|
| 507 |
+
input_row["variation_factor"] = variation_factor
|
| 508 |
+
|
| 509 |
+
return input_row
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def display_prediction_results(
|
| 513 |
+
prediction_value,
|
| 514 |
+
base_prediction,
|
| 515 |
+
store_id,
|
| 516 |
+
item_id,
|
| 517 |
+
prediction_inputs,
|
| 518 |
+
historical_data,
|
| 519 |
+
store_col,
|
| 520 |
+
item_col,
|
| 521 |
+
has_store_names,
|
| 522 |
+
has_item_names,
|
| 523 |
+
store_names,
|
| 524 |
+
item_names,
|
| 525 |
+
model,
|
| 526 |
+
model_features,
|
| 527 |
+
):
|
| 528 |
+
"""Display prediction results with visualizations"""
|
| 529 |
+
|
| 530 |
+
# Build result text
|
| 531 |
+
result_lines = []
|
| 532 |
+
result_lines.append("=" * 50)
|
| 533 |
+
result_lines.append("PREDICTION RESULTS")
|
| 534 |
+
result_lines.append("=" * 50)
|
| 535 |
+
result_lines.append(f"\nPredicted Sales: ${prediction_value:,.2f}")
|
| 536 |
+
|
| 537 |
+
if has_store_names:
|
| 538 |
+
result_lines.append(f"Store: {store_names[store_id]}")
|
| 539 |
+
else:
|
| 540 |
+
result_lines.append(f"Store ID: {store_id}")
|
| 541 |
+
|
| 542 |
+
if has_item_names:
|
| 543 |
+
result_lines.append(f"Product: {item_names[item_id]}")
|
| 544 |
+
else:
|
| 545 |
+
result_lines.append(f"Product ID: {item_id}")
|
| 546 |
+
|
| 547 |
+
result_lines.append(f"Date: {prediction_inputs['date'].strftime('%B %d, %Y')}")
|
| 548 |
+
result_lines.append(f"Season: {prediction_inputs['season'].capitalize()}")
|
| 549 |
+
if prediction_inputs["is_holiday"]:
|
| 550 |
+
result_lines.append("Holiday: Yes")
|
| 551 |
+
|
| 552 |
+
# Adjustment details
|
| 553 |
+
result_lines.append(f"\n{'='*50}")
|
| 554 |
+
result_lines.append("ADJUSTMENT DETAILS")
|
| 555 |
+
result_lines.append("="*50)
|
| 556 |
+
result_lines.append(f"Base prediction: ${base_prediction:.2f}")
|
| 557 |
+
result_lines.append(f"Final prediction: ${prediction_value:.2f}")
|
| 558 |
+
result_lines.append(f"Total adjustment: {prediction_inputs['adjustment_factor']:.2f}x")
|
| 559 |
+
result_lines.append(f"\nEvent: {prediction_inputs['special_event']}")
|
| 560 |
+
result_lines.append(f"Weather: {prediction_inputs['weather_condition']}")
|
| 561 |
+
result_lines.append(f"Competition: {prediction_inputs['competition_level']}")
|
| 562 |
+
result_lines.append(f"Supply: {prediction_inputs['supply_chain']}")
|
| 563 |
+
result_lines.append(f"Weekend: {'Yes' if prediction_inputs['is_weekend'] else 'No'}")
|
| 564 |
+
result_lines.append(f"Holiday: {'Yes' if prediction_inputs['is_holiday'] else 'No'}")
|
| 565 |
+
|
| 566 |
+
# Get historical context
|
| 567 |
+
historical = historical_data[
|
| 568 |
+
(historical_data[store_col] == store_id)
|
| 569 |
+
& (historical_data[item_col] == item_id)
|
| 570 |
+
].sort_values("date")
|
| 571 |
+
|
| 572 |
+
if "sales" in historical.columns and len(historical) > 0:
|
| 573 |
+
last_value = historical["sales"].iloc[-1]
|
| 574 |
+
last_date = historical["date"].iloc[-1]
|
| 575 |
+
avg_sales = historical["sales"].mean()
|
| 576 |
+
max_sales = historical["sales"].max()
|
| 577 |
+
max_date = historical.loc[historical["sales"].idxmax(), "date"]
|
| 578 |
+
|
| 579 |
+
result_lines.append(f"\n{'='*50}")
|
| 580 |
+
result_lines.append("HISTORICAL CONTEXT")
|
| 581 |
+
result_lines.append("="*50)
|
| 582 |
+
result_lines.append(f"Historical Average: ${avg_sales:,.2f}")
|
| 583 |
+
result_lines.append(f"Period: {historical['date'].min().strftime('%b %d, %Y')} to {historical['date'].max().strftime('%b %d, %Y')}")
|
| 584 |
+
result_lines.append(f"\nLast Recorded Sales: ${last_value:,.2f}")
|
| 585 |
+
result_lines.append(f"Date: {last_date.strftime('%b %d, %Y')}")
|
| 586 |
+
result_lines.append(f"\nHistorical Maximum: ${max_sales:,.2f}")
|
| 587 |
+
result_lines.append(f"Date: {max_date.strftime('%b %d, %Y')}")
|
| 588 |
+
|
| 589 |
+
result_text = "\n".join(result_lines)
|
| 590 |
+
|
| 591 |
+
# Create visualizations
|
| 592 |
+
plot1 = display_historical_context(historical, prediction_inputs["date"], prediction_value)
|
| 593 |
+
plot2 = display_weekly_pattern(historical, prediction_inputs["date"])
|
| 594 |
+
plot3 = display_feature_importance(model, model_features)
|
| 595 |
+
|
| 596 |
+
return result_text, plot1, plot2, plot3
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def display_historical_context(historical_data, prediction_date, prediction_value):
|
| 600 |
+
"""Display historical context visualizations"""
|
| 601 |
+
|
| 602 |
+
if "sales" not in historical_data.columns or historical_data.empty:
|
| 603 |
+
return None
|
| 604 |
+
|
| 605 |
+
# Limit to last 2 months
|
| 606 |
+
last_date = historical_data["date"].max()
|
| 607 |
+
two_months_ago = last_date - pd.Timedelta(days=60)
|
| 608 |
+
recent_history = historical_data[historical_data["date"] >= two_months_ago].copy()
|
| 609 |
+
|
| 610 |
+
if recent_history.empty:
|
| 611 |
+
return None
|
| 612 |
+
|
| 613 |
+
# Plot recent sales history
|
| 614 |
+
fig, ax = plt.subplots(figsize=(6, 2.5))
|
| 615 |
+
|
| 616 |
+
# Plot historical sales
|
| 617 |
+
ax.plot(
|
| 618 |
+
recent_history["date"],
|
| 619 |
+
recent_history["sales"],
|
| 620 |
+
"b-",
|
| 621 |
+
label="Sales",
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
# Add the prediction point
|
| 625 |
+
ax.scatter(
|
| 626 |
+
prediction_date,
|
| 627 |
+
prediction_value,
|
| 628 |
+
color="red",
|
| 629 |
+
s=60,
|
| 630 |
+
label="Prediction",
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
# Add moving average
|
| 634 |
+
if len(recent_history) > 7:
|
| 635 |
+
recent_history["MA7"] = recent_history["sales"].rolling(window=7).mean()
|
| 636 |
+
ax.plot(
|
| 637 |
+
recent_history["date"],
|
| 638 |
+
recent_history["MA7"],
|
| 639 |
+
"g--",
|
| 640 |
+
label="7-Day Avg",
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
ax.set_xlabel("")
|
| 644 |
+
ax.set_ylabel("Sales ($)")
|
| 645 |
+
ax.set_title("Last 60 Days Sales History")
|
| 646 |
+
ax.legend(loc="upper left", fontsize="x-small")
|
| 647 |
+
fig.autofmt_xdate(rotation=45)
|
| 648 |
+
fig.tight_layout()
|
| 649 |
+
|
| 650 |
+
return fig
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def display_weekly_pattern(recent_history, prediction_date):
|
| 654 |
+
"""Display weekly sales pattern visualization"""
|
| 655 |
+
|
| 656 |
+
if len(recent_history) < 7:
|
| 657 |
+
return None
|
| 658 |
+
|
| 659 |
+
# Add day of week
|
| 660 |
+
recent_history = recent_history.copy()
|
| 661 |
+
recent_history["day_of_week"] = recent_history["date"].dt.dayofweek
|
| 662 |
+
day_names = [
|
| 663 |
+
"Monday",
|
| 664 |
+
"Tuesday",
|
| 665 |
+
"Wednesday",
|
| 666 |
+
"Thursday",
|
| 667 |
+
"Friday",
|
| 668 |
+
"Saturday",
|
| 669 |
+
"Sunday",
|
| 670 |
+
]
|
| 671 |
+
|
| 672 |
+
# Group by day of week
|
| 673 |
+
day_sales = recent_history.groupby("day_of_week")["sales"].mean()
|
| 674 |
+
day_sales_df = pd.DataFrame(
|
| 675 |
+
{
|
| 676 |
+
"day_name": [day_names[i] for i in range(7) if i in day_sales.index],
|
| 677 |
+
"sales": [day_sales[i] for i in range(7) if i in day_sales.index],
|
| 678 |
+
}
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# Plot
|
| 682 |
+
fig, ax = plt.subplots(figsize=(6, 2.5))
|
| 683 |
+
|
| 684 |
+
# Plot day of week pattern
|
| 685 |
+
sns.barplot(x="day_name", y="sales", data=day_sales_df, ax=ax)
|
| 686 |
+
|
| 687 |
+
# Highlight the day of the prediction
|
| 688 |
+
prediction_day = prediction_date.weekday()
|
| 689 |
+
for i, patch in enumerate(ax.patches):
|
| 690 |
+
if day_sales_df.iloc[i]["day_name"] == day_names[prediction_day]:
|
| 691 |
+
patch.set_facecolor("red")
|
| 692 |
+
|
| 693 |
+
ax.set_xlabel("")
|
| 694 |
+
ax.set_ylabel("Avg Sales ($)")
|
| 695 |
+
ax.set_title("Sales by Day of Week")
|
| 696 |
+
plt.xticks(rotation=45, fontsize=8)
|
| 697 |
+
fig.tight_layout()
|
| 698 |
+
|
| 699 |
+
return fig
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def display_feature_importance(model, model_features):
|
| 703 |
+
"""Display feature importance visualization"""
|
| 704 |
+
|
| 705 |
+
if not hasattr(model, "feature_importances_"):
|
| 706 |
+
return None
|
| 707 |
+
|
| 708 |
+
# Get feature importances
|
| 709 |
+
importances = model.feature_importances_
|
| 710 |
+
|
| 711 |
+
# Create DataFrame with feature importances
|
| 712 |
+
importance_df = (
|
| 713 |
+
pd.DataFrame({"Feature": model_features, "Importance": importances})
|
| 714 |
+
.sort_values("Importance", ascending=False)
|
| 715 |
+
.head(8)
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# Clean feature names for display
|
| 719 |
+
importance_df["Feature"] = importance_df["Feature"].apply(
|
| 720 |
+
lambda x: x.replace("_", " ").title()
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
# Plot feature importances
|
| 724 |
+
fig, ax = plt.subplots(figsize=(6, 2.5))
|
| 725 |
+
sns.barplot(x="Importance", y="Feature", data=importance_df, ax=ax)
|
| 726 |
+
ax.set_title("Top Factors Influencing Sales Prediction")
|
| 727 |
+
plt.xticks(fontsize=8)
|
| 728 |
+
plt.yticks(fontsize=8)
|
| 729 |
+
fig.tight_layout()
|
| 730 |
+
|
| 731 |
+
return fig
|
app/utils/data_generator.py
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datetime import datetime, timedelta
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
# Set random seed for reproducibility
|
| 8 |
+
np.random.seed(2025)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_store_data():
|
| 12 |
+
"""Generate store data"""
|
| 13 |
+
|
| 14 |
+
# Define provinces and stores
|
| 15 |
+
provinces = ["Hanoi", "Ho Chi Minh City"]
|
| 16 |
+
|
| 17 |
+
stores = [
|
| 18 |
+
# Hanoi stores
|
| 19 |
+
{"id": 1, "name": "Hoan Kiem Market", "province": "Hanoi"},
|
| 20 |
+
{"id": 2, "name": "Ba Dinh Supermarket", "province": "Hanoi"},
|
| 21 |
+
{"id": 3, "name": "Dong Da Mall", "province": "Hanoi"},
|
| 22 |
+
{"id": 4, "name": "Tay Ho Store", "province": "Hanoi"},
|
| 23 |
+
{"id": 5, "name": "Long Bien Shop", "province": "Hanoi"},
|
| 24 |
+
# Ho Chi Minh City stores
|
| 25 |
+
{"id": 6, "name": "District 1 Market", "province": "Ho Chi Minh City"},
|
| 26 |
+
{"id": 7, "name": "Ben Thanh Store", "province": "Ho Chi Minh City"},
|
| 27 |
+
{"id": 8, "name": "Saigon Supermarket", "province": "Ho Chi Minh City"},
|
| 28 |
+
{"id": 9, "name": "Phu Nhuan Shop", "province": "Ho Chi Minh City"},
|
| 29 |
+
{"id": 10, "name": "Binh Thanh Market", "province": "Ho Chi Minh City"},
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
return provinces, stores
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def generate_item_data():
|
| 36 |
+
"""Generate item data"""
|
| 37 |
+
|
| 38 |
+
# Define categories and items
|
| 39 |
+
categories = [
|
| 40 |
+
"Staples",
|
| 41 |
+
"Dairy & Frozen",
|
| 42 |
+
"Beverages & Snacks",
|
| 43 |
+
"Household & Personal Care",
|
| 44 |
+
"Baby & Health",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
items = [
|
| 48 |
+
# Staples
|
| 49 |
+
{
|
| 50 |
+
"id": 1,
|
| 51 |
+
"name": "Rice",
|
| 52 |
+
"category": "Staples",
|
| 53 |
+
"base_price": 20.0,
|
| 54 |
+
"base_sales": 15,
|
| 55 |
+
"volatility": 0.3,
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"id": 2,
|
| 59 |
+
"name": "Noodles",
|
| 60 |
+
"category": "Staples",
|
| 61 |
+
"base_price": 15.0,
|
| 62 |
+
"base_sales": 12,
|
| 63 |
+
"volatility": 0.25,
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"id": 3,
|
| 67 |
+
"name": "Bread",
|
| 68 |
+
"category": "Staples",
|
| 69 |
+
"base_price": 10.0,
|
| 70 |
+
"base_sales": 20,
|
| 71 |
+
"volatility": 0.4,
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"id": 4,
|
| 75 |
+
"name": "Flour",
|
| 76 |
+
"category": "Staples",
|
| 77 |
+
"base_price": 12.0,
|
| 78 |
+
"base_sales": 8,
|
| 79 |
+
"volatility": 0.2,
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"id": 5,
|
| 83 |
+
"name": "Cooking Oil",
|
| 84 |
+
"category": "Staples",
|
| 85 |
+
"base_price": 25.0,
|
| 86 |
+
"base_sales": 10,
|
| 87 |
+
"volatility": 0.15,
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"id": 6,
|
| 91 |
+
"name": "Sugar",
|
| 92 |
+
"category": "Staples",
|
| 93 |
+
"base_price": 8.0,
|
| 94 |
+
"base_sales": 7,
|
| 95 |
+
"volatility": 0.1,
|
| 96 |
+
},
|
| 97 |
+
# Dairy & Frozen
|
| 98 |
+
{
|
| 99 |
+
"id": 7,
|
| 100 |
+
"name": "Milk",
|
| 101 |
+
"category": "Dairy & Frozen",
|
| 102 |
+
"base_price": 18.0,
|
| 103 |
+
"base_sales": 30,
|
| 104 |
+
"volatility": 0.35,
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"id": 8,
|
| 108 |
+
"name": "Cheese",
|
| 109 |
+
"category": "Dairy & Frozen",
|
| 110 |
+
"base_price": 35.0,
|
| 111 |
+
"base_sales": 12,
|
| 112 |
+
"volatility": 0.3,
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"id": 9,
|
| 116 |
+
"name": "Yogurt",
|
| 117 |
+
"category": "Dairy & Frozen",
|
| 118 |
+
"base_price": 12.0,
|
| 119 |
+
"base_sales": 25,
|
| 120 |
+
"volatility": 0.4,
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"id": 10,
|
| 124 |
+
"name": "Ice Cream",
|
| 125 |
+
"category": "Dairy & Frozen",
|
| 126 |
+
"base_price": 30.0,
|
| 127 |
+
"base_sales": 15,
|
| 128 |
+
"volatility": 0.5,
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"id": 11,
|
| 132 |
+
"name": "Frozen Vegetables",
|
| 133 |
+
"category": "Dairy & Frozen",
|
| 134 |
+
"base_price": 22.0,
|
| 135 |
+
"base_sales": 10,
|
| 136 |
+
"volatility": 0.25,
|
| 137 |
+
},
|
| 138 |
+
# Beverages & Snacks
|
| 139 |
+
{
|
| 140 |
+
"id": 12,
|
| 141 |
+
"name": "Soda",
|
| 142 |
+
"category": "Beverages & Snacks",
|
| 143 |
+
"base_price": 15.0,
|
| 144 |
+
"base_sales": 40,
|
| 145 |
+
"volatility": 0.45,
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"id": 13,
|
| 149 |
+
"name": "Juice",
|
| 150 |
+
"category": "Beverages & Snacks",
|
| 151 |
+
"base_price": 20.0,
|
| 152 |
+
"base_sales": 30,
|
| 153 |
+
"volatility": 0.4,
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"id": 14,
|
| 157 |
+
"name": "Water",
|
| 158 |
+
"category": "Beverages & Snacks",
|
| 159 |
+
"base_price": 10.0,
|
| 160 |
+
"base_sales": 50,
|
| 161 |
+
"volatility": 0.3,
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"id": 15,
|
| 165 |
+
"name": "Coffee",
|
| 166 |
+
"category": "Beverages & Snacks",
|
| 167 |
+
"base_price": 45.0,
|
| 168 |
+
"base_sales": 20,
|
| 169 |
+
"volatility": 0.25,
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"id": 16,
|
| 173 |
+
"name": "Tea",
|
| 174 |
+
"category": "Beverages & Snacks",
|
| 175 |
+
"base_price": 35.0,
|
| 176 |
+
"base_sales": 15,
|
| 177 |
+
"volatility": 0.2,
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"id": 17,
|
| 181 |
+
"name": "Chips",
|
| 182 |
+
"category": "Beverages & Snacks",
|
| 183 |
+
"base_price": 12.0,
|
| 184 |
+
"base_sales": 35,
|
| 185 |
+
"volatility": 0.45,
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"id": 18,
|
| 189 |
+
"name": "Cookies",
|
| 190 |
+
"category": "Beverages & Snacks",
|
| 191 |
+
"base_price": 18.0,
|
| 192 |
+
"base_sales": 30,
|
| 193 |
+
"volatility": 0.4,
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"id": 19,
|
| 197 |
+
"name": "Chocolate",
|
| 198 |
+
"category": "Beverages & Snacks",
|
| 199 |
+
"base_price": 22.0,
|
| 200 |
+
"base_sales": 25,
|
| 201 |
+
"volatility": 0.35,
|
| 202 |
+
},
|
| 203 |
+
# Household & Personal Care
|
| 204 |
+
{
|
| 205 |
+
"id": 20,
|
| 206 |
+
"name": "Soap",
|
| 207 |
+
"category": "Household & Personal Care",
|
| 208 |
+
"base_price": 8.0,
|
| 209 |
+
"base_sales": 20,
|
| 210 |
+
"volatility": 0.2,
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"id": 21,
|
| 214 |
+
"name": "Shampoo",
|
| 215 |
+
"category": "Household & Personal Care",
|
| 216 |
+
"base_price": 25.0,
|
| 217 |
+
"base_sales": 15,
|
| 218 |
+
"volatility": 0.25,
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"id": 22,
|
| 222 |
+
"name": "Toothpaste",
|
| 223 |
+
"category": "Household & Personal Care",
|
| 224 |
+
"base_price": 15.0,
|
| 225 |
+
"base_sales": 18,
|
| 226 |
+
"volatility": 0.15,
|
| 227 |
+
},
|
| 228 |
+
{
|
| 229 |
+
"id": 23,
|
| 230 |
+
"name": "Laundry Detergent",
|
| 231 |
+
"category": "Household & Personal Care",
|
| 232 |
+
"base_price": 40.0,
|
| 233 |
+
"base_sales": 12,
|
| 234 |
+
"volatility": 0.2,
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"id": 24,
|
| 238 |
+
"name": "Paper Towels",
|
| 239 |
+
"category": "Household & Personal Care",
|
| 240 |
+
"base_price": 20.0,
|
| 241 |
+
"base_sales": 14,
|
| 242 |
+
"volatility": 0.3,
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"id": 25,
|
| 246 |
+
"name": "Toilet Paper",
|
| 247 |
+
"category": "Household & Personal Care",
|
| 248 |
+
"base_price": 25.0,
|
| 249 |
+
"base_sales": 16,
|
| 250 |
+
"volatility": 0.25,
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"id": 26,
|
| 254 |
+
"name": "Trash Bags",
|
| 255 |
+
"category": "Household & Personal Care",
|
| 256 |
+
"base_price": 18.0,
|
| 257 |
+
"base_sales": 10,
|
| 258 |
+
"volatility": 0.15,
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"id": 27,
|
| 262 |
+
"name": "Dishwashing Liquid",
|
| 263 |
+
"category": "Household & Personal Care",
|
| 264 |
+
"base_price": 15.0,
|
| 265 |
+
"base_sales": 11,
|
| 266 |
+
"volatility": 0.2,
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"id": 28,
|
| 270 |
+
"name": "All-Purpose Cleaner",
|
| 271 |
+
"category": "Household & Personal Care",
|
| 272 |
+
"base_price": 22.0,
|
| 273 |
+
"base_sales": 9,
|
| 274 |
+
"volatility": 0.15,
|
| 275 |
+
},
|
| 276 |
+
# Baby & Health
|
| 277 |
+
{
|
| 278 |
+
"id": 29,
|
| 279 |
+
"name": "Diapers",
|
| 280 |
+
"category": "Baby & Health",
|
| 281 |
+
"base_price": 45.0,
|
| 282 |
+
"base_sales": 25,
|
| 283 |
+
"volatility": 0.3,
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"id": 30,
|
| 287 |
+
"name": "Baby Food",
|
| 288 |
+
"category": "Baby & Health",
|
| 289 |
+
"base_price": 20.0,
|
| 290 |
+
"base_sales": 15,
|
| 291 |
+
"volatility": 0.25,
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"id": 31,
|
| 295 |
+
"name": "Baby Wipes",
|
| 296 |
+
"category": "Baby & Health",
|
| 297 |
+
"base_price": 15.0,
|
| 298 |
+
"base_sales": 20,
|
| 299 |
+
"volatility": 0.2,
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"id": 32,
|
| 303 |
+
"name": "Pain Relievers",
|
| 304 |
+
"category": "Baby & Health",
|
| 305 |
+
"base_price": 30.0,
|
| 306 |
+
"base_sales": 10,
|
| 307 |
+
"volatility": 0.15,
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"id": 33,
|
| 311 |
+
"name": "Vitamins",
|
| 312 |
+
"category": "Baby & Health",
|
| 313 |
+
"base_price": 40.0,
|
| 314 |
+
"base_sales": 8,
|
| 315 |
+
"volatility": 0.2,
|
| 316 |
+
},
|
| 317 |
+
{
|
| 318 |
+
"id": 34,
|
| 319 |
+
"name": "Cold & Flu Medicine",
|
| 320 |
+
"category": "Baby & Health",
|
| 321 |
+
"base_price": 35.0,
|
| 322 |
+
"base_sales": 7,
|
| 323 |
+
"volatility": 0.4,
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"id": 35,
|
| 327 |
+
"name": "First Aid Kit",
|
| 328 |
+
"category": "Baby & Health",
|
| 329 |
+
"base_price": 50.0,
|
| 330 |
+
"base_sales": 5,
|
| 331 |
+
"volatility": 0.1,
|
| 332 |
+
},
|
| 333 |
+
]
|
| 334 |
+
|
| 335 |
+
return categories, items
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def calculate_daily_sales(date, store, item, weather_data=None):
|
| 339 |
+
"""
|
| 340 |
+
Calculate daily sales based on various factors.
|
| 341 |
+
Returns an integer value for sales quantity.
|
| 342 |
+
"""
|
| 343 |
+
# Base sales for this item
|
| 344 |
+
base_sales = item["base_sales"]
|
| 345 |
+
|
| 346 |
+
# Store factor (some stores have higher sales)
|
| 347 |
+
store_factor = 0.8 + (store["id"] % 10) / 10 # 0.8 to 1.7
|
| 348 |
+
|
| 349 |
+
# Day of week factor (weekend boost)
|
| 350 |
+
day_of_week = date.weekday() # 0 = Monday, 6 = Sunday
|
| 351 |
+
weekday_factor = 1.0
|
| 352 |
+
if day_of_week >= 5: # Weekend
|
| 353 |
+
weekday_factor = 1.3
|
| 354 |
+
|
| 355 |
+
# Monthly seasonality
|
| 356 |
+
month = date.month
|
| 357 |
+
# Higher sales in December (holidays), lower in February
|
| 358 |
+
month_factor = 1.0 + 0.3 * (month == 12) - 0.1 * (month == 2)
|
| 359 |
+
|
| 360 |
+
# Quarterly business cycle
|
| 361 |
+
quarter = (month - 1) // 3 + 1
|
| 362 |
+
quarter_factor = 1.0 + 0.05 * (quarter - 2.5) # Q3-Q4 slightly higher
|
| 363 |
+
|
| 364 |
+
# Holiday effects
|
| 365 |
+
holiday_factor = 1.0
|
| 366 |
+
# Vietnamese New Year (Tet) - usually in late January or early February
|
| 367 |
+
if (month == 1 and date.day >= 27) or (month == 2 and date.day <= 5):
|
| 368 |
+
holiday_factor = 1.5
|
| 369 |
+
# National Day (September 2)
|
| 370 |
+
elif month == 9 and date.day == 2:
|
| 371 |
+
holiday_factor = 1.3
|
| 372 |
+
# Year-end shopping
|
| 373 |
+
elif month == 12 and date.day >= 20:
|
| 374 |
+
holiday_factor = 1.4
|
| 375 |
+
|
| 376 |
+
# Weather effects if weather data is provided
|
| 377 |
+
weather_factor = 1.0
|
| 378 |
+
if weather_data is not None:
|
| 379 |
+
# Find weather for this date and province
|
| 380 |
+
date_str = date.strftime("%Y-%m-%d")
|
| 381 |
+
province = store["province"]
|
| 382 |
+
day_weather = weather_data.get((date_str, province))
|
| 383 |
+
|
| 384 |
+
if day_weather:
|
| 385 |
+
temp = day_weather["temperature"]
|
| 386 |
+
humidity = day_weather["humidity"]
|
| 387 |
+
|
| 388 |
+
# Temperature effects differ by item category
|
| 389 |
+
if item["category"] == "Beverages & Snacks":
|
| 390 |
+
# More beverages sold in hot weather
|
| 391 |
+
if temp > 28:
|
| 392 |
+
weather_factor *= 1.3
|
| 393 |
+
elif temp < 18:
|
| 394 |
+
weather_factor *= 0.9
|
| 395 |
+
elif item["category"] == "Dairy & Frozen":
|
| 396 |
+
# More ice cream in hot weather
|
| 397 |
+
if temp > 28:
|
| 398 |
+
weather_factor *= 1.4
|
| 399 |
+
elif temp < 18:
|
| 400 |
+
weather_factor *= 0.8
|
| 401 |
+
|
| 402 |
+
# Rain effect (approximated by high humidity)
|
| 403 |
+
if humidity > 80:
|
| 404 |
+
# People buy more when staying indoors
|
| 405 |
+
if item["category"] in [
|
| 406 |
+
"Beverages & Snacks",
|
| 407 |
+
"Household & Personal Care",
|
| 408 |
+
]:
|
| 409 |
+
weather_factor *= 1.2
|
| 410 |
+
|
| 411 |
+
# Year-over-year growth (for 2017 data)
|
| 412 |
+
yoy_growth = 1.0
|
| 413 |
+
if date.year == 2017:
|
| 414 |
+
# 5% general growth with some category variations
|
| 415 |
+
category_growth = {
|
| 416 |
+
"Staples": 1.03,
|
| 417 |
+
"Dairy & Frozen": 1.05,
|
| 418 |
+
"Beverages & Snacks": 1.08,
|
| 419 |
+
"Household & Personal Care": 1.05,
|
| 420 |
+
"Baby & Health": 1.07,
|
| 421 |
+
}
|
| 422 |
+
yoy_growth = category_growth.get(item["category"], 1.05)
|
| 423 |
+
|
| 424 |
+
# Random variation
|
| 425 |
+
random_factor = np.random.normal(1.0, item["volatility"])
|
| 426 |
+
|
| 427 |
+
# Calculate final sales
|
| 428 |
+
sales = (
|
| 429 |
+
base_sales
|
| 430 |
+
* store_factor
|
| 431 |
+
* weekday_factor
|
| 432 |
+
* month_factor
|
| 433 |
+
* quarter_factor
|
| 434 |
+
* holiday_factor
|
| 435 |
+
* weather_factor
|
| 436 |
+
* yoy_growth
|
| 437 |
+
* random_factor
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Ensure minimum sales and convert to integer
|
| 441 |
+
sales = max(
|
| 442 |
+
1, int(round(sales))
|
| 443 |
+
) # Minimum sales of 1 unit, rounded to nearest integer
|
| 444 |
+
|
| 445 |
+
return sales
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def generate_weather_data(start_date, end_date, provinces):
|
| 449 |
+
"""Generate synthetic weather data"""
|
| 450 |
+
|
| 451 |
+
# Define base temperatures and humidity for each province
|
| 452 |
+
province_weather = {
|
| 453 |
+
"Hanoi": {
|
| 454 |
+
"base_temp": {
|
| 455 |
+
1: 16,
|
| 456 |
+
2: 17,
|
| 457 |
+
3: 20,
|
| 458 |
+
4: 24,
|
| 459 |
+
5: 28,
|
| 460 |
+
6: 30,
|
| 461 |
+
7: 30,
|
| 462 |
+
8: 29,
|
| 463 |
+
9: 28,
|
| 464 |
+
10: 25,
|
| 465 |
+
11: 21,
|
| 466 |
+
12: 18,
|
| 467 |
+
},
|
| 468 |
+
"temp_variation": 3.5,
|
| 469 |
+
"base_humidity": {
|
| 470 |
+
1: 80,
|
| 471 |
+
2: 83,
|
| 472 |
+
3: 85,
|
| 473 |
+
4: 85,
|
| 474 |
+
5: 80,
|
| 475 |
+
6: 80,
|
| 476 |
+
7: 83,
|
| 477 |
+
8: 85,
|
| 478 |
+
9: 83,
|
| 479 |
+
10: 78,
|
| 480 |
+
11: 75,
|
| 481 |
+
12: 77,
|
| 482 |
+
},
|
| 483 |
+
"humidity_variation": 10,
|
| 484 |
+
"seasons": {
|
| 485 |
+
1: "winter",
|
| 486 |
+
2: "winter",
|
| 487 |
+
3: "spring",
|
| 488 |
+
4: "spring",
|
| 489 |
+
5: "summer",
|
| 490 |
+
6: "summer",
|
| 491 |
+
7: "summer",
|
| 492 |
+
8: "summer",
|
| 493 |
+
9: "fall",
|
| 494 |
+
10: "fall",
|
| 495 |
+
11: "fall",
|
| 496 |
+
12: "winter",
|
| 497 |
+
},
|
| 498 |
+
},
|
| 499 |
+
"Ho Chi Minh City": {
|
| 500 |
+
"base_temp": {
|
| 501 |
+
1: 26,
|
| 502 |
+
2: 27,
|
| 503 |
+
3: 28,
|
| 504 |
+
4: 29,
|
| 505 |
+
5: 29,
|
| 506 |
+
6: 28,
|
| 507 |
+
7: 28,
|
| 508 |
+
8: 28,
|
| 509 |
+
9: 28,
|
| 510 |
+
10: 27,
|
| 511 |
+
11: 27,
|
| 512 |
+
12: 26,
|
| 513 |
+
},
|
| 514 |
+
"temp_variation": 2.0,
|
| 515 |
+
"base_humidity": {
|
| 516 |
+
1: 70,
|
| 517 |
+
2: 70,
|
| 518 |
+
3: 70,
|
| 519 |
+
4: 75,
|
| 520 |
+
5: 80,
|
| 521 |
+
6: 83,
|
| 522 |
+
7: 85,
|
| 523 |
+
8: 85,
|
| 524 |
+
9: 88,
|
| 525 |
+
10: 85,
|
| 526 |
+
11: 80,
|
| 527 |
+
12: 75,
|
| 528 |
+
},
|
| 529 |
+
"humidity_variation": 8,
|
| 530 |
+
"seasons": {
|
| 531 |
+
1: "dry",
|
| 532 |
+
2: "dry",
|
| 533 |
+
3: "dry",
|
| 534 |
+
4: "dry",
|
| 535 |
+
5: "wet",
|
| 536 |
+
6: "wet",
|
| 537 |
+
7: "wet",
|
| 538 |
+
8: "wet",
|
| 539 |
+
9: "wet",
|
| 540 |
+
10: "wet",
|
| 541 |
+
11: "wet",
|
| 542 |
+
12: "dry",
|
| 543 |
+
},
|
| 544 |
+
},
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
# Create date range
|
| 548 |
+
date_list = []
|
| 549 |
+
current_date = start_date
|
| 550 |
+
while current_date <= end_date:
|
| 551 |
+
date_list.append(current_date)
|
| 552 |
+
current_date += timedelta(days=1)
|
| 553 |
+
|
| 554 |
+
# Generate weather data
|
| 555 |
+
weather_data = []
|
| 556 |
+
weather_dict = {} # For lookup during sales calculation
|
| 557 |
+
|
| 558 |
+
for date in date_list:
|
| 559 |
+
month = date.month
|
| 560 |
+
for province in provinces:
|
| 561 |
+
# Get base values for this province and month
|
| 562 |
+
base_temp = province_weather[province]["base_temp"][month]
|
| 563 |
+
temp_variation = province_weather[province]["temp_variation"]
|
| 564 |
+
base_humidity = province_weather[province]["base_humidity"][month]
|
| 565 |
+
humidity_variation = province_weather[province]["humidity_variation"]
|
| 566 |
+
season = province_weather[province]["seasons"][month]
|
| 567 |
+
|
| 568 |
+
# Add random variation
|
| 569 |
+
temperature = base_temp + np.random.uniform(-temp_variation, temp_variation)
|
| 570 |
+
humidity = base_humidity + np.random.uniform(
|
| 571 |
+
-humidity_variation, humidity_variation
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
# Round to one decimal place
|
| 575 |
+
temperature = round(temperature, 1)
|
| 576 |
+
humidity = round(humidity, 1)
|
| 577 |
+
|
| 578 |
+
# Ensure humidity is within realistic range
|
| 579 |
+
humidity = max(40, min(95, humidity))
|
| 580 |
+
|
| 581 |
+
# Add to weather data
|
| 582 |
+
weather_data.append(
|
| 583 |
+
{
|
| 584 |
+
"city": province,
|
| 585 |
+
"date": date.strftime("%Y-%m-%d"),
|
| 586 |
+
"temperature": temperature,
|
| 587 |
+
"humidity": humidity,
|
| 588 |
+
"season": season,
|
| 589 |
+
}
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
# Add to lookup dictionary
|
| 593 |
+
weather_dict[(date.strftime("%Y-%m-%d"), province)] = {
|
| 594 |
+
"temperature": temperature,
|
| 595 |
+
"humidity": humidity,
|
| 596 |
+
"season": season,
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
return pd.DataFrame(weather_data), weather_dict
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def generate_sales_data(start_date, end_date, stores, items, weather_dict):
|
| 603 |
+
"""Generate synthetic sales data"""
|
| 604 |
+
|
| 605 |
+
# Create date range
|
| 606 |
+
date_list = []
|
| 607 |
+
current_date = start_date
|
| 608 |
+
while current_date <= end_date:
|
| 609 |
+
date_list.append(current_date)
|
| 610 |
+
current_date += timedelta(days=1)
|
| 611 |
+
|
| 612 |
+
# Generate sales data
|
| 613 |
+
sales_data = []
|
| 614 |
+
|
| 615 |
+
# For each date, store, and item, calculate sales
|
| 616 |
+
for date in date_list:
|
| 617 |
+
for store in stores:
|
| 618 |
+
# Not all stores carry all items
|
| 619 |
+
# Use store_id to deterministically select items
|
| 620 |
+
store_seed = store["id"] * 10
|
| 621 |
+
np.random.seed(store_seed)
|
| 622 |
+
|
| 623 |
+
# Select a subset of items for this store
|
| 624 |
+
store_items = []
|
| 625 |
+
for item in items:
|
| 626 |
+
# 80% chance of carrying an item
|
| 627 |
+
if np.random.random() < 0.8:
|
| 628 |
+
store_items.append(item)
|
| 629 |
+
|
| 630 |
+
# Reset random seed
|
| 631 |
+
np.random.seed(None)
|
| 632 |
+
|
| 633 |
+
# Calculate sales for each item
|
| 634 |
+
for item in store_items:
|
| 635 |
+
# Calculate sales for this combination
|
| 636 |
+
sales_value = calculate_daily_sales(date, store, item, weather_dict)
|
| 637 |
+
|
| 638 |
+
# Add to sales data
|
| 639 |
+
sales_data.append(
|
| 640 |
+
{
|
| 641 |
+
"date": date.strftime("%Y-%m-%d"),
|
| 642 |
+
"province": store["province"],
|
| 643 |
+
"store_id": store["id"],
|
| 644 |
+
"store_name": store["name"],
|
| 645 |
+
"category": item["category"],
|
| 646 |
+
"item_id": item["id"],
|
| 647 |
+
"item_name": item["name"],
|
| 648 |
+
"sales": sales_value,
|
| 649 |
+
}
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
return pd.DataFrame(sales_data)
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
def add_outliers_and_nans(data, outlier_percentage=0.01, nan_percentage=0.1):
|
| 656 |
+
"""Add the nan values to data set"""
|
| 657 |
+
# Copy the original data to avoid modifying the input directly
|
| 658 |
+
modified_data = data.copy()
|
| 659 |
+
|
| 660 |
+
# Calculate the number of rows to add outliers and NaN values
|
| 661 |
+
num_rows = len(modified_data)
|
| 662 |
+
num_outliers = int(num_rows * outlier_percentage / 100)
|
| 663 |
+
num_nans = int(num_rows * nan_percentage / 100)
|
| 664 |
+
|
| 665 |
+
# Add outliers to the 'sales' column
|
| 666 |
+
np.random.seed(2025)
|
| 667 |
+
outlier_indices = np.random.choice(num_rows, num_outliers, replace=False)
|
| 668 |
+
modified_data.loc[
|
| 669 |
+
outlier_indices, "sales"
|
| 670 |
+
] *= 3 # Increase sales by a factor to create outliers
|
| 671 |
+
|
| 672 |
+
# Add NaN values to the 'sales' column
|
| 673 |
+
nan_indices = np.random.choice(num_rows, num_nans, replace=False)
|
| 674 |
+
modified_data.loc[nan_indices, "sales"] = np.nan
|
| 675 |
+
|
| 676 |
+
return modified_data
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def check_missing_values(df):
|
| 680 |
+
"""Check missing values"""
|
| 681 |
+
df_nan = pd.DataFrame(
|
| 682 |
+
{
|
| 683 |
+
"counts": df.isna().sum(),
|
| 684 |
+
"ratio (%)": np.round(df.isna().sum() / df.shape[0], 4) * 100,
|
| 685 |
+
}
|
| 686 |
+
)
|
| 687 |
+
return df_nan
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def main():
|
| 691 |
+
"""Main function to generate all data"""
|
| 692 |
+
print("Generating synthetic data for Sales Forecasting with XAI project...")
|
| 693 |
+
|
| 694 |
+
# Create output directory if it doesn't exist
|
| 695 |
+
os.makedirs("data", exist_ok=True)
|
| 696 |
+
|
| 697 |
+
# Generate store and item data
|
| 698 |
+
provinces, stores = generate_store_data()
|
| 699 |
+
categories, items = generate_item_data()
|
| 700 |
+
|
| 701 |
+
print(
|
| 702 |
+
f"Created {len(stores)} stores and {len(items)} items across {len(categories)} categories"
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
# Define date ranges
|
| 706 |
+
start_date_2016 = datetime(2016, 1, 1)
|
| 707 |
+
end_date_2016 = datetime(2016, 12, 31)
|
| 708 |
+
|
| 709 |
+
start_date_2017 = datetime(2017, 1, 1)
|
| 710 |
+
end_date_2017 = datetime(2017, 12, 31)
|
| 711 |
+
|
| 712 |
+
# Generate weather data for both years
|
| 713 |
+
print("Generating weather data...")
|
| 714 |
+
weather_df, weather_dict = generate_weather_data(
|
| 715 |
+
start_date_2016, end_date_2017, provinces
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# Save weather data
|
| 719 |
+
weather_df.to_csv("data/weather_data.csv", index=False)
|
| 720 |
+
print(f"Saved weather data with {len(weather_df)} records")
|
| 721 |
+
|
| 722 |
+
# Generate 2016 sales data
|
| 723 |
+
print("Generating 2016 sales data...")
|
| 724 |
+
sales_2016 = generate_sales_data(
|
| 725 |
+
start_date_2016, end_date_2016, stores, items, weather_dict
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
sales_2016 = add_outliers_and_nans(
|
| 729 |
+
sales_2016, outlier_percentage=0.5, nan_percentage=1
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
# Save 2016 sales data
|
| 733 |
+
sales_2016.to_csv("data/2016_sales.csv", index=False)
|
| 734 |
+
print(f"Saved 2016 sales data with {len(sales_2016)} records")
|
| 735 |
+
|
| 736 |
+
# Generate 2017 sales data
|
| 737 |
+
print("Generating 2017 sales data...")
|
| 738 |
+
sales_2017 = generate_sales_data(
|
| 739 |
+
start_date_2017, end_date_2017, stores, items, weather_dict
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
sales_2017 = add_outliers_and_nans(
|
| 743 |
+
sales_2017, outlier_percentage=0.5, nan_percentage=1
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
# Save 2017 sales data
|
| 747 |
+
sales_2017.to_csv("data/2017_sales.csv", index=False)
|
| 748 |
+
print(f"Saved 2017 sales data with {len(sales_2017)} records")
|
| 749 |
+
|
| 750 |
+
# Print statistics
|
| 751 |
+
print("\nData Generation Complete!")
|
| 752 |
+
print(f"Total weather records: {len(weather_df)}")
|
| 753 |
+
print(f"Total 2016 sales records: {len(sales_2016)}")
|
| 754 |
+
print(f"Total 2017 sales records: {len(sales_2017)}")
|
| 755 |
+
print(
|
| 756 |
+
f"Total combined records: {len(weather_df) + len(sales_2016) + len(sales_2017)}"
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
print("\nSales Statistics:")
|
| 760 |
+
print(f"2016 Average Sales: {sales_2016['sales'].mean():.2f} units")
|
| 761 |
+
print(f"2016 Max Sales: {sales_2016['sales'].max()} units")
|
| 762 |
+
print(f"2017 Average Sales: {sales_2017['sales'].mean():.2f} units")
|
| 763 |
+
print(f"2017 Max Sales: {sales_2017['sales'].max()} units")
|
| 764 |
+
print(f"Missing values: {check_missing_values(sales_2016)}")
|
| 765 |
+
print(f"Missing values: {check_missing_values(sales_2017)}")
|
| 766 |
+
|
| 767 |
+
print("\nFiles saved to data/ directory:")
|
| 768 |
+
print("- data/weather_data.csv")
|
| 769 |
+
print("- data/2016_sales.csv")
|
| 770 |
+
print("- data/2017_sales.csv")
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
if __name__ == "__main__":
|
| 774 |
+
main()
|
app/utils/data_loader.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pickle
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import pyarrow.feather as feather
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
# --- Data & Model Loading Logic ---
|
| 9 |
+
|
| 10 |
+
def load_model():
|
| 11 |
+
"""Load the trained sales forecast model"""
|
| 12 |
+
try:
|
| 13 |
+
with open("models/sales_forecast_model.pkl", "rb") as file:
|
| 14 |
+
model = pickle.load(file)
|
| 15 |
+
return model
|
| 16 |
+
except FileNotFoundError:
|
| 17 |
+
# Using gr.Error for UI notification if called within an interaction
|
| 18 |
+
# or standard print for startup logs
|
| 19 |
+
print("Error: 'models/sales_forecast_model.pkl' not found.")
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
def load_feature_stats():
|
| 23 |
+
"""Load feature statistics used for normalization"""
|
| 24 |
+
try:
|
| 25 |
+
with open("models/feature_stats.json", "r") as file:
|
| 26 |
+
feature_stats = json.load(file)
|
| 27 |
+
return feature_stats
|
| 28 |
+
except FileNotFoundError:
|
| 29 |
+
print("Error: 'models/feature_stats.json' not found.")
|
| 30 |
+
return {}
|
| 31 |
+
|
| 32 |
+
@lru_cache(maxsize=1)
|
| 33 |
+
def load_data():
|
| 34 |
+
"""Load preprocessed sales data (lru_cache replaces @st.cache_data)"""
|
| 35 |
+
try:
|
| 36 |
+
df = pd.read_csv("data/sales_data_preprocessed.csv")
|
| 37 |
+
if "date" in df.columns:
|
| 38 |
+
df["date"] = pd.to_datetime(df["date"])
|
| 39 |
+
return df
|
| 40 |
+
except FileNotFoundError:
|
| 41 |
+
print("Error: 'data/sales_data_preprocessed.csv' not found.")
|
| 42 |
+
return pd.DataFrame(columns=["date", "store", "sales"])
|
| 43 |
+
|
| 44 |
+
def load_feature_engineered_data():
|
| 45 |
+
"""Load feature engineered data with extended features"""
|
| 46 |
+
try:
|
| 47 |
+
feature_engineered_data = feather.read_feather(
|
| 48 |
+
"data/feature_engineered_data_55_features.feather"
|
| 49 |
+
)
|
| 50 |
+
return feature_engineered_data
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Error loading feature engineered data: {str(e)}")
|
| 53 |
+
return pd.DataFrame()
|
| 54 |
+
|
| 55 |
+
# --- Processing Logic ---
|
| 56 |
+
|
| 57 |
+
def preprocess_data(df, feature_stats=None):
|
| 58 |
+
"""Preprocess data for prediction (simplified version)"""
|
| 59 |
+
# Create a copy to avoid modifying the original
|
| 60 |
+
processed_df = df.copy()
|
| 61 |
+
|
| 62 |
+
# Extract date features if date column exists
|
| 63 |
+
if "date" in processed_df.columns:
|
| 64 |
+
processed_df["day_of_week"] = processed_df["date"].dt.dayofweek
|
| 65 |
+
processed_df["day_of_month"] = processed_df["date"].dt.day
|
| 66 |
+
processed_df["month"] = processed_df["date"].dt.month
|
| 67 |
+
processed_df["year"] = processed_df["date"].dt.year
|
| 68 |
+
processed_df["is_weekend"] = processed_df["day_of_week"].apply(
|
| 69 |
+
lambda x: 1 if x >= 5 else 0
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Normalize numerical features if stats are provided
|
| 73 |
+
if feature_stats:
|
| 74 |
+
for feature, stats in feature_stats.items():
|
| 75 |
+
if feature in processed_df.columns and "mean" in stats and "std" in stats:
|
| 76 |
+
processed_df[feature] = (processed_df[feature] - stats["mean"]) / stats[
|
| 77 |
+
"std"
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
return processed_df
|
| 81 |
+
|
| 82 |
+
# --- Gradio UI Implementation ---
|
| 83 |
+
|
| 84 |
+
# Load resources once when the app starts
|
| 85 |
+
model = load_model()
|
| 86 |
+
stats = load_feature_stats()
|
| 87 |
+
|
| 88 |
+
def predict_sales_ui(store_id):
|
| 89 |
+
"""Example function to link the logic to a Gradio interface"""
|
| 90 |
+
if model is None:
|
| 91 |
+
raise gr.Error("Model not loaded. Check server logs.")
|
| 92 |
+
|
| 93 |
+
data = load_data()
|
| 94 |
+
# Apply your logic
|
| 95 |
+
processed = preprocess_data(data, stats)
|
| 96 |
+
|
| 97 |
+
# Filter for the specific store
|
| 98 |
+
store_data = processed[processed['store'] == store_id]
|
| 99 |
+
|
| 100 |
+
# Return results (placeholder for actual model.predict logic)
|
| 101 |
+
return store_data.head()
|
| 102 |
+
|
| 103 |
+
# Simple Gradio Interface
|
| 104 |
+
with gr.Blocks() as demo:
|
| 105 |
+
gr.Markdown("# Sales Forecast Prediction")
|
| 106 |
+
store_input = gr.Number(label="Enter Store ID")
|
| 107 |
+
output_table = gr.DataFrame(label="Preprocessed Data Preview")
|
| 108 |
+
btn = gr.Button("Predict")
|
| 109 |
+
|
| 110 |
+
btn.click(fn=predict_sales_ui, inputs=store_input, outputs=output_table)
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
demo.launch()
|
app/utils/plots.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import seaborn as sns
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def plot_sales(df, store_id=1, item_id=1):
|
| 8 |
+
"""Plot sales and visualize missing values"""
|
| 9 |
+
|
| 10 |
+
df_2plot = df.query("(store_id==@store_id)&(item_id==@item_id)")
|
| 11 |
+
store_name = df_2plot["store_name"].iloc[-1]
|
| 12 |
+
item_name = df_2plot["item_name"].iloc[-1]
|
| 13 |
+
|
| 14 |
+
fig, ax = plt.subplots(figsize=(6, 3))
|
| 15 |
+
df_2plot[["date", "sales"]].plot(x="date", y="sales", ax=ax, legend=False)
|
| 16 |
+
|
| 17 |
+
# Replace NaN values with the mean of surrounding two points
|
| 18 |
+
nan_indices = df_2plot[df_2plot["sales"].isna()].index
|
| 19 |
+
|
| 20 |
+
if len(nan_indices) >= 1:
|
| 21 |
+
df_2plot = df_2plot.assign(sales=lambda df: df["sales"].fillna(method="ffill"))
|
| 22 |
+
# Draw arrows for NaN values
|
| 23 |
+
nan_dates = df_2plot.loc[nan_indices, "date"]
|
| 24 |
+
nan_sales = df_2plot.loc[nan_indices, "sales"]
|
| 25 |
+
for date, sales in zip(nan_dates, nan_sales):
|
| 26 |
+
ax.annotate(
|
| 27 |
+
"-",
|
| 28 |
+
xy=(date, sales),
|
| 29 |
+
color="red", # Set text color to red
|
| 30 |
+
size=20,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Set plot labels and legend
|
| 34 |
+
ax.set_xlabel("Date")
|
| 35 |
+
ax.set_ylabel("Sales")
|
| 36 |
+
ax.set_title(f"Store: {store_name} - Item: {item_name}")
|
| 37 |
+
ax.legend()
|
| 38 |
+
plt.show()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def plot_forecast_single(flat_df, store_item):
|
| 42 |
+
"""
|
| 43 |
+
Plot actual vs predicted sales for one store-item combo from flattened predictions for Prophet.
|
| 44 |
+
"""
|
| 45 |
+
df = flat_df[flat_df["store_item"] == store_item].copy()
|
| 46 |
+
|
| 47 |
+
if df.empty:
|
| 48 |
+
print(f"No data found for: {store_item}")
|
| 49 |
+
return
|
| 50 |
+
|
| 51 |
+
plt.figure(figsize=(12, 6))
|
| 52 |
+
sns.lineplot(data=df, x="ds", y="y", label="Actual", color="black")
|
| 53 |
+
sns.lineplot(data=df, x="ds", y="yhat", label="Forecast", color="blue")
|
| 54 |
+
plt.fill_between(
|
| 55 |
+
df["ds"],
|
| 56 |
+
df["yhat_lower"],
|
| 57 |
+
df["yhat_upper"],
|
| 58 |
+
color="blue",
|
| 59 |
+
alpha=0.2,
|
| 60 |
+
label="Confidence Interval",
|
| 61 |
+
)
|
| 62 |
+
plt.title(f"Forecast vs Actual for {store_item}")
|
| 63 |
+
plt.xlabel("Date")
|
| 64 |
+
plt.ylabel("Sales")
|
| 65 |
+
plt.xticks(rotation=45)
|
| 66 |
+
plt.legend()
|
| 67 |
+
# plt.grid(True)
|
| 68 |
+
plt.tight_layout()
|
| 69 |
+
plt.show()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def plot_sales_predictions(
|
| 73 |
+
df_prediction, store_id=1, nrows=6, ncols=5, figsize=(20, 20)
|
| 74 |
+
):
|
| 75 |
+
"""
|
| 76 |
+
Plots actual vs predicted sales for items in a given store.
|
| 77 |
+
|
| 78 |
+
Parameters:
|
| 79 |
+
df_prediction (DataFrame): Must include ['store_id', 'item_id', 'date', 'sales', 'prediction']
|
| 80 |
+
store_id (int): Store to filter on
|
| 81 |
+
nrows (int): Rows of subplots
|
| 82 |
+
ncols (int): Columns of subplots
|
| 83 |
+
figsize (tuple): Size of the full figure
|
| 84 |
+
"""
|
| 85 |
+
df_sample = df_prediction[df_prediction["store_id"] == store_id]
|
| 86 |
+
store_name = df_sample["store_name"].iloc[-1]
|
| 87 |
+
|
| 88 |
+
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
|
| 89 |
+
axes = axes.flatten()
|
| 90 |
+
|
| 91 |
+
item_ids = sorted(df_sample["item_id"].unique())
|
| 92 |
+
|
| 93 |
+
for i, ax in enumerate(axes):
|
| 94 |
+
if i >= len(item_ids):
|
| 95 |
+
ax.axis("off") # Hide unused subplots
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
item_id = item_ids[i]
|
| 99 |
+
df2plot = df_sample[df_sample["item_id"] == item_id]
|
| 100 |
+
item_name = df2plot["item_name"].iloc[-1]
|
| 101 |
+
|
| 102 |
+
if df2plot.empty:
|
| 103 |
+
ax.axis("off")
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
# Plot actual and predicted sales
|
| 107 |
+
ax.plot(df2plot["date"], df2plot["sales"], label="Actual", color="blue")
|
| 108 |
+
ax.plot(
|
| 109 |
+
df2plot["date"],
|
| 110 |
+
df2plot["prediction"],
|
| 111 |
+
label="Forecast",
|
| 112 |
+
color="red",
|
| 113 |
+
linestyle="--",
|
| 114 |
+
marker=".",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
ax.set_title(f"Item: {item_name}")
|
| 118 |
+
ax.set_xlabel("")
|
| 119 |
+
ax.set_ylabel("Sales")
|
| 120 |
+
ax.tick_params(axis="x", rotation=45)
|
| 121 |
+
ax.grid(True)
|
| 122 |
+
|
| 123 |
+
# Only add legend to the first subplot
|
| 124 |
+
handles, labels = axes[0].get_legend_handles_labels()
|
| 125 |
+
fig.legend(handles, labels, loc="upper center", ncol=2, fontsize=12)
|
| 126 |
+
|
| 127 |
+
plt.tight_layout(rect=[0, 0, 1, 0.97]) # Leave space for the legend
|
| 128 |
+
fig.suptitle(
|
| 129 |
+
f"Sales Forecast vs Actual - Store {store_name}", fontsize=16, fontweight="bold"
|
| 130 |
+
)
|
| 131 |
+
plt.show()
|
app/utils/utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
|
| 3 |
+
import lightgbm as lgbm
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fill_misisng_values(df):
|
| 11 |
+
"""Fill NaN values in the 'sales' column with the mean of non-NaN values"""
|
| 12 |
+
df_filled = df.copy()
|
| 13 |
+
df_filled["sales"] = df_filled["sales"].fillna(df_filled["sales"].mean())
|
| 14 |
+
return df_filled
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def correct_outliers(df, factor=3):
|
| 18 |
+
"""Identify and correct outliers in the 'sales' column by reducing them to the mean"""
|
| 19 |
+
df_corrected = df.copy()
|
| 20 |
+
|
| 21 |
+
# Identify outliers using z-score
|
| 22 |
+
z_scores = (df_corrected["sales"] - df_corrected["sales"].mean()) / df_corrected[
|
| 23 |
+
"sales"
|
| 24 |
+
].std()
|
| 25 |
+
outlier_indices = np.abs(z_scores) > factor # Adjust the threshold as needed
|
| 26 |
+
# Correct outliers by reducing them to the mean
|
| 27 |
+
df_corrected.loc[outlier_indices, "sales"] = df_corrected["sales"].mean()
|
| 28 |
+
|
| 29 |
+
return df_corrected
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_sample_stores(df: pd.DataFrame, store_id: int = 1) -> pd.DataFrame:
|
| 33 |
+
"""Get the sample stores with store_id"""
|
| 34 |
+
grouped = df.groupby("store_id")
|
| 35 |
+
sample_store = grouped.get_group((store_id))
|
| 36 |
+
return sample_store
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def save_data(df, file_path, file_format="feather"):
|
| 40 |
+
"""
|
| 41 |
+
Save a DataFrame to a specified file format.
|
| 42 |
+
|
| 43 |
+
Parameters:
|
| 44 |
+
- df (pd.DataFrame): The DataFrame to be saved.
|
| 45 |
+
- file_path (str): The path where the file will be saved.
|
| 46 |
+
- file_format (str): The format in which to save the file. Supported formats: 'feather', 'csv'.
|
| 47 |
+
Default is 'feather'.
|
| 48 |
+
Example:
|
| 49 |
+
```python
|
| 50 |
+
# Assuming df is the DataFrame you want to save
|
| 51 |
+
save_data(df, 'output_data.feather', file_format='feather')
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Note:
|
| 55 |
+
- Make sure to have the required libraries (pandas and feather-format) installed.
|
| 56 |
+
"""
|
| 57 |
+
if file_format.lower() == "feather":
|
| 58 |
+
# Save to Feather format
|
| 59 |
+
df.to_feather(file_path)
|
| 60 |
+
print(f"DataFrame saved to {file_path} in Feather format.")
|
| 61 |
+
elif file_format.lower() == "csv":
|
| 62 |
+
# Save to CSV format
|
| 63 |
+
df.to_csv(file_path, index=False)
|
| 64 |
+
print(f"DataFrame saved to {file_path} in CSV format.")
|
| 65 |
+
else:
|
| 66 |
+
print(
|
| 67 |
+
f"Error: Unsupported file format '{file_format}'. Supported formats: 'feather', 'csv'."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def flatten_prophet_predictions(predictions_dict):
|
| 72 |
+
all_dfs = []
|
| 73 |
+
|
| 74 |
+
for store_item, df in predictions_dict.items():
|
| 75 |
+
df = df.copy()
|
| 76 |
+
df["store_item"] = store_item
|
| 77 |
+
all_dfs.append(df)
|
| 78 |
+
|
| 79 |
+
return pd.concat(all_dfs, ignore_index=True)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_model(file_path):
|
| 83 |
+
"""
|
| 84 |
+
Load a machine learning model from a file.
|
| 85 |
+
|
| 86 |
+
Parameters:
|
| 87 |
+
- file_path: The file path from where the model will be loaded.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
- The loaded model.
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
with open(file_path, "rb") as file:
|
| 94 |
+
model = pickle.load(file)
|
| 95 |
+
print(f"Sklearn model loaded from {file_path}")
|
| 96 |
+
|
| 97 |
+
except (pickle.UnpicklingError, FileNotFoundError):
|
| 98 |
+
# If loading as scikit-learn model fails or the file is not found,
|
| 99 |
+
# assume it is a LightGBM model (scikit-learn API)
|
| 100 |
+
model = lgbm.Booster(model_file=file_path)
|
| 101 |
+
print(f"LightGBM (scikit-learn API) model loaded from {file_path}")
|
| 102 |
+
|
| 103 |
+
return model
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# Function to calculate WAPE (Weighted Absolute Percentage Error)
|
| 107 |
+
def weighted_absolute_percentage_error(y_true, y_pred):
|
| 108 |
+
"""
|
| 109 |
+
Calculate Weighted Absolute Percentage Error
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
y_true: Actual values
|
| 113 |
+
y_pred: Predicted values
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
WAPE value (percentage)
|
| 117 |
+
"""
|
| 118 |
+
y_true, y_pred = np.array(y_true), np.array(y_pred)
|
| 119 |
+
return 100 * np.sum(np.abs(y_true - y_pred)) / np.sum(np.abs(y_true))
|
app/utils/visualization_code.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import matplotlib.ticker as ticker
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
from matplotlib.dates import DateFormatter
|
| 9 |
+
|
| 10 |
+
# Set up plotting style
|
| 11 |
+
plt.style.use("seaborn-v0_8-whitegrid")
|
| 12 |
+
sns.set_palette("deep")
|
| 13 |
+
plt.rcParams["figure.figsize"] = (14, 8)
|
| 14 |
+
plt.rcParams["font.size"] = 12
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def visualize_predictions_by_store_item(test_results, output_dir="visualizations"):
|
| 18 |
+
"""
|
| 19 |
+
Create visualizations of actual vs predicted values for each store-item combination.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
test_results: DataFrame containing test results with columns:
|
| 23 |
+
'date', 'store_name', 'item_name', 'sales', 'prediction'
|
| 24 |
+
output_dir: Directory to save the visualizations
|
| 25 |
+
"""
|
| 26 |
+
# Create output directory if it doesn't exist
|
| 27 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
# Create a time series plot for each store-item combination
|
| 30 |
+
store_items = test_results.groupby(["store_name", "item_name"])
|
| 31 |
+
|
| 32 |
+
# Get total number of combinations for progress tracking
|
| 33 |
+
total_combinations = len(store_items)
|
| 34 |
+
print(
|
| 35 |
+
f"Creating visualizations for {total_combinations} store-item combinations..."
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Counter for progress tracking
|
| 39 |
+
counter = 0
|
| 40 |
+
|
| 41 |
+
# For each store-item combination, create a plot
|
| 42 |
+
for (store, item), group in store_items:
|
| 43 |
+
# Sort by date to ensure proper time series order
|
| 44 |
+
group = group.sort_values("date")
|
| 45 |
+
|
| 46 |
+
# Convert date to datetime if it's not already
|
| 47 |
+
if not pd.api.types.is_datetime64_any_dtype(group["date"]):
|
| 48 |
+
group["date"] = pd.to_datetime(group["date"])
|
| 49 |
+
|
| 50 |
+
# Create the plot
|
| 51 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 52 |
+
|
| 53 |
+
# Plot actual and predicted values
|
| 54 |
+
ax.plot(
|
| 55 |
+
group["date"], group["sales"], "o-", label="Actual", alpha=0.7, linewidth=2
|
| 56 |
+
)
|
| 57 |
+
ax.plot(
|
| 58 |
+
group["date"],
|
| 59 |
+
group["prediction"],
|
| 60 |
+
"s--",
|
| 61 |
+
label="Predicted",
|
| 62 |
+
alpha=0.7,
|
| 63 |
+
linewidth=2,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Calculate error metrics for this store-item
|
| 67 |
+
mae = np.mean(np.abs(group["sales"] - group["prediction"]))
|
| 68 |
+
mape = (
|
| 69 |
+
np.mean(np.abs((group["sales"] - group["prediction"]) / group["sales"]))
|
| 70 |
+
* 100
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Add title and labels
|
| 74 |
+
ax.set_title(f"Store: {store}, Item: {item}\nMAE: {mae:.2f}, MAPE: {mape:.2f}%")
|
| 75 |
+
ax.set_xlabel("Date")
|
| 76 |
+
ax.set_ylabel("Sales")
|
| 77 |
+
|
| 78 |
+
# Format x-axis dates
|
| 79 |
+
date_formatter = DateFormatter("%Y-%m-%d")
|
| 80 |
+
ax.xaxis.set_major_formatter(date_formatter)
|
| 81 |
+
# Rotate date labels for better readability
|
| 82 |
+
plt.xticks(rotation=45)
|
| 83 |
+
|
| 84 |
+
# Add grid for easier reading
|
| 85 |
+
ax.grid(True, linestyle="--", alpha=0.7)
|
| 86 |
+
|
| 87 |
+
# Add legend
|
| 88 |
+
ax.legend()
|
| 89 |
+
|
| 90 |
+
# Adjust layout
|
| 91 |
+
plt.tight_layout()
|
| 92 |
+
|
| 93 |
+
# Save the figure
|
| 94 |
+
safe_store = store.replace(" ", "_").replace("/", "_")
|
| 95 |
+
safe_item = item.replace(" ", "_").replace("/", "_")
|
| 96 |
+
filename = f"{safe_store}_{safe_item}.png"
|
| 97 |
+
plt.savefig(os.path.join(output_dir, filename))
|
| 98 |
+
|
| 99 |
+
# Close the figure to free memory
|
| 100 |
+
plt.close(fig)
|
| 101 |
+
|
| 102 |
+
# Update progress
|
| 103 |
+
counter += 1
|
| 104 |
+
if counter % 10 == 0:
|
| 105 |
+
print(f"Processed {counter}/{total_combinations} combinations")
|
| 106 |
+
|
| 107 |
+
print(f"All visualizations saved to {output_dir}/")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def visualize_aggregated_predictions(test_results, output_dir="visualizations"):
|
| 111 |
+
"""
|
| 112 |
+
Create aggregated visualizations of actual vs predicted values by store, item, and date.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
test_results: DataFrame containing test results
|
| 116 |
+
output_dir: Directory to save the visualizations
|
| 117 |
+
"""
|
| 118 |
+
# Create output directory if it doesn't exist
|
| 119 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
# Ensure date is in datetime format
|
| 122 |
+
if not pd.api.types.is_datetime64_any_dtype(test_results["date"]):
|
| 123 |
+
test_results["date"] = pd.to_datetime(test_results["date"])
|
| 124 |
+
|
| 125 |
+
# 1. Aggregate by date
|
| 126 |
+
daily_results = (
|
| 127 |
+
test_results.groupby("date")
|
| 128 |
+
.agg({"sales": "sum", "prediction": "sum"})
|
| 129 |
+
.reset_index()
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Plot daily aggregated results
|
| 133 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 134 |
+
ax.plot(
|
| 135 |
+
daily_results["date"],
|
| 136 |
+
daily_results["sales"],
|
| 137 |
+
"o-",
|
| 138 |
+
label="Actual",
|
| 139 |
+
alpha=0.7,
|
| 140 |
+
linewidth=2,
|
| 141 |
+
)
|
| 142 |
+
ax.plot(
|
| 143 |
+
daily_results["date"],
|
| 144 |
+
daily_results["prediction"],
|
| 145 |
+
"s--",
|
| 146 |
+
label="Predicted",
|
| 147 |
+
alpha=0.7,
|
| 148 |
+
linewidth=2,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Add title and labels
|
| 152 |
+
ax.set_title("Total Daily Sales: Actual vs Predicted")
|
| 153 |
+
ax.set_xlabel("Date")
|
| 154 |
+
ax.set_ylabel("Total Sales")
|
| 155 |
+
|
| 156 |
+
# Format x-axis dates
|
| 157 |
+
date_formatter = DateFormatter("%Y-%m-%d")
|
| 158 |
+
ax.xaxis.set_major_formatter(date_formatter)
|
| 159 |
+
plt.xticks(rotation=45)
|
| 160 |
+
|
| 161 |
+
# Add grid and legend
|
| 162 |
+
ax.grid(True, linestyle="--", alpha=0.7)
|
| 163 |
+
ax.legend()
|
| 164 |
+
|
| 165 |
+
# Adjust layout and save
|
| 166 |
+
plt.tight_layout()
|
| 167 |
+
plt.savefig(os.path.join(output_dir, "total_daily_sales.png"))
|
| 168 |
+
plt.close(fig)
|
| 169 |
+
|
| 170 |
+
# 2. Aggregate by store
|
| 171 |
+
store_results = (
|
| 172 |
+
test_results.groupby(["store_name", "date"])
|
| 173 |
+
.agg({"sales": "sum", "prediction": "sum"})
|
| 174 |
+
.reset_index()
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Plot for each store
|
| 178 |
+
stores = store_results["store_name"].unique()
|
| 179 |
+
for store in stores:
|
| 180 |
+
store_data = store_results[store_results["store_name"] == store]
|
| 181 |
+
|
| 182 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 183 |
+
ax.plot(
|
| 184 |
+
store_data["date"],
|
| 185 |
+
store_data["sales"],
|
| 186 |
+
"o-",
|
| 187 |
+
label="Actual",
|
| 188 |
+
alpha=0.7,
|
| 189 |
+
linewidth=2,
|
| 190 |
+
)
|
| 191 |
+
ax.plot(
|
| 192 |
+
store_data["date"],
|
| 193 |
+
store_data["prediction"],
|
| 194 |
+
"s--",
|
| 195 |
+
label="Predicted",
|
| 196 |
+
alpha=0.7,
|
| 197 |
+
linewidth=2,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Add title and labels
|
| 201 |
+
ax.set_title(f"Store: {store} - Total Daily Sales")
|
| 202 |
+
ax.set_xlabel("Date")
|
| 203 |
+
ax.set_ylabel("Total Sales")
|
| 204 |
+
|
| 205 |
+
# Format x-axis dates
|
| 206 |
+
ax.xaxis.set_major_formatter(date_formatter)
|
| 207 |
+
plt.xticks(rotation=45)
|
| 208 |
+
|
| 209 |
+
# Add grid and legend
|
| 210 |
+
ax.grid(True, linestyle="--", alpha=0.7)
|
| 211 |
+
ax.legend()
|
| 212 |
+
|
| 213 |
+
# Adjust layout and save
|
| 214 |
+
plt.tight_layout()
|
| 215 |
+
safe_store = store.replace(" ", "_").replace("/", "_")
|
| 216 |
+
plt.savefig(os.path.join(output_dir, f"store_{safe_store}_total.png"))
|
| 217 |
+
plt.close(fig)
|
| 218 |
+
|
| 219 |
+
# 3. Aggregate by item
|
| 220 |
+
item_results = (
|
| 221 |
+
test_results.groupby(["item_name", "date"])
|
| 222 |
+
.agg({"sales": "sum", "prediction": "sum"})
|
| 223 |
+
.reset_index()
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Plot for each item
|
| 227 |
+
items = item_results["item_name"].unique()
|
| 228 |
+
for item in items:
|
| 229 |
+
item_data = item_results[item_results["item_name"] == item]
|
| 230 |
+
|
| 231 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 232 |
+
ax.plot(
|
| 233 |
+
item_data["date"],
|
| 234 |
+
item_data["sales"],
|
| 235 |
+
"o-",
|
| 236 |
+
label="Actual",
|
| 237 |
+
alpha=0.7,
|
| 238 |
+
linewidth=2,
|
| 239 |
+
)
|
| 240 |
+
ax.plot(
|
| 241 |
+
item_data["date"],
|
| 242 |
+
item_data["prediction"],
|
| 243 |
+
"s--",
|
| 244 |
+
label="Predicted",
|
| 245 |
+
alpha=0.7,
|
| 246 |
+
linewidth=2,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Add title and labels
|
| 250 |
+
ax.set_title(f"Item: {item} - Total Daily Sales")
|
| 251 |
+
ax.set_xlabel("Date")
|
| 252 |
+
ax.set_ylabel("Total Sales")
|
| 253 |
+
|
| 254 |
+
# Format x-axis dates
|
| 255 |
+
ax.xaxis.set_major_formatter(date_formatter)
|
| 256 |
+
plt.xticks(rotation=45)
|
| 257 |
+
|
| 258 |
+
# Add grid and legend
|
| 259 |
+
ax.grid(True, linestyle="--", alpha=0.7)
|
| 260 |
+
ax.legend()
|
| 261 |
+
|
| 262 |
+
# Adjust layout and save
|
| 263 |
+
plt.tight_layout()
|
| 264 |
+
safe_item = item.replace(" ", "_").replace("/", "_")
|
| 265 |
+
plt.savefig(os.path.join(output_dir, f"item_{safe_item}_total.png"))
|
| 266 |
+
plt.close(fig)
|
| 267 |
+
|
| 268 |
+
print(f"Aggregated visualizations saved to {output_dir}/")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def create_interactive_dashboard(test_results, output_dir="visualizations"):
|
| 272 |
+
"""
|
| 273 |
+
Create an interactive HTML dashboard with plots for all store-item combinations.
|
| 274 |
+
Requires Plotly and Dash libraries.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
test_results: DataFrame containing test results
|
| 278 |
+
output_dir: Directory to save the dashboard
|
| 279 |
+
"""
|
| 280 |
+
try:
|
| 281 |
+
import plotly.express as px
|
| 282 |
+
import plotly.graph_objects as go
|
| 283 |
+
from plotly.subplots import make_subplots
|
| 284 |
+
|
| 285 |
+
print("Creating interactive dashboard...")
|
| 286 |
+
|
| 287 |
+
# Create output directory if it doesn't exist
|
| 288 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 289 |
+
|
| 290 |
+
# Ensure date is in datetime format
|
| 291 |
+
if not pd.api.types.is_datetime64_any_dtype(test_results["date"]):
|
| 292 |
+
test_results["date"] = pd.to_datetime(test_results["date"])
|
| 293 |
+
|
| 294 |
+
# Create overall performance figure
|
| 295 |
+
daily_results = (
|
| 296 |
+
test_results.groupby("date")
|
| 297 |
+
.agg({"sales": "sum", "prediction": "sum"})
|
| 298 |
+
.reset_index()
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
fig = go.Figure()
|
| 302 |
+
fig.add_trace(
|
| 303 |
+
go.Scatter(
|
| 304 |
+
x=daily_results["date"],
|
| 305 |
+
y=daily_results["sales"],
|
| 306 |
+
mode="lines+markers",
|
| 307 |
+
name="Actual",
|
| 308 |
+
line=dict(color="blue"),
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
fig.add_trace(
|
| 312 |
+
go.Scatter(
|
| 313 |
+
x=daily_results["date"],
|
| 314 |
+
y=daily_results["prediction"],
|
| 315 |
+
mode="lines+markers",
|
| 316 |
+
name="Predicted",
|
| 317 |
+
line=dict(color="red", dash="dash"),
|
| 318 |
+
)
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
fig.update_layout(
|
| 322 |
+
title="Overall Sales Performance: Actual vs Predicted",
|
| 323 |
+
xaxis_title="Date",
|
| 324 |
+
yaxis_title="Total Sales",
|
| 325 |
+
legend_title="Series",
|
| 326 |
+
height=600,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Save the overall chart as HTML
|
| 330 |
+
fig.write_html(os.path.join(output_dir, "overall_performance.html"))
|
| 331 |
+
|
| 332 |
+
# Create an error heatmap
|
| 333 |
+
store_item_error = (
|
| 334 |
+
test_results.groupby(["store_name", "item_name"])
|
| 335 |
+
.apply(
|
| 336 |
+
lambda x: np.mean(np.abs((x["sales"] - x["prediction"]) / x["sales"]))
|
| 337 |
+
* 100
|
| 338 |
+
)
|
| 339 |
+
.reset_index()
|
| 340 |
+
)
|
| 341 |
+
store_item_error.columns = ["store_name", "item_name", "mape"]
|
| 342 |
+
|
| 343 |
+
# Pivot the data for the heatmap
|
| 344 |
+
heatmap_data = store_item_error.pivot(
|
| 345 |
+
index="store_name", columns="item_name", values="mape"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Create heatmap figure
|
| 349 |
+
heatmap_fig = px.imshow(
|
| 350 |
+
heatmap_data,
|
| 351 |
+
labels=dict(x="Item", y="Store", color="MAPE (%)"),
|
| 352 |
+
x=heatmap_data.columns,
|
| 353 |
+
y=heatmap_data.index,
|
| 354 |
+
color_continuous_scale="RdBu_r",
|
| 355 |
+
title="Mean Absolute Percentage Error by Store and Item",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
heatmap_fig.update_layout(height=800, width=1200)
|
| 359 |
+
|
| 360 |
+
# Save the heatmap as HTML
|
| 361 |
+
heatmap_fig.write_html(os.path.join(output_dir, "error_heatmap.html"))
|
| 362 |
+
|
| 363 |
+
print(f"Interactive dashboard elements saved to {output_dir}/")
|
| 364 |
+
|
| 365 |
+
except ImportError:
|
| 366 |
+
print("Could not create interactive dashboard. Plotly library is required.")
|
| 367 |
+
print("Install it with: pip install plotly dash")
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def visualize_error_distribution(test_results, output_dir="visualizations"):
|
| 371 |
+
"""
|
| 372 |
+
Visualize the distribution and patterns of prediction errors.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
test_results: DataFrame containing test results
|
| 376 |
+
output_dir: Directory to save the visualizations
|
| 377 |
+
"""
|
| 378 |
+
# Create output directory if it doesn't exist
|
| 379 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 380 |
+
|
| 381 |
+
# Calculate errors
|
| 382 |
+
test_results["error"] = test_results["sales"] - test_results["prediction"]
|
| 383 |
+
test_results["abs_error"] = np.abs(test_results["error"])
|
| 384 |
+
test_results["pct_error"] = (test_results["error"] / test_results["sales"]) * 100
|
| 385 |
+
|
| 386 |
+
# 1. Error distribution histogram
|
| 387 |
+
plt.figure(figsize=(12, 6))
|
| 388 |
+
sns.histplot(test_results["error"], kde=True, bins=50)
|
| 389 |
+
plt.axvline(x=0, color="red", linestyle="--")
|
| 390 |
+
plt.title("Distribution of Prediction Errors")
|
| 391 |
+
plt.xlabel("Error (Actual - Predicted)")
|
| 392 |
+
plt.ylabel("Frequency")
|
| 393 |
+
plt.grid(True, linestyle="--", alpha=0.7)
|
| 394 |
+
plt.tight_layout()
|
| 395 |
+
plt.savefig(os.path.join(output_dir, "error_distribution.png"))
|
| 396 |
+
plt.close()
|
| 397 |
+
|
| 398 |
+
# 2. Error vs Actual Sales
|
| 399 |
+
plt.figure(figsize=(12, 6))
|
| 400 |
+
plt.scatter(test_results["sales"], test_results["error"], alpha=0.5)
|
| 401 |
+
plt.axhline(y=0, color="red", linestyle="--")
|
| 402 |
+
plt.title("Prediction Error vs Actual Sales")
|
| 403 |
+
plt.xlabel("Actual Sales")
|
| 404 |
+
plt.ylabel("Error (Actual - Predicted)")
|
| 405 |
+
plt.grid(True, linestyle="--", alpha=0.7)
|
| 406 |
+
plt.tight_layout()
|
| 407 |
+
plt.savefig(os.path.join(output_dir, "error_vs_sales.png"))
|
| 408 |
+
plt.close()
|
| 409 |
+
|
| 410 |
+
# 3. Error over time
|
| 411 |
+
plt.figure(figsize=(14, 6))
|
| 412 |
+
# Ensure date is in datetime format
|
| 413 |
+
if not pd.api.types.is_datetime64_any_dtype(test_results["date"]):
|
| 414 |
+
test_results["date"] = pd.to_datetime(test_results["date"])
|
| 415 |
+
|
| 416 |
+
# Group by date to see overall error trend
|
| 417 |
+
daily_error = test_results.groupby("date")["error"].mean().reset_index()
|
| 418 |
+
plt.plot(daily_error["date"], daily_error["error"], "o-")
|
| 419 |
+
plt.axhline(y=0, color="red", linestyle="--")
|
| 420 |
+
plt.title("Mean Prediction Error Over Time")
|
| 421 |
+
plt.xlabel("Date")
|
| 422 |
+
plt.ylabel("Mean Error")
|
| 423 |
+
date_formatter = DateFormatter("%Y-%m-%d")
|
| 424 |
+
plt.gca().xaxis.set_major_formatter(date_formatter)
|
| 425 |
+
plt.xticks(rotation=45)
|
| 426 |
+
plt.grid(True, linestyle="--", alpha=0.7)
|
| 427 |
+
plt.tight_layout()
|
| 428 |
+
plt.savefig(os.path.join(output_dir, "error_over_time.png"))
|
| 429 |
+
plt.close()
|
| 430 |
+
|
| 431 |
+
# 4. Error by day of week
|
| 432 |
+
test_results["day_of_week"] = test_results["date"].dt.dayofweek
|
| 433 |
+
test_results["day_name"] = test_results["date"].dt.day_name()
|
| 434 |
+
|
| 435 |
+
plt.figure(figsize=(12, 6))
|
| 436 |
+
day_error = (
|
| 437 |
+
test_results.groupby("day_name")["pct_error"]
|
| 438 |
+
.mean()
|
| 439 |
+
.reindex(
|
| 440 |
+
[
|
| 441 |
+
"Monday",
|
| 442 |
+
"Tuesday",
|
| 443 |
+
"Wednesday",
|
| 444 |
+
"Thursday",
|
| 445 |
+
"Friday",
|
| 446 |
+
"Saturday",
|
| 447 |
+
"Sunday",
|
| 448 |
+
]
|
| 449 |
+
)
|
| 450 |
+
)
|
| 451 |
+
sns.barplot(x=day_error.index, y=day_error.values)
|
| 452 |
+
plt.title("Mean Percentage Error by Day of Week")
|
| 453 |
+
plt.xlabel("Day of Week")
|
| 454 |
+
plt.ylabel("Mean Percentage Error (%)")
|
| 455 |
+
plt.axhline(y=0, color="red", linestyle="--")
|
| 456 |
+
plt.grid(True, linestyle="--", alpha=0.7)
|
| 457 |
+
plt.tight_layout()
|
| 458 |
+
plt.savefig(os.path.join(output_dir, "error_by_day_of_week.png"))
|
| 459 |
+
plt.close()
|
| 460 |
+
|
| 461 |
+
# 5. Error by category - only if 'category' column exists
|
| 462 |
+
if "category" in test_results.columns:
|
| 463 |
+
plt.figure(figsize=(12, 6))
|
| 464 |
+
cat_error = test_results.groupby("category")["pct_error"].mean().sort_values()
|
| 465 |
+
sns.barplot(x=cat_error.index, y=cat_error.values)
|
| 466 |
+
plt.title("Mean Percentage Error by Category")
|
| 467 |
+
plt.xlabel("Category")
|
| 468 |
+
plt.ylabel("Mean Percentage Error (%)")
|
| 469 |
+
plt.axhline(y=0, color="red", linestyle="--")
|
| 470 |
+
plt.xticks(rotation=45)
|
| 471 |
+
plt.grid(True, linestyle="--", alpha=0.7)
|
| 472 |
+
plt.tight_layout()
|
| 473 |
+
plt.savefig(os.path.join(output_dir, "error_by_category.png"))
|
| 474 |
+
plt.close()
|
| 475 |
+
|
| 476 |
+
print(f"Error analysis visualizations saved to {output_dir}/")
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def create_forecast_dashboard(
|
| 480 |
+
model, X_test, y_test, test_results, data, output_dir="visualizations"
|
| 481 |
+
):
|
| 482 |
+
"""
|
| 483 |
+
Create a comprehensive dashboard of forecast visualizations.
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
model: Trained model
|
| 487 |
+
X_test: Test features
|
| 488 |
+
y_test: Test target values
|
| 489 |
+
test_results: DataFrame with test results
|
| 490 |
+
data: Original data with date, store, item info
|
| 491 |
+
output_dir: Directory to save visualizations
|
| 492 |
+
"""
|
| 493 |
+
# Create all visualizations
|
| 494 |
+
print("Creating forecast visualizations...")
|
| 495 |
+
|
| 496 |
+
# 1. Individual store-item visualizations (limited to avoid too many plots)
|
| 497 |
+
# Get the top 20 store-item combinations by sales volume
|
| 498 |
+
store_item_sales = (
|
| 499 |
+
test_results.groupby(["store_name", "item_name"])["sales"].sum().reset_index()
|
| 500 |
+
)
|
| 501 |
+
top_combinations = store_item_sales.sort_values("sales", ascending=False).head(20)
|
| 502 |
+
|
| 503 |
+
# Filter test_results to include only these top combinations
|
| 504 |
+
top_results = pd.merge(
|
| 505 |
+
test_results,
|
| 506 |
+
top_combinations[["store_name", "item_name"]],
|
| 507 |
+
on=["store_name", "item_name"],
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Create visualizations for top combinations
|
| 511 |
+
visualize_predictions_by_store_item(top_results, output_dir)
|
| 512 |
+
|
| 513 |
+
# 2. Aggregated visualizations
|
| 514 |
+
visualize_aggregated_predictions(test_results, output_dir)
|
| 515 |
+
|
| 516 |
+
# 3. Error distribution and patterns
|
| 517 |
+
visualize_error_distribution(test_results, output_dir)
|
| 518 |
+
|
| 519 |
+
# 4. Try to create interactive dashboard if plotly is available
|
| 520 |
+
create_interactive_dashboard(test_results, output_dir)
|
| 521 |
+
|
| 522 |
+
print("Forecast visualization dashboard created successfully!")
|