elizabethmyn commited on
Commit
84548c1
·
1 Parent(s): cc73c46

Add demo for Sale forcasting

Browse files
app/core/config.py CHANGED
@@ -9,9 +9,8 @@ class Settings(BaseSettings):
9
 
10
  # Server
11
  HOST: str = "0.0.0.0"
12
- # HOST: str = "127.0.0.1"
13
  PORT: int = 5050
14
- API_PREFIX: str = "/api/v1"
15
 
16
  # Model
17
  MODEL_CHECKPOINT: str = "yainage90/fashion-object-detection"
@@ -19,13 +18,12 @@ class Settings(BaseSettings):
19
 
20
  # Security
21
  SECRET_KEY: str = "xxx"
22
- ALGORITHM: str = "HS256"
23
- ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 30 # 1 month
24
-
25
  API_TOKEN: str = "xxx"
 
 
26
 
27
  class Config:
28
- # env_file = ".env"
29
  case_sensitive = True
30
 
31
  settings = Settings()
 
9
 
10
  # Server
11
  HOST: str = "0.0.0.0"
 
12
  PORT: int = 5050
13
+ API_PREFIX: str = "x"
14
 
15
  # Model
16
  MODEL_CHECKPOINT: str = "yainage90/fashion-object-detection"
 
18
 
19
  # Security
20
  SECRET_KEY: str = "xxx"
 
 
 
21
  API_TOKEN: str = "xxx"
22
+ ALGORITHM: str = ".xxx"
23
+ ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
24
 
25
  class Config:
26
+ env_file = ".env"
27
  case_sensitive = True
28
 
29
  settings = Settings()
app/frontend/dashboard.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ from datetime import datetime
5
+
6
+ from app.frontend.data_viz import (
7
+ plot_category_distribution,
8
+ plot_day_of_week_pattern,
9
+ plot_sales_distribution,
10
+ plot_sales_time_series,
11
+ plot_store_comparison,
12
+ )
13
+
14
+ # Mocking st.session_state for Gradio logic compatibility
15
+ class SessionState(dict):
16
+ def __getattr__(self, item): return self.get(item)
17
+ def __setattr__(self, key, value): self[key] = value
18
+
19
+ session_state = SessionState()
20
+
21
+ def configure_filters(data, start_date, end_date, selected_store_input, selected_categories):
22
+ """Logic-only version of configure_filters (removing st.sidebar calls)"""
23
+
24
+ # Resolve store selection logic from the original code
25
+ selected_store = "All Stores"
26
+ selected_store_name = "All Stores"
27
+
28
+ if "store_name" in data.columns:
29
+ selected_store_name = selected_store_input
30
+ elif "store" in data.columns:
31
+ selected_store = selected_store_input
32
+
33
+ # Filter data based on selection
34
+ filtered_data = data.copy()
35
+
36
+ # Gradio strings to datetime.date
37
+ start_dt = pd.to_datetime(start_date).date()
38
+ end_dt = pd.to_datetime(end_date).date()
39
+
40
+ mask = (filtered_data["date"].dt.date >= start_dt) & (
41
+ filtered_data["date"].dt.date <= end_dt
42
+ )
43
+
44
+ # Apply store filter
45
+ if "store_name" in data.columns and selected_store_name != "All Stores":
46
+ mask &= filtered_data["store_name"] == selected_store_name
47
+ elif "store" in data.columns and selected_store != "All Stores":
48
+ mask &= filtered_data["store"] == selected_store
49
+
50
+ # Apply category filter
51
+ if selected_categories:
52
+ mask &= filtered_data["category"].isin(selected_categories)
53
+
54
+ # Update state for other functions
55
+ session_state.selected_store = selected_store
56
+ session_state.selected_store_name = selected_store_name
57
+ session_state.start_date = start_dt
58
+ session_state.end_date = end_dt
59
+
60
+ return filtered_data[mask]
61
+
62
+ def display_kpis(filtered_data):
63
+ """Logic-only version of display_kpis (returns strings for UI)"""
64
+ total_sales = filtered_data["sales"].sum()
65
+ avg_daily_sales = filtered_data.groupby("date")["sales"].sum().mean()
66
+
67
+ if len(filtered_data["date"].unique()) >= 2:
68
+ mid_date = session_state.start_date + (session_state.end_date - session_state.start_date) / 2
69
+ period1_data = filtered_data[filtered_data["date"].dt.date <= mid_date]
70
+ period2_data = filtered_data[filtered_data["date"].dt.date > mid_date]
71
+ period1_sales = period1_data["sales"].sum() if not period1_data.empty else 0
72
+ period2_sales = period2_data["sales"].sum() if not period2_data.empty else 0
73
+ sales_change_pct = (((period2_sales - period1_sales) / period1_sales * 100) if period1_sales > 0 else 0)
74
+ else:
75
+ sales_change_pct = 0
76
+
77
+ if "transactions" in filtered_data.columns:
78
+ total_transactions = filtered_data["transactions"].sum()
79
+ else:
80
+ total_transactions = filtered_data.shape[0]
81
+
82
+ avg_transaction_value = (total_sales / total_transactions if total_transactions > 0 else 0)
83
+
84
+ # Return formatted strings for Gradio Label/Textbox components
85
+ return (
86
+ total_sales,
87
+ sales_change_pct,
88
+ avg_daily_sales,
89
+ total_transactions,
90
+ avg_transaction_value
91
+ )
92
+
93
+ def display_sales_trends(filtered_data):
94
+ """Logic-only version of display_sales_trends (returns figures)"""
95
+ fig1 = plot_sales_time_series(
96
+ filtered_data,
97
+ session_state.selected_store,
98
+ session_state.selected_store_name,
99
+ )
100
+
101
+ fig2 = None
102
+ if len(filtered_data["date"].unique()) >= 7:
103
+ fig2 = plot_day_of_week_pattern(filtered_data)
104
+
105
+ return fig1, fig2
106
+
107
+ def display_performance_breakdown(filtered_data):
108
+ """Logic-only version of display_performance_breakdown (returns DF and Fig)"""
109
+ category_df = pd.DataFrame()
110
+ fig_cat = None
111
+ store_df = pd.DataFrame()
112
+ fig_store = None
113
+
114
+ if "category" in filtered_data.columns and len(filtered_data["category"].unique()) > 1:
115
+ category_sales = filtered_data.groupby("category")["sales"].sum().sort_values(ascending=False)
116
+ category_sales_pct = (category_sales / category_sales.sum() * 100).round(1)
117
+ category_df = pd.DataFrame({"Sales": category_sales, "Percentage": category_sales_pct}).reset_index()
118
+ category_df["Sales"] = category_df["Sales"].apply(lambda x: f"${x:,.2f}")
119
+ category_df["Percentage"] = category_df["Percentage"].apply(lambda x: f"{x}%")
120
+ fig_cat = plot_category_distribution(filtered_data)
121
+
122
+ if (session_state.selected_store_name == "All Stores" and session_state.selected_store == "All Stores") and \
123
+ ("store_name" in filtered_data.columns or "store" in filtered_data.columns):
124
+ store_identifier = "store_name" if "store_name" in filtered_data.columns else "store"
125
+ store_sales = filtered_data.groupby(store_identifier)["sales"].sum().sort_values(ascending=False)
126
+ top_stores = store_sales.head(10)
127
+ store_df = pd.DataFrame({"Store": top_stores.index, "Sales": top_stores.values})
128
+ store_df["Sales"] = store_df["Sales"].apply(lambda x: f"${x:,.2f}")
129
+ fig_store = plot_store_comparison(filtered_data, store_identifier)
130
+
131
+ return category_df, fig_cat, store_df, fig_store
132
+
133
+ def format_kpi_html(label, value_str, delta_pct=None):
134
+ """Create HTML for metric"""
135
+
136
+ # Process Delta
137
+ delta_html = ""
138
+ if delta_pct is not None and delta_pct != 0:
139
+ if delta_pct > 0:
140
+ color = "color: #38a169;" # Greenn
141
+ arrow = "▲"
142
+ else:
143
+ color = "color: #e53e3e;" # Red
144
+ arrow = "▼"
145
+
146
+ # Format delta: Ví dụ: "▲ 4.2%"
147
+ delta_str = f"{arrow} {abs(delta_pct):.1f}%"
148
+ delta_html = f'<div style="{color} font-size: 14px; font-weight: 500; margin-top: 5px; line-height: 1;">{delta_str}</div>'
149
+
150
+ html_output = f"""
151
+ <div style="font-family: Arial, sans-serif; padding: 10px;">
152
+ <div style="font-size: 14px; color: #555; margin-bottom: 5px;">{label}</div>
153
+ <div style="font-size: 30px; font-weight: 600; color: #1a1a1a; line-height: 1;">{value_str}</div>
154
+ {delta_html}
155
+ </div>
156
+ """
157
+ return html_output
158
+
159
+ def update_kpis_html(total_sales, sales_change_pct, avg_daily_sales, total_transactions, avg_transaction_value):
160
+ """wrapper function update KPI HTML"""
161
+
162
+ html1 = format_kpi_html(
163
+ "💰 Total Sales",
164
+ f"${total_sales:,.2f}",
165
+ sales_change_pct
166
+ )
167
+
168
+ html2 = format_kpi_html(
169
+ "📊 Avg Daily Sales",
170
+ f"${avg_daily_sales:,.2f}"
171
+ )
172
+
173
+ html3 = format_kpi_html(
174
+ "🛒 Total Transactions",
175
+ f"{total_transactions:,}"
176
+ )
177
+
178
+ html4 = format_kpi_html(
179
+ "💵 Avg Transaction Value",
180
+ f"${avg_transaction_value:,.2f}"
181
+ )
182
+
183
+ return html1, html2, html3, html4
184
+
185
+ def historical_sales_view(data):
186
+ """Main Gradio Interface Builder"""
187
+
188
+ def run_dashboard_update(start_date, end_date, store_selection, categories):
189
+ # 1. Logic: Filter
190
+ filtered_data = configure_filters(data, start_date, end_date, store_selection, categories)
191
+
192
+ if filtered_data.empty:
193
+ empty_msg = "⚠️ No data available for the selected filters. Please adjust your selections."
194
+ return [empty_msg] * 4 + [None] * 5 + [pd.DataFrame()]
195
+
196
+ # 2. Logic: KPIs
197
+ kpi_metrics = display_kpis(filtered_data)
198
+ html1, html2, html3, html4 = update_kpis_html(
199
+ *kpi_metrics
200
+ )
201
+
202
+ # 3. Logic: Trends
203
+ fig_ts, fig_dow = display_sales_trends(filtered_data)
204
+
205
+ # 4. Logic: Breakdown
206
+ cat_df, fig_cat, store_df, fig_store = display_performance_breakdown(filtered_data)
207
+
208
+ # 5. Logic: Distribution
209
+ fig_dist = plot_sales_distribution(filtered_data)
210
+
211
+ # 6. Logic: Table
212
+ detailed_table = filtered_data.sort_values("date", ascending=False)
213
+
214
+ return (
215
+ html1, html2, html3, html4,
216
+ fig_ts, fig_dow,
217
+ cat_df, fig_cat,
218
+ store_df, fig_store,
219
+ fig_dist,
220
+ detailed_table
221
+ )
222
+
223
+ # Define the App Layout (Compatible with older Gradio versions)
224
+ with gr.Blocks(title="Store Sales Dashboard") as demo:
225
+ # Left Sidebar - Filters (Fixed)
226
+ with gr.Sidebar(position="right"):
227
+ gr.Markdown("## 🔍 Dashboard Filters")
228
+ gr.Markdown("---")
229
+
230
+ # Date Filters
231
+ gr.Markdown("### 📅 Date Range")
232
+ min_date = data["date"].min().date()
233
+ max_date = data["date"].max().date()
234
+
235
+ start_in = gr.DateTime(
236
+ label="From",
237
+ value=str(min_date),
238
+ type="string",
239
+ interactive=True
240
+ )
241
+ end_in = gr.DateTime(
242
+ label="To",
243
+ value=str(max_date),
244
+ type="string",
245
+ interactive=True
246
+ )
247
+
248
+ gr.Markdown("---")
249
+
250
+ # Store Filter
251
+ gr.Markdown("### 🏬 Store Selection")
252
+ if "store_name" in data.columns:
253
+ opts = ["All Stores"] + sorted(data["store_name"].unique().tolist())
254
+ elif "store" in data.columns:
255
+ opts = ["All Stores"] + sorted(data["store"].unique().tolist())
256
+ else:
257
+ opts = ["All Stores"]
258
+
259
+ store_in = gr.Dropdown(
260
+ choices=opts,
261
+ value="All Stores",
262
+ label="Select Store",
263
+ interactive=True
264
+ )
265
+
266
+ # Category Filter
267
+ cat_in = None
268
+ if "category" in data.columns:
269
+ gr.Markdown("---")
270
+ gr.Markdown("### 📦 Product Categories")
271
+ cats = sorted(data["category"].unique().tolist())
272
+ cat_in = gr.CheckboxGroup(
273
+ choices=cats,
274
+ value=cats,
275
+ label="Select Categories",
276
+ interactive=True
277
+ )
278
+
279
+ gr.Markdown("---")
280
+ btn = gr.Button("🔄 Update Dashboard", variant="primary", size="lg")
281
+
282
+ gr.Markdown(
283
+ """
284
+ <br>
285
+ 💡 **Tip:** Adjust filters and click Update to refresh
286
+ """
287
+ )
288
+
289
+ # Right Column - Main Dashboard
290
+ with gr.Column(scale=1):
291
+ # Header
292
+ gr.Markdown(
293
+ """
294
+ # 📊 Store Sales Dashboard
295
+ ### Comprehensive sales analytics and performance insights
296
+ """
297
+ )
298
+ # KPI Section
299
+ gr.Markdown("## 📈 Key Performance Indicators")
300
+ with gr.Row():
301
+ m1 = gr.HTML(label=None, scale=1, container=True)
302
+ m2 = gr.HTML(label=None, scale=1, container=True)
303
+ m3 = gr.HTML(label=None, scale=1, container=True)
304
+ m4 = gr.HTML(label=None, scale=1, container=True)
305
+
306
+ gr.Markdown("---")
307
+
308
+ # Sales Trends Section
309
+ gr.Markdown("## 📉 Sales Trends Analysis")
310
+ with gr.Row():
311
+ p_ts = gr.Plot(label="📈 Sales Time Series", container=True, scale=1)
312
+ p_dow = gr.Plot(label="📅 Weekly Patterns", container=True, scale=1)
313
+
314
+ gr.Markdown("---")
315
+
316
+ # Performance Breakdown Section
317
+ gr.Markdown("## 🎯 Performance Breakdown")
318
+
319
+ # Category Performance Section
320
+ gr.Markdown("### 📦 Category Performance")
321
+ with gr.Row():
322
+ with gr.Column(scale=1):
323
+ df_cat = gr.DataFrame(label="Category Sales Data", max_height=300)
324
+ with gr.Column(scale=1):
325
+ p_cat = gr.Plot(label="Sales by Category", container=True)
326
+
327
+ gr.Markdown("---")
328
+
329
+ # Store Comparison Section
330
+ gr.Markdown("### 🏪 Store Comparison (Top 10)")
331
+ with gr.Row():
332
+ with gr.Column(scale=1):
333
+ df_store = gr.DataFrame(label="Top Performing Stores", max_height=300)
334
+ with gr.Column(scale=2):
335
+ p_store = gr.Plot(label="Top 10 Stores by Sales", container=True)
336
+
337
+ gr.Markdown("---")
338
+
339
+ # Sales Distribution Section
340
+ gr.Markdown("## 📊 Sales Distribution")
341
+ p_dist = gr.Plot(label="Distribution Analysis", container=True)
342
+
343
+ gr.Markdown("---")
344
+
345
+ # Detailed Data Section
346
+ with gr.Accordion("📋 View Detailed Sales Data", open=True):
347
+ gr.Markdown("*Complete transaction history for the selected period*")
348
+ df_detailed = gr.DataFrame(max_height=400)
349
+
350
+ # Footer
351
+ gr.Markdown(
352
+ """
353
+ ---
354
+ <div style='text-align: center; color: #666; font-size: 0.9em;'>
355
+ 📊 Store Sales Dashboard | Powered by Gradio
356
+ </div>
357
+ """
358
+ )
359
+
360
+ # Link event - Update button
361
+ btn.click(
362
+ run_dashboard_update,
363
+ inputs=[start_in, end_in, store_in, cat_in],
364
+ outputs=[m1, m2, m3, m4, p_ts, p_dow, df_cat, p_cat, df_store, p_store, p_dist, df_detailed]
365
+ )
366
+
367
+ # Auto-load initial data on page load
368
+ demo.load(
369
+ run_dashboard_update,
370
+ inputs=[start_in, end_in, store_in, cat_in],
371
+ outputs=[m1, m2, m3, m4, p_ts, p_dow, df_cat, p_cat, df_store, p_store, p_dist, df_detailed]
372
+ )
373
+
374
+ return demo
375
+
376
+ # Usage:
377
+ # if __name__ == "__main__":
378
+ # df = pd.read_csv("your_data.csv", parse_dates=['date'])
379
+ # app = historical_sales_view(df)
380
+ # app.launch()
app/frontend/data_viz.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import seaborn as sns
5
+
6
+
7
+ def plot_sales_forecast(
8
+ historical_data, prediction_date, prediction_value, store_id=None
9
+ ):
10
+ """
11
+ Plot historical sales with prediction point
12
+ """
13
+ fig, ax = plt.subplots(figsize=(12, 6))
14
+
15
+ # Filter for specific store if provided
16
+ if store_id is not None and "store" in historical_data.columns:
17
+ plot_data = historical_data[historical_data["store"] == store_id].copy()
18
+ else:
19
+ plot_data = historical_data.copy()
20
+
21
+ # Group by date if multiple records per date
22
+ if len(plot_data) > len(plot_data["date"].unique()):
23
+ plot_data = plot_data.groupby("date")["sales"].sum().reset_index()
24
+
25
+ # Sort by date
26
+ plot_data = plot_data.sort_values("date")
27
+
28
+ # Plot historical data
29
+ ax.plot(plot_data["date"], plot_data["sales"], label="Historical Sales")
30
+
31
+ # Add prediction point
32
+ ax.scatter(
33
+ prediction_date, prediction_value, color="red", s=100, label="Prediction"
34
+ )
35
+
36
+ # Formatting
37
+ ax.set_xlabel("Date")
38
+ ax.set_ylabel("Sales")
39
+ if store_id is not None:
40
+ ax.set_title(f"Sales Forecast for Store {store_id}")
41
+ else:
42
+ ax.set_title("Sales Forecast")
43
+ ax.legend()
44
+ fig.autofmt_xdate()
45
+
46
+ return fig
47
+
48
+
49
+ def plot_sales_time_series(
50
+ filtered_data, selected_store=None, selected_store_name=None
51
+ ):
52
+ """Generate time series plot of sales with moving average"""
53
+ fig, ax = plt.subplots(figsize=(7, 6))
54
+
55
+ # Plot data based on store selection
56
+ if selected_store_name == "All Stores" and selected_store == "All Stores":
57
+ # Group by date for the trend line
58
+ sales_by_date = filtered_data.groupby("date")["sales"].sum()
59
+ ax.plot(sales_by_date.index, sales_by_date.values, "b-")
60
+
61
+ # Add moving average
62
+ if len(sales_by_date) > 7:
63
+ sales_by_date_df = sales_by_date.reset_index()
64
+ sales_by_date_df["MA7"] = sales_by_date_df["sales"].rolling(window=7).mean()
65
+ ax.plot(
66
+ sales_by_date_df["date"],
67
+ sales_by_date_df["MA7"],
68
+ "r--",
69
+ label="7-Day Moving Avg",
70
+ )
71
+ ax.legend()
72
+ else:
73
+ # Single store - show daily sales and trend
74
+ sales_by_date = filtered_data.groupby("date")["sales"].sum()
75
+ ax.plot(sales_by_date.index, sales_by_date.values, "b-")
76
+
77
+ # Add moving average if enough data
78
+ if len(sales_by_date) > 7:
79
+ sales_by_date_df = sales_by_date.reset_index()
80
+ sales_by_date_df["MA7"] = sales_by_date_df["sales"].rolling(window=7).mean()
81
+ ax.plot(
82
+ sales_by_date_df["date"],
83
+ sales_by_date_df["MA7"],
84
+ "r--",
85
+ label="7-Day Moving Avg",
86
+ )
87
+ ax.legend()
88
+
89
+ ax.set_xlabel("")
90
+ ax.set_ylabel("Sales ($)")
91
+
92
+ if "store_name" in filtered_data.columns and selected_store_name != "All Stores":
93
+ ax.set_title(f"Daily Sales - {selected_store_name}")
94
+ elif "store" in filtered_data.columns and selected_store != "All Stores":
95
+ ax.set_title(f"Daily Sales - Store {selected_store}")
96
+ else:
97
+ ax.set_title("Daily Sales - All Stores")
98
+
99
+ fig.autofmt_xdate()
100
+ return fig
101
+
102
+
103
+ def plot_day_of_week_pattern(filtered_data):
104
+ """Generate bar chart showing sales by day of week"""
105
+ fig, ax = plt.subplots(figsize=(7, 7))
106
+
107
+ # Add day of week name
108
+ day_names = [
109
+ "Monday",
110
+ "Tuesday",
111
+ "Wednesday",
112
+ "Thursday",
113
+ "Friday",
114
+ "Saturday",
115
+ "Sunday",
116
+ ]
117
+ filtered_data["day_name"] = filtered_data["date"].dt.dayofweek.apply(
118
+ lambda x: day_names[x]
119
+ )
120
+
121
+ # Group by day of week
122
+ day_sales = filtered_data.groupby("day_name")["sales"].mean().reindex(day_names)
123
+
124
+ # Calculate average line
125
+ avg_daily = day_sales.mean()
126
+
127
+ # Create bar chart with average line
128
+ bars = ax.bar(day_sales.index, day_sales.values, color="skyblue")
129
+ ax.axhline(y=avg_daily, color="red", linestyle="--", label="Daily Average")
130
+
131
+ # Highlight best and worst days
132
+ best_day = day_sales.idxmax()
133
+ worst_day = day_sales.idxmin()
134
+
135
+ for i, (day, sales) in enumerate(day_sales.items()):
136
+ if day == best_day:
137
+ bars[i].set_color("green")
138
+ elif day == worst_day:
139
+ bars[i].set_color("orange")
140
+
141
+ ax.set_xlabel("")
142
+ ax.set_ylabel("Average Sales ($)")
143
+ ax.set_title("Sales by Day of Week")
144
+ plt.xticks(rotation=45)
145
+ ax.legend()
146
+
147
+ return fig
148
+
149
+
150
+ def plot_category_distribution(filtered_data):
151
+ """Generate pie chart of sales by category"""
152
+ fig, ax = plt.subplots(figsize=(8, 6))
153
+
154
+ category_sales = (
155
+ filtered_data.groupby("category")["sales"].sum().sort_values(ascending=False)
156
+ )
157
+
158
+ top_categories = category_sales.head(5)
159
+ others = category_sales.iloc[5:].sum() if len(category_sales) > 5 else 0
160
+
161
+ if others > 0:
162
+ plot_data = pd.concat([top_categories, pd.Series([others], index=["Others"])])
163
+ else:
164
+ plot_data = top_categories
165
+
166
+ plt.pie(
167
+ plot_data,
168
+ labels=plot_data.index,
169
+ autopct="%1.1f%%",
170
+ startangle=90,
171
+ shadow=False,
172
+ )
173
+ plt.axis("equal")
174
+ plt.title("Sales by Category")
175
+
176
+ return fig
177
+
178
+
179
+ def plot_store_comparison(filtered_data, store_identifier="store"):
180
+ """Generate horizontal bar chart for top stores by sales"""
181
+ fig, ax = plt.subplots(figsize=(12, 6))
182
+
183
+ # Group by store
184
+ store_sales = (
185
+ filtered_data.groupby(store_identifier)["sales"]
186
+ .sum()
187
+ .sort_values(ascending=False)
188
+ )
189
+
190
+ # Take top 10 stores
191
+ top_stores = store_sales.head(10)
192
+
193
+ # Plot horizontal bar chart
194
+ y_pos = np.arange(len(top_stores))
195
+ ax.barh(y_pos, top_stores.values, align="center")
196
+ ax.set_yticks(y_pos)
197
+ ax.set_yticklabels(top_stores.index)
198
+ ax.invert_yaxis() # Labels read top-to-bottom
199
+ ax.set_xlabel("Sales ($)")
200
+ ax.set_title("Top 10 Stores by Sales")
201
+
202
+ return fig
203
+
204
+
205
+ def plot_sales_distribution(filtered_data):
206
+ """Generate histogram with KDE and summary statistics"""
207
+ fig, ax = plt.subplots(figsize=(18, 4))
208
+
209
+ # Create histogram with KDE
210
+ sns.histplot(filtered_data["sales"], bins=30, kde=True, ax=ax)
211
+
212
+ # Add vertical lines for key statistics
213
+ median_sales = filtered_data["sales"].median()
214
+ mean_sales = filtered_data["sales"].mean()
215
+
216
+ ax.axvline(
217
+ x=median_sales, color="r", linestyle="--", label=f"Median: ${median_sales:.2f}"
218
+ )
219
+ ax.axvline(
220
+ x=mean_sales, color="g", linestyle="--", label=f"Mean: ${mean_sales:.2f}"
221
+ )
222
+
223
+ ax.set_xlabel("Sales ($)")
224
+ ax.set_ylabel("Frequency")
225
+ ax.set_title("Sales Distribution")
226
+ ax.legend()
227
+
228
+ return fig
app/frontend/gradio_ui.py CHANGED
@@ -1,59 +1,65 @@
1
- import os
2
  import gradio as gr
3
  import requests
4
  from PIL import Image, ImageDraw, ImageFont
5
  import io
6
  from typing import List, Dict, Any
7
- from datetime import datetime, timedelta
8
- from pathlib import Path
9
  import random
10
- from app.core.config import settings
11
- from app.core.security import create_access_token
12
 
13
- # Try to import logger, fallback if not available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  try:
15
  from app.utils.logger import logger
 
 
16
  except ImportError:
 
 
 
17
  import logging
18
  logger = logging.getLogger(__name__)
19
- logging.basicConfig(level=logging.INFO)
20
 
21
- # Try to import ui_template, fallback if not available
22
- try:
23
- import ui_template as ui
24
- HAS_UI_TEMPLATE = True
25
- except ImportError:
26
- HAS_UI_TEMPLATE = False
27
- logger.warning("ui_template not found, using basic styling")
28
 
29
- # ==================== API Configuration ====================
30
  API_BASE_URL = "http://localhost:5050"
31
  API_VERSION = "v1"
32
  API_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/detect/image"
33
  API_HEALTH_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/health"
34
  API_BATCH_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/detect/batch"
35
-
36
  SHOW_GRADIO_API = "undocumented"
37
 
38
 
39
- # ==================== Fashion Detection Client ====================
40
  class FashionDetectionClient:
41
  """Client for interacting with the Fashion Detection API"""
42
 
43
  def __init__(self, base_url: str = API_BASE_URL, token: str = None):
44
  self.base_url = base_url
45
- if token is None or token == "xxx":
46
- token = generate_test_token()
47
- else:
48
- token = settings.API_TOKEN
49
- self.token = token
50
- self.headers = {"X-Token": token}
51
  self.session = requests.Session()
52
  self.session.headers.update(self.headers)
53
 
54
  def check_health(self) -> Dict[str, Any]:
55
  """Check API health status"""
56
- logger.info(">>> check_health called")
57
  try:
58
  response = self.session.get(API_HEALTH_ENDPOINT, timeout=10)
59
  response.raise_for_status()
@@ -69,7 +75,6 @@ class FashionDetectionClient:
69
 
70
  def detect_single_image(self, image: Image.Image, threshold: float = 0.4) -> Dict[str, Any]:
71
  """Detect objects in a single image"""
72
- logger.info(">>> detect_single_image function")
73
  try:
74
  img_byte_arr = io.BytesIO()
75
  image.save(img_byte_arr, format='PNG')
@@ -77,7 +82,7 @@ class FashionDetectionClient:
77
 
78
  files = {"file": ("image.png", img_byte_arr, "image/png")}
79
  params = {"threshold": threshold} if threshold else {}
80
- logger.info(f">>Sending request to {API_ENDPOINT} with params={params}")
81
  response = self.session.post(
82
  API_ENDPOINT,
83
  files=files,
@@ -85,14 +90,12 @@ class FashionDetectionClient:
85
  timeout=30
86
  )
87
  response.raise_for_status()
88
- logger.info(f">>response {response}")
89
  return response.json()
90
 
91
  except requests.exceptions.RequestException as e:
92
- logger.info(f"Lỗi: {response.status_code}. Chi tiết: {response.json()}, API_TOKEN={self.token}")
93
  return {
94
  "success": False,
95
- "error": f"API request failed: {str(e)}\n",
96
  "details": f"URL: {API_ENDPOINT}"
97
  }
98
  except Exception as e:
@@ -129,7 +132,6 @@ class FashionDetectionClient:
129
  }
130
 
131
 
132
- # ==================== Drawing Functions ====================
133
  def draw_bounding_boxes_pil(image: Image.Image, detections: List[Dict[str, Any]]) -> Image.Image:
134
  """Draw bounding boxes on PIL Image"""
135
  img_with_boxes = image.copy()
@@ -212,62 +214,12 @@ def format_detection_results(result: Dict[str, Any]) -> str:
212
  return result_text
213
 
214
 
215
- # ==================== Helper Functions ====================
216
- def convert_to_pil_images(gradio_files: List) -> List[Image.Image]:
217
- """Convert Gradio NamedString objects (file paths) to PIL Images"""
218
- pil_images = []
219
- for file in gradio_files:
220
- try:
221
- file_path = file.name if hasattr(file, 'name') else file
222
- pil_image = Image.open(file_path)
223
- if pil_image.mode != "RGB":
224
- pil_image = pil_image.convert("RGB")
225
- pil_images.append(pil_image)
226
- except Exception as e:
227
- logger.error(f"Error converting image {file_path}: {str(e)}")
228
- return pil_images
229
-
230
- #
231
- def generate_test_token():
232
- access_token = create_access_token(
233
- data={"sub": "test_user"},
234
- expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
235
- )
236
- print(f"Generated test token: {access_token}")
237
- return access_token
238
-
239
- # ==================== Main Application ====================
240
- def create_app():
241
- """Create the main Gradio application"""
242
-
243
- # Configure UI template if available
244
- if HAS_UI_TEMPLATE:
245
- ui.configure(
246
- project_name="Intelligent Retail Decision Making System",
247
- year="2025",
248
- about="AI-powered fashion detection and retail analytics",
249
- description="An integrated platform for fashion item detection and sales forecasting.",
250
- colors={
251
- "primary": "#0F6CBD",
252
- "accent": "#C4314B",
253
- "success": "#2E7D32",
254
- "bg1": "#F0F7FF",
255
- "bg2": "#E8F0FA",
256
- "bg3": "#DDE7F8"
257
- },
258
- meta_items=[
259
- ("Model", "Fashion Detection & Sales Forecasting"),
260
- ("Features", "Object Detection & Predictive Analytics"),
261
- ]
262
- )
263
-
264
- # Initialize API client
265
- api_client = FashionDetectionClient()
266
 
267
- # ==================== Prediction Functions ====================
268
  def predict_single_image(image: Image.Image, threshold: float) -> tuple:
269
  """Predict objects in a single image"""
270
- logger.info(">>> predict_single_image called")
271
  try:
272
  health_status = api_client.check_health()
273
  if not health_status.get('success', False):
@@ -275,20 +227,20 @@ def create_app():
275
 
276
  result = api_client.detect_single_image(image, threshold)
277
  result_text = format_detection_results(result)
278
- logger.info(f">>> predict_single_image result_text: {result_text}")
279
  if result.get('success', False) and result.get('detections'):
280
  image_with_boxes = draw_bounding_boxes_pil(image, result['detections'])
281
  return image_with_boxes, result_text
282
  else:
283
  return image, result_text
284
- logger.info(">>> predict_single_image completed")
285
  except Exception as e:
286
  error_msg = f"❌ Prediction error: {str(e)}"
287
  return image, error_msg
288
 
289
  def predict_batch_images(images: List[Image.Image], threshold: float):
290
  """Predict objects in multiple images"""
291
- logger.info(">>> predict_batch_images called")
292
  try:
293
  if not images:
294
  return [], "Please upload at least one image."
@@ -329,21 +281,32 @@ def create_app():
329
  except Exception as e:
330
  return [], f"❌ Batch prediction error: {str(e)}"
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  def check_api_health():
333
  """Check and display API health status"""
334
- logger.info(">>> check_api_health called")
335
  health_status = api_client.check_health()
 
336
 
337
- if health_status.get('success', False):
338
- status_emoji = ""
339
- status_text = "Healthy"
340
- else:
341
- status_emoji = "❌"
342
- status_text = "Unhealthy"
343
 
344
  health_info = f"{status_emoji} API Status: {status_text}\n\n"
345
  health_info += f"📡 Endpoint: {API_BASE_URL}\n"
346
- health_info += f"🕒 Checked: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
347
 
348
  if health_status.get('success', False):
349
  health_info += f"🚀 Version: {health_status.get('version', 'N/A')}\n"
@@ -355,193 +318,171 @@ def create_app():
355
 
356
  return health_info
357
 
358
- def load_historical():
359
- """Load historical sales analysis"""
360
- try:
361
- return "<div style='padding: 20px;'>Historical sales analysis would be displayed here.</div>"
362
- except Exception as e:
363
- return f"<div style='color: red; padding: 20px;'>Error loading analysis: {str(e)}</div>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- def make_prediction(date_str, horizon):
366
- """Make sales prediction"""
367
- try:
368
- return f"<div style='padding: 20px;'>Prediction for {date_str} with horizon {horizon} days would be displayed here.</div>"
369
- except Exception as e:
370
- return f"<div style='color: red; padding: 20px;'>Error generating prediction: {str(e)}</div>"
371
 
372
- # ==================== Create Gradio Interface ====================
373
- demo = gr.Blocks(title="💡 Intelligent Retail Decision Making System")
 
 
 
 
374
 
375
- with demo:
376
- # Header
377
- if HAS_UI_TEMPLATE:
378
- ui.create_header(logo_path="static/intelligent_retail.png")
379
- else:
380
- gr.Markdown("#💡 Intelligent Retail Decision Making System")
381
 
382
- gr.Markdown("AI-powered fashion detection and retail analytics platform")
383
 
384
- # Info card
385
- if HAS_UI_TEMPLATE:
386
- gr.HTML(ui.render_info_card(
387
- icon="📈",
388
- title="About this Application"
389
- ))
390
 
391
- # Main tabs
392
- with gr.Tabs():
393
- # Tab 1: Fashion Detection
394
- with gr.Tab("👔 Fashion Detection"):
395
- gr.Markdown("## Fashion Item Detection")
396
-
397
- # API Health Section
398
- with gr.Row():
399
- with gr.Column():
400
- gr.Markdown("### 📊 API Status")
401
- health_btn = gr.Button("Check API Health", variant="secondary")
402
- health_output = gr.Textbox(label="API Health Status", lines=6, interactive=False)
403
-
404
- # Single Image Detection
405
- with gr.Row():
406
- with gr.Column():
407
- gr.Markdown("### 📷 Single Image Detection")
408
- single_image = gr.Image(type="pil", label="Upload Fashion Image")
409
- threshold_slider = gr.Slider(
410
- minimum=0.1, maximum=0.9, value=0.4, step=0.05,
411
- label="Detection Confidence Threshold"
412
- )
413
- single_btn = gr.Button("Detect Objects", variant="primary")
414
-
415
- with gr.Column():
416
- single_output_image = gr.Image(label="Detection Results", interactive=False)
417
- single_output_text = gr.Textbox(label="Detection Results", lines=12)
418
-
419
- # Batch Image Detection
420
- with gr.Row():
421
- with gr.Column():
422
- gr.Markdown("### 📦 Batch Image Detection")
423
- batch_images = gr.File(
424
- label="Upload Multiple Images",
425
- file_count="multiple",
426
- file_types=["image"]
427
- )
428
- batch_threshold = gr.Slider(
429
- minimum=0.1, maximum=0.9, value=0.4, step=0.05,
430
- label="Detection Confidence Threshold"
431
- )
432
- batch_btn = gr.Button("Process Batch", variant="primary")
433
-
434
- with gr.Column():
435
- batch_output_images = gr.Gallery(
436
- label="Detection Results",
437
- columns=3,
438
- height="auto",
439
- interactive=False
440
- )
441
- batch_output_text = gr.Textbox(label="Batch Results", lines=15)
442
-
443
- # Examples
444
- if os.path.exists("static/examples"):
445
- gr.Examples(
446
- examples=[
447
- ["static/examples/image1.png"],
448
- ["static/examples/image2.png"],
449
- ["static/examples/image3.png"]
450
- ],
451
- inputs=single_image,
452
- label="Try these example images"
453
- )
454
-
455
- # Event handlers
456
- health_btn.click(
457
- fn=check_api_health,
458
- outputs=health_output,
459
- api_visibility=SHOW_GRADIO_API
460
- )
461
 
462
- single_btn.click(
463
- fn=predict_single_image,
464
- inputs=[single_image, threshold_slider],
465
- outputs=[single_output_image, single_output_text],
466
- api_visibility=SHOW_GRADIO_API
467
- )
 
468
 
469
- batch_btn.click(
470
- fn=lambda images, threshold: predict_batch_images(convert_to_pil_images(images), threshold),
471
- inputs=[batch_images, batch_threshold],
472
- outputs=[batch_output_images, batch_output_text],
473
- api_visibility=SHOW_GRADIO_API
474
- )
 
 
 
475
 
476
- # Tab 2: Historical Sales Analysis
477
- with gr.Tab("📊 Historical Sales Analysis"):
478
- gr.Markdown("### Explore and visualize historical sales data")
479
 
480
- with gr.Row():
481
- analyze_btn = gr.Button(
482
- "Load Historical Analysis",
483
- variant="primary",
484
- size="lg"
485
- )
486
 
487
- historical_output = gr.HTML(label="Analysis Results")
 
488
 
489
- analyze_btn.click(
490
- fn=load_historical,
491
- outputs=historical_output
492
- )
493
 
494
- # Tab 3: Sales Prediction
495
- with gr.Tab("🔮 Sales Prediction"):
496
- gr.Markdown("### Generate sales forecasts using machine learning")
497
-
498
- with gr.Row():
499
- with gr.Column(scale=1):
500
- gr.Markdown("#### Input Parameters")
501
-
502
- date_input = gr.Textbox(
503
- label="Prediction Date",
504
- placeholder="YYYY-MM-DD",
505
- info="Enter the date for prediction",
506
- value=""
507
- )
508
-
509
- forecast_horizon = gr.Slider(
510
- minimum=1,
511
- maximum=30,
512
- value=7,
513
- step=1,
514
- label="Forecast Horizon (days)",
515
- info="Number of days to forecast"
516
- )
517
-
518
- predict_btn = gr.Button(
519
- "Generate Prediction",
520
- variant="primary",
521
- size="lg"
522
- )
523
-
524
- with gr.Column(scale=2):
525
- gr.Markdown("#### Prediction Results")
526
- prediction_output = gr.HTML(label="Forecast")
527
-
528
- predict_btn.click(
529
- fn=make_prediction,
530
- inputs=[date_input, forecast_horizon],
531
- outputs=prediction_output
532
- )
533
 
534
  # Footer
535
- if HAS_UI_TEMPLATE:
536
- ui.create_footer(
537
- logo_path="static/intelligent_retail.png",
538
- creator_name="Thi-Diem-My Le",
539
- creator_link="https://beacons.ai/elizabethmyn",
540
- org_name="AI VIET NAM",
541
- org_link="https://aivietnam.edu.vn/"
542
- )
543
- else:
544
- gr.Markdown("---")
 
 
 
 
545
  gr.Markdown("© 2025 Intelligent Retail System. All rights reserved.")
546
 
547
  return demo
@@ -550,7 +491,7 @@ def create_app():
550
  # ==================== Application Entry Point ====================
551
  def main():
552
  """Main entry point"""
553
- demo = create_app()
554
 
555
  # Custom CSS
556
  custom_css = """
@@ -559,8 +500,7 @@ def main():
559
  .error {color: red; font-weight: bold;}
560
  """
561
 
562
- if HAS_UI_TEMPLATE:
563
- custom_css = ui.get_custom_css() + custom_css
564
 
565
  demo.launch(
566
  server_name="0.0.0.0",
 
 
1
  import gradio as gr
2
  import requests
3
  from PIL import Image, ImageDraw, ImageFont
4
  import io
5
  from typing import List, Dict, Any
6
+ from datetime import datetime
 
7
  import random
 
 
8
 
9
+ # Import modules from sales forecasting app
10
+ try:
11
+ import app.frontend.ui_template as ui
12
+ from app.utils.data_loader import (
13
+ load_data,
14
+ load_feature_engineered_data,
15
+ load_feature_stats,
16
+ load_model,
17
+ )
18
+ from app.frontend.dashboard import historical_sales_view
19
+ from app.services.prediction import sales_prediction_view
20
+ SALES_MODULE_AVAILABLE = True
21
+ except ImportError:
22
+ SALES_MODULE_AVAILABLE = False
23
+ print("Warning: Sales forecasting modules not available")
24
+
25
+ # Import fashion detection modules
26
  try:
27
  from app.utils.logger import logger
28
+ from app.core.config import settings
29
+ FASHION_MODULE_AVAILABLE = True
30
  except ImportError:
31
+ FASHION_MODULE_AVAILABLE = False
32
+ print("Warning: Fashion detection modules not available")
33
+ # Fallback logger
34
  import logging
35
  logger = logging.getLogger(__name__)
 
36
 
37
+ # Fallback settings
38
+ class Settings:
39
+ API_TOKEN = "your-api-token-here"
40
+ settings = Settings()
 
 
 
41
 
42
+ # Configuration for Fashion Detection API
43
  API_BASE_URL = "http://localhost:5050"
44
  API_VERSION = "v1"
45
  API_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/detect/image"
46
  API_HEALTH_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/health"
47
  API_BATCH_ENDPOINT = f"{API_BASE_URL}/api/{API_VERSION}/detect/batch"
 
48
  SHOW_GRADIO_API = "undocumented"
49
 
50
 
 
51
  class FashionDetectionClient:
52
  """Client for interacting with the Fashion Detection API"""
53
 
54
  def __init__(self, base_url: str = API_BASE_URL, token: str = None):
55
  self.base_url = base_url
56
+ self.token = token or (settings.API_TOKEN if FASHION_MODULE_AVAILABLE else "default-token")
57
+ self.headers = {"X-Token": self.token}
 
 
 
 
58
  self.session = requests.Session()
59
  self.session.headers.update(self.headers)
60
 
61
  def check_health(self) -> Dict[str, Any]:
62
  """Check API health status"""
 
63
  try:
64
  response = self.session.get(API_HEALTH_ENDPOINT, timeout=10)
65
  response.raise_for_status()
 
75
 
76
  def detect_single_image(self, image: Image.Image, threshold: float = 0.4) -> Dict[str, Any]:
77
  """Detect objects in a single image"""
 
78
  try:
79
  img_byte_arr = io.BytesIO()
80
  image.save(img_byte_arr, format='PNG')
 
82
 
83
  files = {"file": ("image.png", img_byte_arr, "image/png")}
84
  params = {"threshold": threshold} if threshold else {}
85
+
86
  response = self.session.post(
87
  API_ENDPOINT,
88
  files=files,
 
90
  timeout=30
91
  )
92
  response.raise_for_status()
 
93
  return response.json()
94
 
95
  except requests.exceptions.RequestException as e:
 
96
  return {
97
  "success": False,
98
+ "error": f"API request failed: {str(e)}",
99
  "details": f"URL: {API_ENDPOINT}"
100
  }
101
  except Exception as e:
 
132
  }
133
 
134
 
 
135
  def draw_bounding_boxes_pil(image: Image.Image, detections: List[Dict[str, Any]]) -> Image.Image:
136
  """Draw bounding boxes on PIL Image"""
137
  img_with_boxes = image.copy()
 
214
  return result_text
215
 
216
 
217
+ def create_fashion_detection_tab(api_client: FashionDetectionClient):
218
+ """Create the Fashion Detection tab"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
 
220
  def predict_single_image(image: Image.Image, threshold: float) -> tuple:
221
  """Predict objects in a single image"""
222
+ logger.info(">>> predict_single_image ping clicked")
223
  try:
224
  health_status = api_client.check_health()
225
  if not health_status.get('success', False):
 
227
 
228
  result = api_client.detect_single_image(image, threshold)
229
  result_text = format_detection_results(result)
230
+
231
  if result.get('success', False) and result.get('detections'):
232
  image_with_boxes = draw_bounding_boxes_pil(image, result['detections'])
233
  return image_with_boxes, result_text
234
  else:
235
  return image, result_text
236
+
237
  except Exception as e:
238
  error_msg = f"❌ Prediction error: {str(e)}"
239
  return image, error_msg
240
 
241
  def predict_batch_images(images: List[Image.Image], threshold: float):
242
  """Predict objects in multiple images"""
243
+ logger.info(">>> predict_batch_images ping clicked")
244
  try:
245
  if not images:
246
  return [], "Please upload at least one image."
 
281
  except Exception as e:
282
  return [], f"❌ Batch prediction error: {str(e)}"
283
 
284
+ def convert_to_pil_images(gradio_files: List) -> List[Image.Image]:
285
+ """Convert Gradio file objects to PIL Images"""
286
+ pil_images = []
287
+ for file in gradio_files:
288
+ try:
289
+ file_path = file.name if hasattr(file, 'name') else file
290
+ pil_image = Image.open(file_path)
291
+ if pil_image.mode != "RGB":
292
+ pil_image = pil_image.convert("RGB")
293
+ pil_images.append(pil_image)
294
+ except Exception as e:
295
+ logger.error(f"Error converting image {file_path}: {str(e)}")
296
+ return pil_images
297
+
298
  def check_api_health():
299
  """Check and display API health status"""
300
+ logger.info(">>> check_api_health ping clicked")
301
  health_status = api_client.check_health()
302
+ logger.info(health_status)
303
 
304
+ status_emoji = "✅" if health_status.get('success', False) else "❌"
305
+ status_text = "Healthy" if health_status.get('success', False) else "Unhealthy"
 
 
 
 
306
 
307
  health_info = f"{status_emoji} API Status: {status_text}\n\n"
308
  health_info += f"📡 Endpoint: {API_BASE_URL}\n"
309
+ health_info += f"🕐 Checked: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
310
 
311
  if health_status.get('success', False):
312
  health_info += f"🚀 Version: {health_status.get('version', 'N/A')}\n"
 
318
 
319
  return health_info
320
 
321
+ with gr.Column():
322
+ gr.Markdown("# 👔 Fashion Detection System")
323
+ gr.Markdown("Upload images to detect fashion items using our AI-powered API")
324
+
325
+ # API Health Section
326
+ with gr.Row():
327
+ with gr.Column():
328
+ gr.Markdown("## 📊 API Status")
329
+ health_btn = gr.Button("Check API Health", variant="secondary")
330
+ health_output = gr.Textbox(label="API Health Status", lines=6, interactive=False)
331
+
332
+ # Single Image Detection
333
+ with gr.Row():
334
+ with gr.Column():
335
+ gr.Markdown("## 📷 Single Image Detection")
336
+ single_image = gr.Image(type="pil", label="Upload Fashion Image")
337
+ threshold_slider = gr.Slider(
338
+ minimum=0.1, maximum=0.9, value=0.4, step=0.05,
339
+ label="Detection Confidence Threshold"
340
+ )
341
+ single_btn = gr.Button("Detect Objects", variant="primary")
342
+
343
+ with gr.Column():
344
+ single_output_image = gr.Image(label="Detection Results", interactive=False)
345
+ single_output_text = gr.Textbox(label="Detection Results", lines=12)
346
+
347
+ # Batch Image Detection
348
+ with gr.Row():
349
+ with gr.Column():
350
+ gr.Markdown("## 📦 Batch Image Detection")
351
+ batch_images = gr.File(
352
+ label="Upload Multiple Images",
353
+ file_count="multiple",
354
+ file_types=["image"]
355
+ )
356
+ batch_threshold = gr.Slider(
357
+ minimum=0.1, maximum=0.9, value=0.4, step=0.05,
358
+ label="Detection Confidence Threshold"
359
+ )
360
+ batch_btn = gr.Button("Process Batch", variant="primary")
361
+
362
+ with gr.Column():
363
+ batch_output_images = gr.Gallery(
364
+ label="Detection Results",
365
+ columns=3,
366
+ height="auto",
367
+ interactive=False
368
+ )
369
+ batch_output_text = gr.Textbox(label="Batch Results", lines=15)
370
+
371
+ # Examples
372
+ gr.Examples(
373
+ examples=[
374
+ ["static/examples/image1.png"],
375
+ ["static/examples/image2.png"],
376
+ ["static/examples/image3.png"]
377
+ ],
378
+ inputs=single_image,
379
+ label="Try these example images"
380
+ )
381
 
382
+ # Event handlers
383
+ health_btn.click(
384
+ fn=check_api_health,
385
+ outputs=health_output,
386
+ api_visibility=SHOW_GRADIO_API
387
+ )
388
 
389
+ single_btn.click(
390
+ fn=predict_single_image,
391
+ inputs=[single_image, threshold_slider],
392
+ outputs=[single_output_image, single_output_text],
393
+ api_visibility=SHOW_GRADIO_API
394
+ )
395
 
396
+ batch_btn.click(
397
+ fn=lambda images, threshold: predict_batch_images(convert_to_pil_images(images), threshold),
398
+ inputs=[batch_images, batch_threshold],
399
+ outputs=[batch_output_images, batch_output_text],
400
+ api_visibility=SHOW_GRADIO_API
401
+ )
402
 
 
403
 
404
+ def create_sales_forecasting_tab(data, model, feature_stats):
405
+ """Create the Sales Forecasting tab"""
 
 
 
 
406
 
407
+ with gr.Column():
408
+ gr.Markdown("# 📈 Sales Forecasting System")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
+ # Page selector for sales forecasting sub-sections
411
+ page_selector = gr.Dropdown(
412
+ choices=["Historical Sales Analysis", "Sales Prediction"],
413
+ value="Historical Sales Analysis",
414
+ label="Choose a view",
415
+ interactive=True
416
+ )
417
 
418
+ # Render content based on selection
419
+ @gr.render(inputs=page_selector)
420
+ def render_sales_content(page):
421
+ if page == "Historical Sales Analysis":
422
+ historical_sales_view(data)
423
+ else:
424
+ print("Loading feature engineered data for prediction...")
425
+ feature_engineered_data = load_feature_engineered_data()
426
+ sales_prediction_view(data, model, feature_stats, feature_engineered_data)
427
 
 
 
 
428
 
429
+ def create_gradio_interface():
430
+ """Create the Gradio application"""
 
 
 
 
431
 
432
+ # Initialize API client for fashion detection
433
+ api_client = FashionDetectionClient()
434
 
435
+ # Load sales forecasting data if available
436
+ sales_data = None
437
+ sales_model = None
438
+ sales_feature_stats = None
439
 
440
+ if SALES_MODULE_AVAILABLE:
441
+ try:
442
+ sales_data = load_data()
443
+ sales_model = load_model()
444
+ sales_feature_stats = load_feature_stats()
445
+ except Exception as e:
446
+ print(f"Warning: Could not load sales forecasting data: {e}")
447
+
448
+ # Create main interface
449
+ with gr.Blocks(
450
+ title="💡 Intelligent Retail Decision Making System",
451
+ ) as demo:
452
+
453
+ gr.Markdown("# 💡 Intelligent Retail Decision Making System")
454
+ gr.Markdown("### Comprehensive AI-powered solution for retail analytics and product detection")
455
+
456
+ # Main navigation tabs
457
+ with gr.Tabs():
458
+ # Fashion Detection Tab
459
+ with gr.Tab("👔 Fashion Detection"):
460
+ create_fashion_detection_tab(api_client)
461
+
462
+ # Sales Forecasting Tab
463
+ if SALES_MODULE_AVAILABLE and sales_data is not None:
464
+ with gr.Tab("📈 Sales Forecasting"):
465
+ create_sales_forecasting_tab(sales_data, sales_model, sales_feature_stats)
466
+ else:
467
+ with gr.Tab("📈 Sales Forecasting"):
468
+ gr.Markdown("## ⚠️ Sales Forecasting Module Not Available")
469
+ gr.Markdown("Please ensure all required dependencies are installed.")
 
 
 
 
 
 
 
 
 
470
 
471
  # Footer
472
+ try:
473
+ if SALES_MODULE_AVAILABLE:
474
+ ui.create_footer(
475
+ logo_path="static/intelligent_retail.png",
476
+ creator_name="Thi-Diem-My Le",
477
+ creator_link="https://beacons.ai/elizabethmyn",
478
+ org_name="AI VIET NAM",
479
+ org_link="https://aivietnam.edu.vn/"
480
+ )
481
+ else:
482
+ gr.Markdown("---")
483
+ gr.Markdown("### Created by Thi-Diem-My Le | AI VIET NAM")
484
+ except Exception as e:
485
+ print(f"Warning: Could not create footer: {e}")
486
  gr.Markdown("© 2025 Intelligent Retail System. All rights reserved.")
487
 
488
  return demo
 
491
  # ==================== Application Entry Point ====================
492
  def main():
493
  """Main entry point"""
494
+ demo = create_gradio_interface()
495
 
496
  # Custom CSS
497
  custom_css = """
 
500
  .error {color: red; font-weight: bold;}
501
  """
502
 
503
+ custom_css = ui.get_custom_css() + custom_css
 
504
 
505
  demo.launch(
506
  server_name="0.0.0.0",
app/frontend/ui_template.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ from pathlib import Path
4
+ from typing import Optional, Dict, List, Tuple
5
+ import gradio as gr
6
+
7
+
8
+ class ThemeConfig:
9
+ """Centralized theme configuration with validation."""
10
+
11
+ def __init__(self):
12
+ # Default color palette
13
+ self.primary_color = "#0F6CBD"
14
+ self.accent_color = "#C4314B"
15
+ self.success_color = "#2E7D32"
16
+ self.bg1 = "#F0F7FF"
17
+ self.bg2 = "#E8F0FA"
18
+ self.bg3 = "#DDE7F8"
19
+ self.font_family = (
20
+ "'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', "
21
+ "Roboto, 'Helvetica Neue', Arial, sans-serif"
22
+ )
23
+
24
+ # Metadata
25
+ self.project_name = "Heart Project"
26
+ self.year = "2025"
27
+ self.about = ""
28
+ self.description = ""
29
+ self.meta_items: List[Tuple[str, str]] = []
30
+
31
+ # Cache for CSS
32
+ self._css_cache: Optional[str] = None
33
+
34
+ def update_colors(self, **kwargs) -> None:
35
+ """Update color scheme with validation."""
36
+ valid_keys = {'primary', 'accent', 'success', 'bg1', 'bg2', 'bg3'}
37
+ for key, value in kwargs.items():
38
+ if key not in valid_keys or value is None:
39
+ continue
40
+ if not self._is_valid_color(value):
41
+ raise ValueError(f"Invalid color format for {key}: {value}")
42
+ setattr(self, f"{key}_color" if not key.startswith('bg') else key, value)
43
+ self._invalidate_cache()
44
+
45
+ def update_font(self, font_family: str) -> None:
46
+ """Update font family."""
47
+ if font_family and isinstance(font_family, str):
48
+ self.font_family = font_family
49
+ self._invalidate_cache()
50
+
51
+ def update_meta(self, project_name: Optional[str] = None,
52
+ year: Optional[str] = None,
53
+ about: Optional[str] = None,
54
+ description: Optional[str] = None,
55
+ meta_items: Optional[List[Tuple[str, str]]] = None) -> None:
56
+ """Update metadata."""
57
+ if project_name is not None:
58
+ self.project_name = project_name
59
+ if year is not None:
60
+ self.year = year
61
+ if about is not None:
62
+ self.about = about
63
+ if description is not None:
64
+ self.description = description
65
+ if meta_items is not None:
66
+ self.meta_items = meta_items
67
+
68
+ @staticmethod
69
+ def _is_valid_color(color: str) -> bool:
70
+ """Validate hex color format."""
71
+ return isinstance(color, str) and (
72
+ color.startswith('#') and len(color) in (4, 7, 9)
73
+ )
74
+
75
+ def _invalidate_cache(self) -> None:
76
+ """Clear CSS cache when theme changes."""
77
+ self._css_cache = None
78
+
79
+ def get_css(self) -> str:
80
+ """Get or generate CSS with caching."""
81
+ if self._css_cache is None:
82
+ self._css_cache = self._build_css()
83
+ return self._css_cache
84
+
85
+ def _build_css(self) -> str:
86
+ """Build the complete CSS string."""
87
+ return f"""
88
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
89
+
90
+ .gradio-container {{
91
+ min-height: 100vh !important;
92
+ width: 100vw !important;
93
+ margin: 0 !important;
94
+ padding: 0px !important;
95
+ background: linear-gradient(135deg, {self.bg1} 0%, {self.bg2} 50%, {self.bg3} 100%);
96
+ background-size: 600% 600%;
97
+ animation: gradientBG 7s ease infinite;
98
+ }}
99
+
100
+ /* Global font setup */
101
+ body, .gradio-container, .gr-block, .gr-markdown, .gr-button, .gr-input,
102
+ .gr-dropdown, .gr-number, .gr-plot, .gr-dataframe, .gr-accordion, .gr-form,
103
+ .gr-textbox, .gr-html, table, th, td, label, h1, h2, h3, h4, h5, h6, p, span, div {{
104
+ font-family: {self.font_family} !important;
105
+ }}
106
+
107
+ @keyframes gradientBG {{
108
+ 0% {{background-position: 0% 50%;}}
109
+ 50% {{background-position: 100% 50%;}}
110
+ 100% {{background-position: 0% 50%;}}
111
+ }}
112
+
113
+ /* Minimize spacing and padding */
114
+ .content-wrap {{
115
+ padding: 2px !important;
116
+ margin: 0 !important;
117
+ }}
118
+
119
+ /* Reduce component spacing */
120
+ .gr-row {{
121
+ gap: 5px !important;
122
+ margin: 2px 0 !important;
123
+ }}
124
+
125
+ .gr-column {{
126
+ gap: 4px !important;
127
+ padding: 4px !important;
128
+ }}
129
+
130
+ /* Accordion optimization */
131
+ .gr-accordion {{
132
+ margin: 4px 0 !important;
133
+ }}
134
+
135
+ .gr-accordion .gr-accordion-content {{
136
+ padding: 2px !important;
137
+ }}
138
+
139
+ /* Form elements spacing */
140
+ .gr-form {{
141
+ gap: 2px !important;
142
+ }}
143
+
144
+ /* Button styling */
145
+ .gr-button {{
146
+ margin: 2px 0 !important;
147
+ }}
148
+
149
+ /* DataFrame optimization */
150
+ .gr-dataframe {{
151
+ margin: 4px 0 !important;
152
+ }}
153
+
154
+ /* Remove horizontal scroll from data preview */
155
+ .gr-dataframe .wrap {{
156
+ overflow-x: auto !important;
157
+ max-width: 100% !important;
158
+ }}
159
+
160
+ /* Plot optimization */
161
+ .gr-plot {{
162
+ margin: 4px 0 !important;
163
+ }}
164
+
165
+ /* Reduce markdown margins */
166
+ .gr-markdown {{
167
+ margin: 2px 0 !important;
168
+ }}
169
+
170
+ /* Footer positioning */
171
+ .sticky-footer {{
172
+ position: fixed;
173
+ bottom: 0px;
174
+ left: 0;
175
+ width: 100%;
176
+ background: {self.bg1};
177
+ padding: 6px !important;
178
+ box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
179
+ z-index: 1000;
180
+ }}
181
+ """
182
+
183
+
184
+ # Global theme instance
185
+ _theme = ThemeConfig()
186
+
187
+
188
+ def configure(project_name: Optional[str] = None,
189
+ year: Optional[str] = None,
190
+ about: Optional[str] = None,
191
+ description: Optional[str] = None,
192
+ colors: Optional[Dict[str, str]] = None,
193
+ font_family: Optional[str] = None,
194
+ meta_items: Optional[List[Tuple[str, str]]] = None) -> None:
195
+ """
196
+ One-call configuration for the entire theme.
197
+
198
+ Args:
199
+ project_name: Name of the project
200
+ year: Project year
201
+ about: About project
202
+ description: Project description
203
+ colors: Dict with keys: primary, accent, success, bg1, bg2, bg3
204
+ font_family: CSS font family string
205
+ meta_items: List of (label, value) tuples for metadata
206
+ """
207
+ if colors:
208
+ _theme.update_colors(**colors)
209
+ if font_family:
210
+ _theme.update_font(font_family)
211
+ _theme.update_meta(project_name, year, about, description, meta_items)
212
+
213
+
214
+ def get_custom_css() -> str:
215
+ """Get the current custom CSS."""
216
+ return _theme.get_css()
217
+
218
+
219
+ def _image_to_base64(image_path: str) -> str:
220
+ """
221
+ Convert image to base64 string with better error handling.
222
+
223
+ Args:
224
+ image_path: Relative path to image file
225
+
226
+ Returns:
227
+ Base64 encoded string
228
+
229
+ Raises:
230
+ FileNotFoundError: If image file doesn't exist
231
+ """
232
+ current_dir = Path(__file__).parent
233
+ full_path = current_dir / image_path
234
+
235
+ if not full_path.exists():
236
+ raise FileNotFoundError(f"Image not found: {full_path}")
237
+
238
+ with open(full_path, "rb") as f:
239
+ return base64.b64encode(f.read()).decode("utf-8")
240
+
241
+
242
+ def create_header(logo_path: str = "static/intelligent_retail.png") -> None:
243
+ """
244
+ Create a header with logo and project name.
245
+
246
+ Args:
247
+ logo_path: Path to logo image
248
+ """
249
+ with gr.Row():
250
+ with gr.Column(scale=2):
251
+ try:
252
+ logo_base64 = _image_to_base64(logo_path)
253
+ gr.HTML(
254
+ f"""<img src="data:image/png;base64,{logo_base64}"
255
+ alt="Logo"
256
+ style="height:100px;width:auto;margin:0 auto;margin-bottom:18px;display:block;">"""
257
+ )
258
+ except FileNotFoundError:
259
+ gr.HTML("<div style='text-align:center;color:#999;'>Logo not found</div>")
260
+
261
+ with gr.Column(scale=2):
262
+ gr.HTML(f"""
263
+ <div style="display:flex;justify-content:flex-start;align-items:center;gap:30px;">
264
+ <div>
265
+ <h1 style="margin-bottom:0;color:{_theme.primary_color};font-size:2.32em;font-weight:bold;">
266
+ {_theme.project_name}
267
+ </h1>
268
+ <p style="margin-top:4px;font-size:1.1em;color:#555;">{_theme.about}</p>
269
+ </div>
270
+ </div>
271
+ """)
272
+
273
+
274
+ def create_footer(logo_path: str = "static/intelligent_retail.png",
275
+ creator_name: str = "Thi-Diem-My Le",
276
+ creator_link: str = "https://beacons.ai/elizabethmyn",
277
+ org_name: str = "AI VIET NAM",
278
+ org_link: str = "https://aivietnam.edu.vn/") -> gr.HTML:
279
+ """
280
+ Create a sticky footer with creator information.
281
+
282
+ Args:
283
+ logo_path: Path to logo image
284
+ creator_name: Name of creator
285
+ creator_link: Link to creator profile
286
+ org_name: Organization name
287
+ org_link: Link to organization
288
+
289
+ Returns:
290
+ Gradio HTML component
291
+ """
292
+ try:
293
+ logo_base64 = _image_to_base64(logo_path)
294
+ logo_html = f'<img src="data:image/png;base64,{logo_base64}" alt="Logo" style="height:0px;width:auto;">'
295
+ except FileNotFoundError:
296
+ logo_html = ""
297
+
298
+ footer_html = f"""
299
+ <style>
300
+ .sticky-footer{{
301
+ position:fixed;
302
+ bottom:0px;
303
+ left:0;
304
+ width:100%;
305
+ background:#E8F5E8;
306
+ padding:10px;
307
+ box-shadow:0 -2px 10px rgba(0,0,0,0.1);
308
+ z-index:1000;
309
+ }}
310
+ .content-wrap{{padding-bottom:60px;}}
311
+ </style>
312
+ <div class="sticky-footer">
313
+ <div style="text-align:center;font-size:18px;color:#888">
314
+ Created by
315
+ <a href="{creator_link}" target="_blank"
316
+ style="color:#465C88;text-decoration:none;font-weight:bold;display:inline-flex;align-items:center;">
317
+ {creator_name}
318
+ {logo_html}
319
+ </a>
320
+ from
321
+ <a href="{org_link}" target="_blank"
322
+ style="color:#355724;text-decoration:none;font-weight:bold;">
323
+ {org_name}
324
+ </a>
325
+ </div>
326
+ </div>
327
+ """
328
+ return gr.HTML(footer_html)
329
+
330
+
331
+ def render_info_card(description: Optional[str] = None,
332
+ meta_items: Optional[List[Tuple[str, str]]] = None,
333
+ icon: str = "🧠",
334
+ title: str = "About this demo") -> str:
335
+ """
336
+ Render an informational card.
337
+
338
+ Args:
339
+ description: Card description text
340
+ meta_items: List of (label, value) tuples
341
+ icon: Emoji or icon for the card
342
+ title: Card title
343
+
344
+ Returns:
345
+ HTML string for the card
346
+ """
347
+ desc = description if description is not None else _theme.description
348
+ items = meta_items if meta_items is not None else _theme.meta_items
349
+
350
+ meta_html = ""
351
+ if items:
352
+ meta_html = "".join([f"<span><strong>{k}</strong>: {v}</span><br>" for k, v in items])
353
+
354
+ return f"""
355
+ <div style="margin:8px 0 8px 0;">
356
+ <div style="background:#F5F9FF;border-left:6px solid {_theme.primary_color};
357
+ padding:14px 16px;border-radius:10px;box-shadow:0 1px 3px rgba(0,0,0,0.06);">
358
+ <div style="display:flex;gap:14px;align-items:flex-start;">
359
+ <div style="font-size:22px;">{icon}</div>
360
+ <div>
361
+ <div style="font-weight:700;color:{_theme.primary_color};margin-bottom:4px;">{title}</div>
362
+ <div style="color:#000;font-size:14px;line-height:1.5;">{desc}</div>
363
+ {f'<div style="margin-top:8px;color:#000;font-size:13px;">{meta_html}</div>' if meta_html else ''}
364
+ </div>
365
+ </div>
366
+ </div>
367
+ </div>
368
+ """
369
+
370
+
371
+ def render_disclaimer(text: str,
372
+ icon: str = "⚠️",
373
+ title: str = "Educational Use Only") -> str:
374
+ """
375
+ Render a disclaimer/warning card.
376
+
377
+ Args:
378
+ text: Warning text
379
+ icon: Warning icon/emoji
380
+ title: Warning title
381
+
382
+ Returns:
383
+ HTML string for the disclaimer
384
+ """
385
+ return f"""
386
+ <div style="margin:8px 0 6px 0;">
387
+ <div style="background:#FFF4F4;border-left:6px solid {_theme.accent_color};
388
+ padding:12px 16px;border-radius:8px;box-shadow:0 1px 3px rgba(0,0,0,0.06);">
389
+ <div style="display:flex;gap:10px;align-items:flex-start;color:#000;">
390
+ <span style="font-size:20px">{icon}</span>
391
+ <div>
392
+ <div style="font-weight:700;margin-bottom:4px;">{title}</div>
393
+ <div style="font-size:14px;line-height:1.4;">{text}</div>
394
+ </div>
395
+ </div>
396
+ </div>
397
+ </div>
398
+ """
399
+
400
+
401
+ # Backward compatibility - expose old function names
402
+ def set_colors(**kwargs):
403
+ """Legacy function - use configure() instead."""
404
+ _theme.update_colors(**kwargs)
405
+
406
+
407
+ def set_font(font_family: str):
408
+ """Legacy function - use configure() instead."""
409
+ _theme.update_font(font_family)
410
+
411
+
412
+ def set_meta(**kwargs):
413
+ """Legacy function - use configure() instead."""
414
+ _theme.update_meta(**kwargs)
415
+
416
+
417
+ # Expose custom_css as a property for backward compatibility
418
+ @property
419
+ def custom_css():
420
+ return _theme.get_css()
app/main.py CHANGED
@@ -7,6 +7,9 @@ from app.core.config import settings
7
  from app.api.routes import detection, health
8
  from app.utils.logger import logger
9
 
 
 
 
10
  # Create FastAPI application
11
  app = FastAPI(
12
  title=settings.APP_NAME,
@@ -17,6 +20,12 @@ app = FastAPI(
17
  openapi_url="/api/openapi.json"
18
  )
19
 
 
 
 
 
 
 
20
  # Add CORS middleware
21
  app.add_middleware(
22
  CORSMiddleware,
 
7
  from app.api.routes import detection, health
8
  from app.utils.logger import logger
9
 
10
+ from datetime import timedelta
11
+ from app.core.security import create_access_token
12
+
13
  # Create FastAPI application
14
  app = FastAPI(
15
  title=settings.APP_NAME,
 
20
  openapi_url="/api/openapi.json"
21
  )
22
 
23
+ access_token = create_access_token(
24
+ data={"sub": "test_user"},
25
+ expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
26
+ )
27
+ print(f"Generated test token: {access_token}")
28
+
29
  # Add CORS middleware
30
  app.add_middleware(
31
  CORSMiddleware,
app/services/prediction.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ import seaborn as sns
7
+ import gradio as gr
8
+
9
+
10
+ def sales_prediction_view(data, model, feature_stats, feature_engineered_data):
11
+ """Display the sales prediction tool interface"""
12
+
13
+ if model is None:
14
+ return gr.Interface(
15
+ fn=lambda: "Model not loaded. Please check if the model file exists.",
16
+ inputs=[],
17
+ outputs=gr.Textbox(label="Error"),
18
+ title="Sales Prediction Tool"
19
+ )
20
+
21
+ if feature_engineered_data.empty:
22
+ return gr.Interface(
23
+ fn=lambda: "Feature engineered data not loaded.",
24
+ inputs=[],
25
+ outputs=gr.Textbox(label="Error"),
26
+ title="Sales Prediction Tool"
27
+ )
28
+
29
+ # Determine store and item column names
30
+ store_col = "store_id" if "store_id" in feature_engineered_data.columns else "store"
31
+ item_col = "item_id" if "item_id" in feature_engineered_data.columns else "item"
32
+
33
+ # Check for store/item name columns
34
+ has_store_names = "store_name" in feature_engineered_data.columns
35
+ has_item_names = "item_name" in feature_engineered_data.columns
36
+
37
+ # Create mapping dictionaries for names if available
38
+ store_names, item_names = create_name_mappings(
39
+ feature_engineered_data, store_col, item_col, has_store_names, has_item_names
40
+ )
41
+
42
+ # Get unique store and item lists
43
+ stores = sorted(feature_engineered_data[store_col].unique())
44
+
45
+ # Create store options
46
+ if has_store_names:
47
+ store_options = [f"{store_id} - {store_names[store_id]}" for store_id in stores]
48
+ else:
49
+ store_options = stores
50
+
51
+ def update_items(store_selection):
52
+ """Update item dropdown based on selected store"""
53
+ if has_store_names:
54
+ store_id = int(store_selection.split(" - ")[0])
55
+ else:
56
+ store_id = store_selection
57
+
58
+ store_items = feature_engineered_data[feature_engineered_data[store_col] == store_id][item_col].unique()
59
+
60
+ if has_item_names:
61
+ item_options = [
62
+ f"{item_id} - {item_names[item_id]}"
63
+ for item_id in store_items
64
+ if item_id in item_names
65
+ ]
66
+ else:
67
+ item_options = sorted(store_items)
68
+
69
+ return gr.Dropdown(choices=item_options)
70
+
71
+ def predict_sales(store_selection, item_selection, prediction_date, is_holiday,
72
+ special_event, promotion_impact, event_impact, clearance_impact,
73
+ launch_impact, temperature, weather_condition, humidity,
74
+ competition_level, supply_chain):
75
+ """Wrapper function for prediction with all inputs"""
76
+
77
+ # Parse store and item IDs
78
+ if has_store_names:
79
+ store_id = int(store_selection.split(" - ")[0])
80
+ else:
81
+ store_id = store_selection
82
+
83
+ if has_item_names:
84
+ item_id = int(item_selection.split(" - ")[0])
85
+ else:
86
+ item_id = item_selection
87
+
88
+ # Collect prediction inputs
89
+ prediction_inputs = collect_prediction_inputs_from_values(
90
+ prediction_date, is_holiday, special_event, promotion_impact,
91
+ event_impact, clearance_impact, launch_impact, temperature,
92
+ weather_condition, humidity, competition_level, supply_chain
93
+ )
94
+
95
+ # Generate prediction and return results
96
+ return generate_prediction(
97
+ feature_engineered_data,
98
+ model,
99
+ store_id,
100
+ item_id,
101
+ store_col,
102
+ item_col,
103
+ prediction_inputs,
104
+ has_store_names,
105
+ has_item_names,
106
+ store_names,
107
+ item_names,
108
+ )
109
+
110
+ # Get initial items for first store
111
+ initial_store = store_options[0] if store_options else None
112
+ if initial_store:
113
+ if has_store_names:
114
+ initial_store_id = int(initial_store.split(" - ")[0])
115
+ else:
116
+ initial_store_id = initial_store
117
+
118
+ initial_items = feature_engineered_data[feature_engineered_data[store_col] == initial_store_id][item_col].unique()
119
+
120
+ if has_item_names:
121
+ initial_item_options = [
122
+ f"{item_id} - {item_names[item_id]}"
123
+ for item_id in initial_items
124
+ if item_id in item_names
125
+ ]
126
+ else:
127
+ initial_item_options = sorted(initial_items)
128
+ else:
129
+ initial_item_options = []
130
+
131
+ # Build Gradio interface
132
+ with gr.Blocks(title="Sales Prediction Tool") as demo:
133
+ gr.Markdown("# Sales Prediction Tool")
134
+
135
+ with gr.Row():
136
+ # Left column - Product Selection
137
+ with gr.Column(scale=1):
138
+ gr.Markdown("## Product Selection")
139
+ store_dropdown = gr.Dropdown(
140
+ choices=store_options,
141
+ label="Select Store",
142
+ value=initial_store,
143
+ interactive=True,
144
+ allow_custom_value=False
145
+ )
146
+ item_dropdown = gr.Dropdown(
147
+ choices=initial_item_options,
148
+ label="Select Product",
149
+ value=initial_item_options[0] if initial_item_options else None,
150
+ interactive=True,
151
+ allow_custom_value=False
152
+ )
153
+
154
+ # Update items when store changes
155
+ store_dropdown.change(
156
+ fn=update_items,
157
+ inputs=[store_dropdown],
158
+ outputs=[item_dropdown]
159
+ )
160
+
161
+ # Right column - Prediction Parameters
162
+ with gr.Column(scale=2):
163
+ gr.Markdown("## Prediction Parameters")
164
+
165
+ with gr.Row():
166
+ with gr.Column():
167
+ prediction_date = gr.Textbox(
168
+ label="Prediction Date (YYYY-MM-DD)",
169
+ value=(datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d"),
170
+ interactive=True
171
+ )
172
+ is_holiday = gr.Checkbox(label="Holiday", value=False, interactive=True)
173
+ special_event = gr.Dropdown(
174
+ choices=["None", "Sale/Promotion", "Local Event",
175
+ "Inventory Clearance", "New Product Launch"],
176
+ label="Special Event",
177
+ value="None",
178
+ interactive=True
179
+ )
180
+ promotion_impact = gr.Slider(-50, 100, value=20, label="Promotion Impact (%)", interactive=True)
181
+ event_impact = gr.Slider(-20, 50, value=10, label="Event Impact (%)", interactive=True)
182
+ clearance_impact = gr.Slider(-70, 30, value=-10, label="Clearance Impact (%)", interactive=True)
183
+ launch_impact = gr.Slider(0, 200, value=50, label="Launch Impact (%)", interactive=True)
184
+
185
+ with gr.Column():
186
+ temperature = gr.Slider(-10.0, 40.0, value=20.0, label="Temperature (°C)", interactive=True)
187
+ weather_condition = gr.Dropdown(
188
+ choices=["Clear", "Cloudy", "Rainy", "Snowy", "Stormy"],
189
+ label="Weather Condition",
190
+ value="Clear",
191
+ interactive=True
192
+ )
193
+ gr.Markdown("*Note: Weather impacts vary by product category*")
194
+
195
+ with gr.Column():
196
+ humidity = gr.Slider(0, 100, value=50, label="Humidity (%)", interactive=True)
197
+ competition_level = gr.Radio(
198
+ choices=["Low", "Medium", "High"],
199
+ label="Competition Level",
200
+ value="Medium",
201
+ interactive=True
202
+ )
203
+ supply_chain = gr.Radio(
204
+ choices=["Constrained", "Normal", "Abundant"],
205
+ label="Supply Chain Status",
206
+ value="Normal",
207
+ interactive=True
208
+ )
209
+
210
+ predict_btn = gr.Button("Predict Sales", variant="primary")
211
+
212
+ # Output section
213
+ gr.Markdown("## Prediction Results")
214
+ with gr.Row():
215
+ result_text = gr.Textbox(label="Results", lines=10)
216
+ result_plot1 = gr.Plot(label="Sales History")
217
+
218
+ with gr.Row():
219
+ result_plot2 = gr.Plot(label="Weekly Pattern")
220
+ result_plot3 = gr.Plot(label="Feature Importance")
221
+
222
+ # Connect button to prediction function
223
+ predict_btn.click(
224
+ fn=predict_sales,
225
+ inputs=[
226
+ store_dropdown, item_dropdown, prediction_date, is_holiday,
227
+ special_event, promotion_impact, event_impact, clearance_impact,
228
+ launch_impact, temperature, weather_condition, humidity,
229
+ competition_level, supply_chain
230
+ ],
231
+ outputs=[result_text, result_plot1, result_plot2, result_plot3]
232
+ )
233
+
234
+ return demo
235
+
236
+
237
+ def create_name_mappings(df, store_col, item_col, has_store_names, has_item_names):
238
+ """Create mapping dictionaries for store and item names"""
239
+
240
+ store_names = {}
241
+ item_names = {}
242
+
243
+ if has_store_names:
244
+ # Create store ID to name mapping
245
+ for _, row in df[[store_col, "store_name"]].drop_duplicates().iterrows():
246
+ store_names[row[store_col]] = row["store_name"]
247
+
248
+ if has_item_names:
249
+ # Create item ID to name mapping
250
+ for _, row in df[[item_col, "item_name"]].drop_duplicates().iterrows():
251
+ item_names[row[item_col]] = row["item_name"]
252
+
253
+ return store_names, item_names
254
+
255
+
256
+ def create_product_selection_sidebar(
257
+ df,
258
+ stores,
259
+ store_col,
260
+ item_col,
261
+ has_store_names,
262
+ has_item_names,
263
+ store_names,
264
+ item_names,
265
+ ):
266
+ """Create sidebar for store and product selection"""
267
+ # This function is kept for compatibility but not used in Gradio version
268
+ # The logic is integrated into sales_prediction_view
269
+ pass
270
+
271
+
272
+ def collect_prediction_inputs():
273
+ """Collect all prediction inputs from the user"""
274
+ # This function is kept for compatibility but adapted for Gradio
275
+ # See collect_prediction_inputs_from_values instead
276
+ pass
277
+
278
+
279
+ def collect_prediction_inputs_from_values(
280
+ prediction_date_str, is_holiday, special_event, promotion_impact,
281
+ event_impact, clearance_impact, launch_impact, temperature,
282
+ weather_condition, humidity, competition_level, supply_chain
283
+ ):
284
+ """Collect all prediction inputs from provided values"""
285
+
286
+ # Parse date
287
+ prediction_date = datetime.strptime(prediction_date_str, "%Y-%m-%d").date()
288
+
289
+ # Calculate special event factor
290
+ special_event_factor = 1.0
291
+ if special_event == "Sale/Promotion":
292
+ special_event_factor = promotion_impact / 100 + 1.0
293
+ elif special_event == "Local Event":
294
+ special_event_factor = event_impact / 100 + 1.0
295
+ elif special_event == "Inventory Clearance":
296
+ special_event_factor = clearance_impact / 100 + 1.0
297
+ elif special_event == "New Product Launch":
298
+ special_event_factor = launch_impact / 100 + 1.0
299
+
300
+ # Determine temperature category
301
+ if temperature < 15:
302
+ temp_category = "Cool"
303
+ elif temperature < 25:
304
+ temp_category = "Warm"
305
+ else:
306
+ temp_category = "Hot"
307
+
308
+ # Determine humidity level
309
+ if humidity < 40:
310
+ humidity_level = "Low"
311
+ elif humidity < 70:
312
+ humidity_level = "Medium"
313
+ else:
314
+ humidity_level = "High"
315
+
316
+ # Calculate derived parameters
317
+ month = prediction_date.month
318
+ if month in [3, 4, 5]:
319
+ season = "spring"
320
+ elif month in [6, 7, 8]:
321
+ season = "summer"
322
+ elif month in [9, 10, 11]:
323
+ season = "fall"
324
+ else:
325
+ season = "winter"
326
+
327
+ quarter = (prediction_date.month - 1) // 3 + 1
328
+ day_of_week = prediction_date.weekday()
329
+ is_weekend = 1 if day_of_week >= 5 else 0
330
+
331
+ # Calculate factors
332
+ weather_factor = {
333
+ "Clear": 1.0,
334
+ "Cloudy": 0.95,
335
+ "Rainy": 0.9,
336
+ "Snowy": 0.8,
337
+ "Stormy": 0.7,
338
+ }
339
+
340
+ competition_factor = {"Low": 1.1, "Medium": 1.0, "High": 0.9}
341
+ supply_factor = {"Constrained": 0.9, "Normal": 1.0, "Abundant": 1.05}
342
+ weekend_factor = 1.15 if is_weekend else 1.0
343
+
344
+ # Combined adjustment factor
345
+ adjustment_factor = (
346
+ special_event_factor
347
+ * weather_factor.get(weather_condition, 1.0)
348
+ * competition_factor.get(competition_level, 1.0)
349
+ * supply_factor.get(supply_chain, 1.0)
350
+ * weekend_factor
351
+ )
352
+
353
+ return {
354
+ "date": prediction_date,
355
+ "is_holiday": is_holiday,
356
+ "temperature": temperature,
357
+ "temp_category": temp_category,
358
+ "humidity": humidity,
359
+ "humidity_level": humidity_level,
360
+ "season": season,
361
+ "quarter": quarter,
362
+ "day_of_week": day_of_week,
363
+ "is_weekend": is_weekend,
364
+ "special_event": special_event,
365
+ "weather_condition": weather_condition,
366
+ "competition_level": competition_level,
367
+ "supply_chain": supply_chain,
368
+ "adjustment_factor": adjustment_factor,
369
+ }
370
+
371
+
372
+ def generate_prediction(
373
+ feature_engineered_data,
374
+ model,
375
+ store_id,
376
+ item_id,
377
+ store_col,
378
+ item_col,
379
+ prediction_inputs,
380
+ has_store_names,
381
+ has_item_names,
382
+ store_names,
383
+ item_names,
384
+ ):
385
+ """Generate sales prediction and display results"""
386
+
387
+ try:
388
+ # Find recent samples for the same store-item combination
389
+ recent_samples = (
390
+ feature_engineered_data[
391
+ (feature_engineered_data[store_col] == store_id)
392
+ & (feature_engineered_data[item_col] == item_id)
393
+ ]
394
+ .sort_values("date", ascending=False)
395
+ .head(5)
396
+ )
397
+
398
+ if recent_samples.empty:
399
+ return "No historical data found for this product-store combination.", None, None, None
400
+
401
+ # Create input based on most recent sample
402
+ input_row = prepare_prediction_input(recent_samples, prediction_inputs)
403
+
404
+ # Create DataFrame for prediction
405
+ input_df = pd.DataFrame([input_row])
406
+
407
+ # Get the features that the model expects
408
+ if hasattr(model, "feature_name_"):
409
+ model_features = model.feature_name_
410
+ else:
411
+ model_features = [
412
+ col
413
+ for col in input_df.columns
414
+ if col
415
+ not in ["sales", "date", "variation_factor", "adjustment_factor"]
416
+ ]
417
+
418
+ # Select only the features used by the model
419
+ X_pred = input_df[model_features]
420
+
421
+ # Make prediction
422
+ base_prediction = model.predict(X_pred)[0]
423
+
424
+ # Apply adjustment factors
425
+ adjusted_prediction = base_prediction
426
+
427
+ # Apply the variation factor if it exists
428
+ if "variation_factor" in input_row:
429
+ adjusted_prediction *= input_row["variation_factor"]
430
+
431
+ # Apply adjustment factor from user inputs
432
+ if "adjustment_factor" in prediction_inputs:
433
+ adjusted_prediction *= prediction_inputs["adjustment_factor"]
434
+
435
+ # Display results
436
+ result_text, plot1, plot2, plot3 = display_prediction_results(
437
+ adjusted_prediction,
438
+ base_prediction,
439
+ store_id,
440
+ item_id,
441
+ prediction_inputs,
442
+ feature_engineered_data,
443
+ store_col,
444
+ item_col,
445
+ has_store_names,
446
+ has_item_names,
447
+ store_names,
448
+ item_names,
449
+ model,
450
+ model_features,
451
+ )
452
+
453
+ return result_text, plot1, plot2, plot3
454
+
455
+ except Exception as e:
456
+ import traceback
457
+ error_msg = f"Error making prediction: {str(e)}\n\n{traceback.format_exc()}"
458
+ return error_msg, None, None, None
459
+
460
+
461
+ def prepare_prediction_input(recent_samples, prediction_inputs):
462
+ """Prepare input row for prediction based on recent sample and user inputs"""
463
+
464
+ # Create input row based on most recent sample
465
+ input_row = recent_samples.iloc[0].copy()
466
+
467
+ # Update with user inputs
468
+ input_row["date"] = pd.to_datetime(prediction_inputs["date"])
469
+ input_row["day"] = prediction_inputs["date"].day
470
+ input_row["month"] = prediction_inputs["date"].month
471
+ input_row["year"] = prediction_inputs["date"].year
472
+ input_row["quarter"] = prediction_inputs["quarter"]
473
+ input_row["is_holiday"] = int(prediction_inputs["is_holiday"])
474
+
475
+ # Add day of week information
476
+ input_row["day_of_week"] = input_row["date"].dayofweek
477
+ input_row["day_of_month"] = input_row["date"].day
478
+ input_row["is_weekend"] = 1 if input_row["day_of_week"] >= 5 else 0
479
+
480
+ # Update actual temperature and humidity values if they exist in the dataframe
481
+ if "temperature" in input_row:
482
+ input_row["temperature"] = prediction_inputs["temperature"]
483
+
484
+ if "humidity" in input_row:
485
+ input_row["humidity"] = prediction_inputs["humidity"]
486
+
487
+ # Update temperature and humidity categories
488
+ for category in ["Cool", "Warm", "Hot"]:
489
+ if f"temp_category_{category}" in input_row:
490
+ input_row[f"temp_category_{category}"] = (
491
+ 1 if category == prediction_inputs["temp_category"] else 0
492
+ )
493
+
494
+ for level in ["Low", "Medium", "High"]:
495
+ if f"humidity_level_{level}" in input_row:
496
+ input_row[f"humidity_level_{level}"] = (
497
+ 1 if level == prediction_inputs["humidity_level"] else 0
498
+ )
499
+
500
+ # Update season
501
+ for s in ["spring", "summer", "fall", "winter", "wet"]:
502
+ if f"season_{s}" in input_row:
503
+ input_row[f"season_{s}"] = 1 if s == prediction_inputs["season"] else 0
504
+
505
+ # Set a random variation factor
506
+ variation_factor = 1.0 + np.random.uniform(-0.02, 0.02)
507
+ input_row["variation_factor"] = variation_factor
508
+
509
+ return input_row
510
+
511
+
512
+ def display_prediction_results(
513
+ prediction_value,
514
+ base_prediction,
515
+ store_id,
516
+ item_id,
517
+ prediction_inputs,
518
+ historical_data,
519
+ store_col,
520
+ item_col,
521
+ has_store_names,
522
+ has_item_names,
523
+ store_names,
524
+ item_names,
525
+ model,
526
+ model_features,
527
+ ):
528
+ """Display prediction results with visualizations"""
529
+
530
+ # Build result text
531
+ result_lines = []
532
+ result_lines.append("=" * 50)
533
+ result_lines.append("PREDICTION RESULTS")
534
+ result_lines.append("=" * 50)
535
+ result_lines.append(f"\nPredicted Sales: ${prediction_value:,.2f}")
536
+
537
+ if has_store_names:
538
+ result_lines.append(f"Store: {store_names[store_id]}")
539
+ else:
540
+ result_lines.append(f"Store ID: {store_id}")
541
+
542
+ if has_item_names:
543
+ result_lines.append(f"Product: {item_names[item_id]}")
544
+ else:
545
+ result_lines.append(f"Product ID: {item_id}")
546
+
547
+ result_lines.append(f"Date: {prediction_inputs['date'].strftime('%B %d, %Y')}")
548
+ result_lines.append(f"Season: {prediction_inputs['season'].capitalize()}")
549
+ if prediction_inputs["is_holiday"]:
550
+ result_lines.append("Holiday: Yes")
551
+
552
+ # Adjustment details
553
+ result_lines.append(f"\n{'='*50}")
554
+ result_lines.append("ADJUSTMENT DETAILS")
555
+ result_lines.append("="*50)
556
+ result_lines.append(f"Base prediction: ${base_prediction:.2f}")
557
+ result_lines.append(f"Final prediction: ${prediction_value:.2f}")
558
+ result_lines.append(f"Total adjustment: {prediction_inputs['adjustment_factor']:.2f}x")
559
+ result_lines.append(f"\nEvent: {prediction_inputs['special_event']}")
560
+ result_lines.append(f"Weather: {prediction_inputs['weather_condition']}")
561
+ result_lines.append(f"Competition: {prediction_inputs['competition_level']}")
562
+ result_lines.append(f"Supply: {prediction_inputs['supply_chain']}")
563
+ result_lines.append(f"Weekend: {'Yes' if prediction_inputs['is_weekend'] else 'No'}")
564
+ result_lines.append(f"Holiday: {'Yes' if prediction_inputs['is_holiday'] else 'No'}")
565
+
566
+ # Get historical context
567
+ historical = historical_data[
568
+ (historical_data[store_col] == store_id)
569
+ & (historical_data[item_col] == item_id)
570
+ ].sort_values("date")
571
+
572
+ if "sales" in historical.columns and len(historical) > 0:
573
+ last_value = historical["sales"].iloc[-1]
574
+ last_date = historical["date"].iloc[-1]
575
+ avg_sales = historical["sales"].mean()
576
+ max_sales = historical["sales"].max()
577
+ max_date = historical.loc[historical["sales"].idxmax(), "date"]
578
+
579
+ result_lines.append(f"\n{'='*50}")
580
+ result_lines.append("HISTORICAL CONTEXT")
581
+ result_lines.append("="*50)
582
+ result_lines.append(f"Historical Average: ${avg_sales:,.2f}")
583
+ result_lines.append(f"Period: {historical['date'].min().strftime('%b %d, %Y')} to {historical['date'].max().strftime('%b %d, %Y')}")
584
+ result_lines.append(f"\nLast Recorded Sales: ${last_value:,.2f}")
585
+ result_lines.append(f"Date: {last_date.strftime('%b %d, %Y')}")
586
+ result_lines.append(f"\nHistorical Maximum: ${max_sales:,.2f}")
587
+ result_lines.append(f"Date: {max_date.strftime('%b %d, %Y')}")
588
+
589
+ result_text = "\n".join(result_lines)
590
+
591
+ # Create visualizations
592
+ plot1 = display_historical_context(historical, prediction_inputs["date"], prediction_value)
593
+ plot2 = display_weekly_pattern(historical, prediction_inputs["date"])
594
+ plot3 = display_feature_importance(model, model_features)
595
+
596
+ return result_text, plot1, plot2, plot3
597
+
598
+
599
+ def display_historical_context(historical_data, prediction_date, prediction_value):
600
+ """Display historical context visualizations"""
601
+
602
+ if "sales" not in historical_data.columns or historical_data.empty:
603
+ return None
604
+
605
+ # Limit to last 2 months
606
+ last_date = historical_data["date"].max()
607
+ two_months_ago = last_date - pd.Timedelta(days=60)
608
+ recent_history = historical_data[historical_data["date"] >= two_months_ago].copy()
609
+
610
+ if recent_history.empty:
611
+ return None
612
+
613
+ # Plot recent sales history
614
+ fig, ax = plt.subplots(figsize=(6, 2.5))
615
+
616
+ # Plot historical sales
617
+ ax.plot(
618
+ recent_history["date"],
619
+ recent_history["sales"],
620
+ "b-",
621
+ label="Sales",
622
+ )
623
+
624
+ # Add the prediction point
625
+ ax.scatter(
626
+ prediction_date,
627
+ prediction_value,
628
+ color="red",
629
+ s=60,
630
+ label="Prediction",
631
+ )
632
+
633
+ # Add moving average
634
+ if len(recent_history) > 7:
635
+ recent_history["MA7"] = recent_history["sales"].rolling(window=7).mean()
636
+ ax.plot(
637
+ recent_history["date"],
638
+ recent_history["MA7"],
639
+ "g--",
640
+ label="7-Day Avg",
641
+ )
642
+
643
+ ax.set_xlabel("")
644
+ ax.set_ylabel("Sales ($)")
645
+ ax.set_title("Last 60 Days Sales History")
646
+ ax.legend(loc="upper left", fontsize="x-small")
647
+ fig.autofmt_xdate(rotation=45)
648
+ fig.tight_layout()
649
+
650
+ return fig
651
+
652
+
653
+ def display_weekly_pattern(recent_history, prediction_date):
654
+ """Display weekly sales pattern visualization"""
655
+
656
+ if len(recent_history) < 7:
657
+ return None
658
+
659
+ # Add day of week
660
+ recent_history = recent_history.copy()
661
+ recent_history["day_of_week"] = recent_history["date"].dt.dayofweek
662
+ day_names = [
663
+ "Monday",
664
+ "Tuesday",
665
+ "Wednesday",
666
+ "Thursday",
667
+ "Friday",
668
+ "Saturday",
669
+ "Sunday",
670
+ ]
671
+
672
+ # Group by day of week
673
+ day_sales = recent_history.groupby("day_of_week")["sales"].mean()
674
+ day_sales_df = pd.DataFrame(
675
+ {
676
+ "day_name": [day_names[i] for i in range(7) if i in day_sales.index],
677
+ "sales": [day_sales[i] for i in range(7) if i in day_sales.index],
678
+ }
679
+ )
680
+
681
+ # Plot
682
+ fig, ax = plt.subplots(figsize=(6, 2.5))
683
+
684
+ # Plot day of week pattern
685
+ sns.barplot(x="day_name", y="sales", data=day_sales_df, ax=ax)
686
+
687
+ # Highlight the day of the prediction
688
+ prediction_day = prediction_date.weekday()
689
+ for i, patch in enumerate(ax.patches):
690
+ if day_sales_df.iloc[i]["day_name"] == day_names[prediction_day]:
691
+ patch.set_facecolor("red")
692
+
693
+ ax.set_xlabel("")
694
+ ax.set_ylabel("Avg Sales ($)")
695
+ ax.set_title("Sales by Day of Week")
696
+ plt.xticks(rotation=45, fontsize=8)
697
+ fig.tight_layout()
698
+
699
+ return fig
700
+
701
+
702
+ def display_feature_importance(model, model_features):
703
+ """Display feature importance visualization"""
704
+
705
+ if not hasattr(model, "feature_importances_"):
706
+ return None
707
+
708
+ # Get feature importances
709
+ importances = model.feature_importances_
710
+
711
+ # Create DataFrame with feature importances
712
+ importance_df = (
713
+ pd.DataFrame({"Feature": model_features, "Importance": importances})
714
+ .sort_values("Importance", ascending=False)
715
+ .head(8)
716
+ )
717
+
718
+ # Clean feature names for display
719
+ importance_df["Feature"] = importance_df["Feature"].apply(
720
+ lambda x: x.replace("_", " ").title()
721
+ )
722
+
723
+ # Plot feature importances
724
+ fig, ax = plt.subplots(figsize=(6, 2.5))
725
+ sns.barplot(x="Importance", y="Feature", data=importance_df, ax=ax)
726
+ ax.set_title("Top Factors Influencing Sales Prediction")
727
+ plt.xticks(fontsize=8)
728
+ plt.yticks(fontsize=8)
729
+ fig.tight_layout()
730
+
731
+ return fig
app/utils/data_generator.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime, timedelta
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ # Set random seed for reproducibility
8
+ np.random.seed(2025)
9
+
10
+
11
+ def generate_store_data():
12
+ """Generate store data"""
13
+
14
+ # Define provinces and stores
15
+ provinces = ["Hanoi", "Ho Chi Minh City"]
16
+
17
+ stores = [
18
+ # Hanoi stores
19
+ {"id": 1, "name": "Hoan Kiem Market", "province": "Hanoi"},
20
+ {"id": 2, "name": "Ba Dinh Supermarket", "province": "Hanoi"},
21
+ {"id": 3, "name": "Dong Da Mall", "province": "Hanoi"},
22
+ {"id": 4, "name": "Tay Ho Store", "province": "Hanoi"},
23
+ {"id": 5, "name": "Long Bien Shop", "province": "Hanoi"},
24
+ # Ho Chi Minh City stores
25
+ {"id": 6, "name": "District 1 Market", "province": "Ho Chi Minh City"},
26
+ {"id": 7, "name": "Ben Thanh Store", "province": "Ho Chi Minh City"},
27
+ {"id": 8, "name": "Saigon Supermarket", "province": "Ho Chi Minh City"},
28
+ {"id": 9, "name": "Phu Nhuan Shop", "province": "Ho Chi Minh City"},
29
+ {"id": 10, "name": "Binh Thanh Market", "province": "Ho Chi Minh City"},
30
+ ]
31
+
32
+ return provinces, stores
33
+
34
+
35
+ def generate_item_data():
36
+ """Generate item data"""
37
+
38
+ # Define categories and items
39
+ categories = [
40
+ "Staples",
41
+ "Dairy & Frozen",
42
+ "Beverages & Snacks",
43
+ "Household & Personal Care",
44
+ "Baby & Health",
45
+ ]
46
+
47
+ items = [
48
+ # Staples
49
+ {
50
+ "id": 1,
51
+ "name": "Rice",
52
+ "category": "Staples",
53
+ "base_price": 20.0,
54
+ "base_sales": 15,
55
+ "volatility": 0.3,
56
+ },
57
+ {
58
+ "id": 2,
59
+ "name": "Noodles",
60
+ "category": "Staples",
61
+ "base_price": 15.0,
62
+ "base_sales": 12,
63
+ "volatility": 0.25,
64
+ },
65
+ {
66
+ "id": 3,
67
+ "name": "Bread",
68
+ "category": "Staples",
69
+ "base_price": 10.0,
70
+ "base_sales": 20,
71
+ "volatility": 0.4,
72
+ },
73
+ {
74
+ "id": 4,
75
+ "name": "Flour",
76
+ "category": "Staples",
77
+ "base_price": 12.0,
78
+ "base_sales": 8,
79
+ "volatility": 0.2,
80
+ },
81
+ {
82
+ "id": 5,
83
+ "name": "Cooking Oil",
84
+ "category": "Staples",
85
+ "base_price": 25.0,
86
+ "base_sales": 10,
87
+ "volatility": 0.15,
88
+ },
89
+ {
90
+ "id": 6,
91
+ "name": "Sugar",
92
+ "category": "Staples",
93
+ "base_price": 8.0,
94
+ "base_sales": 7,
95
+ "volatility": 0.1,
96
+ },
97
+ # Dairy & Frozen
98
+ {
99
+ "id": 7,
100
+ "name": "Milk",
101
+ "category": "Dairy & Frozen",
102
+ "base_price": 18.0,
103
+ "base_sales": 30,
104
+ "volatility": 0.35,
105
+ },
106
+ {
107
+ "id": 8,
108
+ "name": "Cheese",
109
+ "category": "Dairy & Frozen",
110
+ "base_price": 35.0,
111
+ "base_sales": 12,
112
+ "volatility": 0.3,
113
+ },
114
+ {
115
+ "id": 9,
116
+ "name": "Yogurt",
117
+ "category": "Dairy & Frozen",
118
+ "base_price": 12.0,
119
+ "base_sales": 25,
120
+ "volatility": 0.4,
121
+ },
122
+ {
123
+ "id": 10,
124
+ "name": "Ice Cream",
125
+ "category": "Dairy & Frozen",
126
+ "base_price": 30.0,
127
+ "base_sales": 15,
128
+ "volatility": 0.5,
129
+ },
130
+ {
131
+ "id": 11,
132
+ "name": "Frozen Vegetables",
133
+ "category": "Dairy & Frozen",
134
+ "base_price": 22.0,
135
+ "base_sales": 10,
136
+ "volatility": 0.25,
137
+ },
138
+ # Beverages & Snacks
139
+ {
140
+ "id": 12,
141
+ "name": "Soda",
142
+ "category": "Beverages & Snacks",
143
+ "base_price": 15.0,
144
+ "base_sales": 40,
145
+ "volatility": 0.45,
146
+ },
147
+ {
148
+ "id": 13,
149
+ "name": "Juice",
150
+ "category": "Beverages & Snacks",
151
+ "base_price": 20.0,
152
+ "base_sales": 30,
153
+ "volatility": 0.4,
154
+ },
155
+ {
156
+ "id": 14,
157
+ "name": "Water",
158
+ "category": "Beverages & Snacks",
159
+ "base_price": 10.0,
160
+ "base_sales": 50,
161
+ "volatility": 0.3,
162
+ },
163
+ {
164
+ "id": 15,
165
+ "name": "Coffee",
166
+ "category": "Beverages & Snacks",
167
+ "base_price": 45.0,
168
+ "base_sales": 20,
169
+ "volatility": 0.25,
170
+ },
171
+ {
172
+ "id": 16,
173
+ "name": "Tea",
174
+ "category": "Beverages & Snacks",
175
+ "base_price": 35.0,
176
+ "base_sales": 15,
177
+ "volatility": 0.2,
178
+ },
179
+ {
180
+ "id": 17,
181
+ "name": "Chips",
182
+ "category": "Beverages & Snacks",
183
+ "base_price": 12.0,
184
+ "base_sales": 35,
185
+ "volatility": 0.45,
186
+ },
187
+ {
188
+ "id": 18,
189
+ "name": "Cookies",
190
+ "category": "Beverages & Snacks",
191
+ "base_price": 18.0,
192
+ "base_sales": 30,
193
+ "volatility": 0.4,
194
+ },
195
+ {
196
+ "id": 19,
197
+ "name": "Chocolate",
198
+ "category": "Beverages & Snacks",
199
+ "base_price": 22.0,
200
+ "base_sales": 25,
201
+ "volatility": 0.35,
202
+ },
203
+ # Household & Personal Care
204
+ {
205
+ "id": 20,
206
+ "name": "Soap",
207
+ "category": "Household & Personal Care",
208
+ "base_price": 8.0,
209
+ "base_sales": 20,
210
+ "volatility": 0.2,
211
+ },
212
+ {
213
+ "id": 21,
214
+ "name": "Shampoo",
215
+ "category": "Household & Personal Care",
216
+ "base_price": 25.0,
217
+ "base_sales": 15,
218
+ "volatility": 0.25,
219
+ },
220
+ {
221
+ "id": 22,
222
+ "name": "Toothpaste",
223
+ "category": "Household & Personal Care",
224
+ "base_price": 15.0,
225
+ "base_sales": 18,
226
+ "volatility": 0.15,
227
+ },
228
+ {
229
+ "id": 23,
230
+ "name": "Laundry Detergent",
231
+ "category": "Household & Personal Care",
232
+ "base_price": 40.0,
233
+ "base_sales": 12,
234
+ "volatility": 0.2,
235
+ },
236
+ {
237
+ "id": 24,
238
+ "name": "Paper Towels",
239
+ "category": "Household & Personal Care",
240
+ "base_price": 20.0,
241
+ "base_sales": 14,
242
+ "volatility": 0.3,
243
+ },
244
+ {
245
+ "id": 25,
246
+ "name": "Toilet Paper",
247
+ "category": "Household & Personal Care",
248
+ "base_price": 25.0,
249
+ "base_sales": 16,
250
+ "volatility": 0.25,
251
+ },
252
+ {
253
+ "id": 26,
254
+ "name": "Trash Bags",
255
+ "category": "Household & Personal Care",
256
+ "base_price": 18.0,
257
+ "base_sales": 10,
258
+ "volatility": 0.15,
259
+ },
260
+ {
261
+ "id": 27,
262
+ "name": "Dishwashing Liquid",
263
+ "category": "Household & Personal Care",
264
+ "base_price": 15.0,
265
+ "base_sales": 11,
266
+ "volatility": 0.2,
267
+ },
268
+ {
269
+ "id": 28,
270
+ "name": "All-Purpose Cleaner",
271
+ "category": "Household & Personal Care",
272
+ "base_price": 22.0,
273
+ "base_sales": 9,
274
+ "volatility": 0.15,
275
+ },
276
+ # Baby & Health
277
+ {
278
+ "id": 29,
279
+ "name": "Diapers",
280
+ "category": "Baby & Health",
281
+ "base_price": 45.0,
282
+ "base_sales": 25,
283
+ "volatility": 0.3,
284
+ },
285
+ {
286
+ "id": 30,
287
+ "name": "Baby Food",
288
+ "category": "Baby & Health",
289
+ "base_price": 20.0,
290
+ "base_sales": 15,
291
+ "volatility": 0.25,
292
+ },
293
+ {
294
+ "id": 31,
295
+ "name": "Baby Wipes",
296
+ "category": "Baby & Health",
297
+ "base_price": 15.0,
298
+ "base_sales": 20,
299
+ "volatility": 0.2,
300
+ },
301
+ {
302
+ "id": 32,
303
+ "name": "Pain Relievers",
304
+ "category": "Baby & Health",
305
+ "base_price": 30.0,
306
+ "base_sales": 10,
307
+ "volatility": 0.15,
308
+ },
309
+ {
310
+ "id": 33,
311
+ "name": "Vitamins",
312
+ "category": "Baby & Health",
313
+ "base_price": 40.0,
314
+ "base_sales": 8,
315
+ "volatility": 0.2,
316
+ },
317
+ {
318
+ "id": 34,
319
+ "name": "Cold & Flu Medicine",
320
+ "category": "Baby & Health",
321
+ "base_price": 35.0,
322
+ "base_sales": 7,
323
+ "volatility": 0.4,
324
+ },
325
+ {
326
+ "id": 35,
327
+ "name": "First Aid Kit",
328
+ "category": "Baby & Health",
329
+ "base_price": 50.0,
330
+ "base_sales": 5,
331
+ "volatility": 0.1,
332
+ },
333
+ ]
334
+
335
+ return categories, items
336
+
337
+
338
+ def calculate_daily_sales(date, store, item, weather_data=None):
339
+ """
340
+ Calculate daily sales based on various factors.
341
+ Returns an integer value for sales quantity.
342
+ """
343
+ # Base sales for this item
344
+ base_sales = item["base_sales"]
345
+
346
+ # Store factor (some stores have higher sales)
347
+ store_factor = 0.8 + (store["id"] % 10) / 10 # 0.8 to 1.7
348
+
349
+ # Day of week factor (weekend boost)
350
+ day_of_week = date.weekday() # 0 = Monday, 6 = Sunday
351
+ weekday_factor = 1.0
352
+ if day_of_week >= 5: # Weekend
353
+ weekday_factor = 1.3
354
+
355
+ # Monthly seasonality
356
+ month = date.month
357
+ # Higher sales in December (holidays), lower in February
358
+ month_factor = 1.0 + 0.3 * (month == 12) - 0.1 * (month == 2)
359
+
360
+ # Quarterly business cycle
361
+ quarter = (month - 1) // 3 + 1
362
+ quarter_factor = 1.0 + 0.05 * (quarter - 2.5) # Q3-Q4 slightly higher
363
+
364
+ # Holiday effects
365
+ holiday_factor = 1.0
366
+ # Vietnamese New Year (Tet) - usually in late January or early February
367
+ if (month == 1 and date.day >= 27) or (month == 2 and date.day <= 5):
368
+ holiday_factor = 1.5
369
+ # National Day (September 2)
370
+ elif month == 9 and date.day == 2:
371
+ holiday_factor = 1.3
372
+ # Year-end shopping
373
+ elif month == 12 and date.day >= 20:
374
+ holiday_factor = 1.4
375
+
376
+ # Weather effects if weather data is provided
377
+ weather_factor = 1.0
378
+ if weather_data is not None:
379
+ # Find weather for this date and province
380
+ date_str = date.strftime("%Y-%m-%d")
381
+ province = store["province"]
382
+ day_weather = weather_data.get((date_str, province))
383
+
384
+ if day_weather:
385
+ temp = day_weather["temperature"]
386
+ humidity = day_weather["humidity"]
387
+
388
+ # Temperature effects differ by item category
389
+ if item["category"] == "Beverages & Snacks":
390
+ # More beverages sold in hot weather
391
+ if temp > 28:
392
+ weather_factor *= 1.3
393
+ elif temp < 18:
394
+ weather_factor *= 0.9
395
+ elif item["category"] == "Dairy & Frozen":
396
+ # More ice cream in hot weather
397
+ if temp > 28:
398
+ weather_factor *= 1.4
399
+ elif temp < 18:
400
+ weather_factor *= 0.8
401
+
402
+ # Rain effect (approximated by high humidity)
403
+ if humidity > 80:
404
+ # People buy more when staying indoors
405
+ if item["category"] in [
406
+ "Beverages & Snacks",
407
+ "Household & Personal Care",
408
+ ]:
409
+ weather_factor *= 1.2
410
+
411
+ # Year-over-year growth (for 2017 data)
412
+ yoy_growth = 1.0
413
+ if date.year == 2017:
414
+ # 5% general growth with some category variations
415
+ category_growth = {
416
+ "Staples": 1.03,
417
+ "Dairy & Frozen": 1.05,
418
+ "Beverages & Snacks": 1.08,
419
+ "Household & Personal Care": 1.05,
420
+ "Baby & Health": 1.07,
421
+ }
422
+ yoy_growth = category_growth.get(item["category"], 1.05)
423
+
424
+ # Random variation
425
+ random_factor = np.random.normal(1.0, item["volatility"])
426
+
427
+ # Calculate final sales
428
+ sales = (
429
+ base_sales
430
+ * store_factor
431
+ * weekday_factor
432
+ * month_factor
433
+ * quarter_factor
434
+ * holiday_factor
435
+ * weather_factor
436
+ * yoy_growth
437
+ * random_factor
438
+ )
439
+
440
+ # Ensure minimum sales and convert to integer
441
+ sales = max(
442
+ 1, int(round(sales))
443
+ ) # Minimum sales of 1 unit, rounded to nearest integer
444
+
445
+ return sales
446
+
447
+
448
+ def generate_weather_data(start_date, end_date, provinces):
449
+ """Generate synthetic weather data"""
450
+
451
+ # Define base temperatures and humidity for each province
452
+ province_weather = {
453
+ "Hanoi": {
454
+ "base_temp": {
455
+ 1: 16,
456
+ 2: 17,
457
+ 3: 20,
458
+ 4: 24,
459
+ 5: 28,
460
+ 6: 30,
461
+ 7: 30,
462
+ 8: 29,
463
+ 9: 28,
464
+ 10: 25,
465
+ 11: 21,
466
+ 12: 18,
467
+ },
468
+ "temp_variation": 3.5,
469
+ "base_humidity": {
470
+ 1: 80,
471
+ 2: 83,
472
+ 3: 85,
473
+ 4: 85,
474
+ 5: 80,
475
+ 6: 80,
476
+ 7: 83,
477
+ 8: 85,
478
+ 9: 83,
479
+ 10: 78,
480
+ 11: 75,
481
+ 12: 77,
482
+ },
483
+ "humidity_variation": 10,
484
+ "seasons": {
485
+ 1: "winter",
486
+ 2: "winter",
487
+ 3: "spring",
488
+ 4: "spring",
489
+ 5: "summer",
490
+ 6: "summer",
491
+ 7: "summer",
492
+ 8: "summer",
493
+ 9: "fall",
494
+ 10: "fall",
495
+ 11: "fall",
496
+ 12: "winter",
497
+ },
498
+ },
499
+ "Ho Chi Minh City": {
500
+ "base_temp": {
501
+ 1: 26,
502
+ 2: 27,
503
+ 3: 28,
504
+ 4: 29,
505
+ 5: 29,
506
+ 6: 28,
507
+ 7: 28,
508
+ 8: 28,
509
+ 9: 28,
510
+ 10: 27,
511
+ 11: 27,
512
+ 12: 26,
513
+ },
514
+ "temp_variation": 2.0,
515
+ "base_humidity": {
516
+ 1: 70,
517
+ 2: 70,
518
+ 3: 70,
519
+ 4: 75,
520
+ 5: 80,
521
+ 6: 83,
522
+ 7: 85,
523
+ 8: 85,
524
+ 9: 88,
525
+ 10: 85,
526
+ 11: 80,
527
+ 12: 75,
528
+ },
529
+ "humidity_variation": 8,
530
+ "seasons": {
531
+ 1: "dry",
532
+ 2: "dry",
533
+ 3: "dry",
534
+ 4: "dry",
535
+ 5: "wet",
536
+ 6: "wet",
537
+ 7: "wet",
538
+ 8: "wet",
539
+ 9: "wet",
540
+ 10: "wet",
541
+ 11: "wet",
542
+ 12: "dry",
543
+ },
544
+ },
545
+ }
546
+
547
+ # Create date range
548
+ date_list = []
549
+ current_date = start_date
550
+ while current_date <= end_date:
551
+ date_list.append(current_date)
552
+ current_date += timedelta(days=1)
553
+
554
+ # Generate weather data
555
+ weather_data = []
556
+ weather_dict = {} # For lookup during sales calculation
557
+
558
+ for date in date_list:
559
+ month = date.month
560
+ for province in provinces:
561
+ # Get base values for this province and month
562
+ base_temp = province_weather[province]["base_temp"][month]
563
+ temp_variation = province_weather[province]["temp_variation"]
564
+ base_humidity = province_weather[province]["base_humidity"][month]
565
+ humidity_variation = province_weather[province]["humidity_variation"]
566
+ season = province_weather[province]["seasons"][month]
567
+
568
+ # Add random variation
569
+ temperature = base_temp + np.random.uniform(-temp_variation, temp_variation)
570
+ humidity = base_humidity + np.random.uniform(
571
+ -humidity_variation, humidity_variation
572
+ )
573
+
574
+ # Round to one decimal place
575
+ temperature = round(temperature, 1)
576
+ humidity = round(humidity, 1)
577
+
578
+ # Ensure humidity is within realistic range
579
+ humidity = max(40, min(95, humidity))
580
+
581
+ # Add to weather data
582
+ weather_data.append(
583
+ {
584
+ "city": province,
585
+ "date": date.strftime("%Y-%m-%d"),
586
+ "temperature": temperature,
587
+ "humidity": humidity,
588
+ "season": season,
589
+ }
590
+ )
591
+
592
+ # Add to lookup dictionary
593
+ weather_dict[(date.strftime("%Y-%m-%d"), province)] = {
594
+ "temperature": temperature,
595
+ "humidity": humidity,
596
+ "season": season,
597
+ }
598
+
599
+ return pd.DataFrame(weather_data), weather_dict
600
+
601
+
602
+ def generate_sales_data(start_date, end_date, stores, items, weather_dict):
603
+ """Generate synthetic sales data"""
604
+
605
+ # Create date range
606
+ date_list = []
607
+ current_date = start_date
608
+ while current_date <= end_date:
609
+ date_list.append(current_date)
610
+ current_date += timedelta(days=1)
611
+
612
+ # Generate sales data
613
+ sales_data = []
614
+
615
+ # For each date, store, and item, calculate sales
616
+ for date in date_list:
617
+ for store in stores:
618
+ # Not all stores carry all items
619
+ # Use store_id to deterministically select items
620
+ store_seed = store["id"] * 10
621
+ np.random.seed(store_seed)
622
+
623
+ # Select a subset of items for this store
624
+ store_items = []
625
+ for item in items:
626
+ # 80% chance of carrying an item
627
+ if np.random.random() < 0.8:
628
+ store_items.append(item)
629
+
630
+ # Reset random seed
631
+ np.random.seed(None)
632
+
633
+ # Calculate sales for each item
634
+ for item in store_items:
635
+ # Calculate sales for this combination
636
+ sales_value = calculate_daily_sales(date, store, item, weather_dict)
637
+
638
+ # Add to sales data
639
+ sales_data.append(
640
+ {
641
+ "date": date.strftime("%Y-%m-%d"),
642
+ "province": store["province"],
643
+ "store_id": store["id"],
644
+ "store_name": store["name"],
645
+ "category": item["category"],
646
+ "item_id": item["id"],
647
+ "item_name": item["name"],
648
+ "sales": sales_value,
649
+ }
650
+ )
651
+
652
+ return pd.DataFrame(sales_data)
653
+
654
+
655
+ def add_outliers_and_nans(data, outlier_percentage=0.01, nan_percentage=0.1):
656
+ """Add the nan values to data set"""
657
+ # Copy the original data to avoid modifying the input directly
658
+ modified_data = data.copy()
659
+
660
+ # Calculate the number of rows to add outliers and NaN values
661
+ num_rows = len(modified_data)
662
+ num_outliers = int(num_rows * outlier_percentage / 100)
663
+ num_nans = int(num_rows * nan_percentage / 100)
664
+
665
+ # Add outliers to the 'sales' column
666
+ np.random.seed(2025)
667
+ outlier_indices = np.random.choice(num_rows, num_outliers, replace=False)
668
+ modified_data.loc[
669
+ outlier_indices, "sales"
670
+ ] *= 3 # Increase sales by a factor to create outliers
671
+
672
+ # Add NaN values to the 'sales' column
673
+ nan_indices = np.random.choice(num_rows, num_nans, replace=False)
674
+ modified_data.loc[nan_indices, "sales"] = np.nan
675
+
676
+ return modified_data
677
+
678
+
679
+ def check_missing_values(df):
680
+ """Check missing values"""
681
+ df_nan = pd.DataFrame(
682
+ {
683
+ "counts": df.isna().sum(),
684
+ "ratio (%)": np.round(df.isna().sum() / df.shape[0], 4) * 100,
685
+ }
686
+ )
687
+ return df_nan
688
+
689
+
690
+ def main():
691
+ """Main function to generate all data"""
692
+ print("Generating synthetic data for Sales Forecasting with XAI project...")
693
+
694
+ # Create output directory if it doesn't exist
695
+ os.makedirs("data", exist_ok=True)
696
+
697
+ # Generate store and item data
698
+ provinces, stores = generate_store_data()
699
+ categories, items = generate_item_data()
700
+
701
+ print(
702
+ f"Created {len(stores)} stores and {len(items)} items across {len(categories)} categories"
703
+ )
704
+
705
+ # Define date ranges
706
+ start_date_2016 = datetime(2016, 1, 1)
707
+ end_date_2016 = datetime(2016, 12, 31)
708
+
709
+ start_date_2017 = datetime(2017, 1, 1)
710
+ end_date_2017 = datetime(2017, 12, 31)
711
+
712
+ # Generate weather data for both years
713
+ print("Generating weather data...")
714
+ weather_df, weather_dict = generate_weather_data(
715
+ start_date_2016, end_date_2017, provinces
716
+ )
717
+
718
+ # Save weather data
719
+ weather_df.to_csv("data/weather_data.csv", index=False)
720
+ print(f"Saved weather data with {len(weather_df)} records")
721
+
722
+ # Generate 2016 sales data
723
+ print("Generating 2016 sales data...")
724
+ sales_2016 = generate_sales_data(
725
+ start_date_2016, end_date_2016, stores, items, weather_dict
726
+ )
727
+
728
+ sales_2016 = add_outliers_and_nans(
729
+ sales_2016, outlier_percentage=0.5, nan_percentage=1
730
+ )
731
+
732
+ # Save 2016 sales data
733
+ sales_2016.to_csv("data/2016_sales.csv", index=False)
734
+ print(f"Saved 2016 sales data with {len(sales_2016)} records")
735
+
736
+ # Generate 2017 sales data
737
+ print("Generating 2017 sales data...")
738
+ sales_2017 = generate_sales_data(
739
+ start_date_2017, end_date_2017, stores, items, weather_dict
740
+ )
741
+
742
+ sales_2017 = add_outliers_and_nans(
743
+ sales_2017, outlier_percentage=0.5, nan_percentage=1
744
+ )
745
+
746
+ # Save 2017 sales data
747
+ sales_2017.to_csv("data/2017_sales.csv", index=False)
748
+ print(f"Saved 2017 sales data with {len(sales_2017)} records")
749
+
750
+ # Print statistics
751
+ print("\nData Generation Complete!")
752
+ print(f"Total weather records: {len(weather_df)}")
753
+ print(f"Total 2016 sales records: {len(sales_2016)}")
754
+ print(f"Total 2017 sales records: {len(sales_2017)}")
755
+ print(
756
+ f"Total combined records: {len(weather_df) + len(sales_2016) + len(sales_2017)}"
757
+ )
758
+
759
+ print("\nSales Statistics:")
760
+ print(f"2016 Average Sales: {sales_2016['sales'].mean():.2f} units")
761
+ print(f"2016 Max Sales: {sales_2016['sales'].max()} units")
762
+ print(f"2017 Average Sales: {sales_2017['sales'].mean():.2f} units")
763
+ print(f"2017 Max Sales: {sales_2017['sales'].max()} units")
764
+ print(f"Missing values: {check_missing_values(sales_2016)}")
765
+ print(f"Missing values: {check_missing_values(sales_2017)}")
766
+
767
+ print("\nFiles saved to data/ directory:")
768
+ print("- data/weather_data.csv")
769
+ print("- data/2016_sales.csv")
770
+ print("- data/2017_sales.csv")
771
+
772
+
773
+ if __name__ == "__main__":
774
+ main()
app/utils/data_loader.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import pandas as pd
4
+ import gradio as gr
5
+ import pyarrow.feather as feather
6
+ from functools import lru_cache
7
+
8
+ # --- Data & Model Loading Logic ---
9
+
10
+ def load_model():
11
+ """Load the trained sales forecast model"""
12
+ try:
13
+ with open("models/sales_forecast_model.pkl", "rb") as file:
14
+ model = pickle.load(file)
15
+ return model
16
+ except FileNotFoundError:
17
+ # Using gr.Error for UI notification if called within an interaction
18
+ # or standard print for startup logs
19
+ print("Error: 'models/sales_forecast_model.pkl' not found.")
20
+ return None
21
+
22
+ def load_feature_stats():
23
+ """Load feature statistics used for normalization"""
24
+ try:
25
+ with open("models/feature_stats.json", "r") as file:
26
+ feature_stats = json.load(file)
27
+ return feature_stats
28
+ except FileNotFoundError:
29
+ print("Error: 'models/feature_stats.json' not found.")
30
+ return {}
31
+
32
+ @lru_cache(maxsize=1)
33
+ def load_data():
34
+ """Load preprocessed sales data (lru_cache replaces @st.cache_data)"""
35
+ try:
36
+ df = pd.read_csv("data/sales_data_preprocessed.csv")
37
+ if "date" in df.columns:
38
+ df["date"] = pd.to_datetime(df["date"])
39
+ return df
40
+ except FileNotFoundError:
41
+ print("Error: 'data/sales_data_preprocessed.csv' not found.")
42
+ return pd.DataFrame(columns=["date", "store", "sales"])
43
+
44
+ def load_feature_engineered_data():
45
+ """Load feature engineered data with extended features"""
46
+ try:
47
+ feature_engineered_data = feather.read_feather(
48
+ "data/feature_engineered_data_55_features.feather"
49
+ )
50
+ return feature_engineered_data
51
+ except Exception as e:
52
+ print(f"Error loading feature engineered data: {str(e)}")
53
+ return pd.DataFrame()
54
+
55
+ # --- Processing Logic ---
56
+
57
+ def preprocess_data(df, feature_stats=None):
58
+ """Preprocess data for prediction (simplified version)"""
59
+ # Create a copy to avoid modifying the original
60
+ processed_df = df.copy()
61
+
62
+ # Extract date features if date column exists
63
+ if "date" in processed_df.columns:
64
+ processed_df["day_of_week"] = processed_df["date"].dt.dayofweek
65
+ processed_df["day_of_month"] = processed_df["date"].dt.day
66
+ processed_df["month"] = processed_df["date"].dt.month
67
+ processed_df["year"] = processed_df["date"].dt.year
68
+ processed_df["is_weekend"] = processed_df["day_of_week"].apply(
69
+ lambda x: 1 if x >= 5 else 0
70
+ )
71
+
72
+ # Normalize numerical features if stats are provided
73
+ if feature_stats:
74
+ for feature, stats in feature_stats.items():
75
+ if feature in processed_df.columns and "mean" in stats and "std" in stats:
76
+ processed_df[feature] = (processed_df[feature] - stats["mean"]) / stats[
77
+ "std"
78
+ ]
79
+
80
+ return processed_df
81
+
82
+ # --- Gradio UI Implementation ---
83
+
84
+ # Load resources once when the app starts
85
+ model = load_model()
86
+ stats = load_feature_stats()
87
+
88
+ def predict_sales_ui(store_id):
89
+ """Example function to link the logic to a Gradio interface"""
90
+ if model is None:
91
+ raise gr.Error("Model not loaded. Check server logs.")
92
+
93
+ data = load_data()
94
+ # Apply your logic
95
+ processed = preprocess_data(data, stats)
96
+
97
+ # Filter for the specific store
98
+ store_data = processed[processed['store'] == store_id]
99
+
100
+ # Return results (placeholder for actual model.predict logic)
101
+ return store_data.head()
102
+
103
+ # Simple Gradio Interface
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("# Sales Forecast Prediction")
106
+ store_input = gr.Number(label="Enter Store ID")
107
+ output_table = gr.DataFrame(label="Preprocessed Data Preview")
108
+ btn = gr.Button("Predict")
109
+
110
+ btn.click(fn=predict_sales_ui, inputs=store_input, outputs=output_table)
111
+
112
+ if __name__ == "__main__":
113
+ demo.launch()
app/utils/plots.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import pandas as pd
4
+ import seaborn as sns
5
+
6
+
7
+ def plot_sales(df, store_id=1, item_id=1):
8
+ """Plot sales and visualize missing values"""
9
+
10
+ df_2plot = df.query("(store_id==@store_id)&(item_id==@item_id)")
11
+ store_name = df_2plot["store_name"].iloc[-1]
12
+ item_name = df_2plot["item_name"].iloc[-1]
13
+
14
+ fig, ax = plt.subplots(figsize=(6, 3))
15
+ df_2plot[["date", "sales"]].plot(x="date", y="sales", ax=ax, legend=False)
16
+
17
+ # Replace NaN values with the mean of surrounding two points
18
+ nan_indices = df_2plot[df_2plot["sales"].isna()].index
19
+
20
+ if len(nan_indices) >= 1:
21
+ df_2plot = df_2plot.assign(sales=lambda df: df["sales"].fillna(method="ffill"))
22
+ # Draw arrows for NaN values
23
+ nan_dates = df_2plot.loc[nan_indices, "date"]
24
+ nan_sales = df_2plot.loc[nan_indices, "sales"]
25
+ for date, sales in zip(nan_dates, nan_sales):
26
+ ax.annotate(
27
+ "-",
28
+ xy=(date, sales),
29
+ color="red", # Set text color to red
30
+ size=20,
31
+ )
32
+
33
+ # Set plot labels and legend
34
+ ax.set_xlabel("Date")
35
+ ax.set_ylabel("Sales")
36
+ ax.set_title(f"Store: {store_name} - Item: {item_name}")
37
+ ax.legend()
38
+ plt.show()
39
+
40
+
41
+ def plot_forecast_single(flat_df, store_item):
42
+ """
43
+ Plot actual vs predicted sales for one store-item combo from flattened predictions for Prophet.
44
+ """
45
+ df = flat_df[flat_df["store_item"] == store_item].copy()
46
+
47
+ if df.empty:
48
+ print(f"No data found for: {store_item}")
49
+ return
50
+
51
+ plt.figure(figsize=(12, 6))
52
+ sns.lineplot(data=df, x="ds", y="y", label="Actual", color="black")
53
+ sns.lineplot(data=df, x="ds", y="yhat", label="Forecast", color="blue")
54
+ plt.fill_between(
55
+ df["ds"],
56
+ df["yhat_lower"],
57
+ df["yhat_upper"],
58
+ color="blue",
59
+ alpha=0.2,
60
+ label="Confidence Interval",
61
+ )
62
+ plt.title(f"Forecast vs Actual for {store_item}")
63
+ plt.xlabel("Date")
64
+ plt.ylabel("Sales")
65
+ plt.xticks(rotation=45)
66
+ plt.legend()
67
+ # plt.grid(True)
68
+ plt.tight_layout()
69
+ plt.show()
70
+
71
+
72
+ def plot_sales_predictions(
73
+ df_prediction, store_id=1, nrows=6, ncols=5, figsize=(20, 20)
74
+ ):
75
+ """
76
+ Plots actual vs predicted sales for items in a given store.
77
+
78
+ Parameters:
79
+ df_prediction (DataFrame): Must include ['store_id', 'item_id', 'date', 'sales', 'prediction']
80
+ store_id (int): Store to filter on
81
+ nrows (int): Rows of subplots
82
+ ncols (int): Columns of subplots
83
+ figsize (tuple): Size of the full figure
84
+ """
85
+ df_sample = df_prediction[df_prediction["store_id"] == store_id]
86
+ store_name = df_sample["store_name"].iloc[-1]
87
+
88
+ fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
89
+ axes = axes.flatten()
90
+
91
+ item_ids = sorted(df_sample["item_id"].unique())
92
+
93
+ for i, ax in enumerate(axes):
94
+ if i >= len(item_ids):
95
+ ax.axis("off") # Hide unused subplots
96
+ continue
97
+
98
+ item_id = item_ids[i]
99
+ df2plot = df_sample[df_sample["item_id"] == item_id]
100
+ item_name = df2plot["item_name"].iloc[-1]
101
+
102
+ if df2plot.empty:
103
+ ax.axis("off")
104
+ continue
105
+
106
+ # Plot actual and predicted sales
107
+ ax.plot(df2plot["date"], df2plot["sales"], label="Actual", color="blue")
108
+ ax.plot(
109
+ df2plot["date"],
110
+ df2plot["prediction"],
111
+ label="Forecast",
112
+ color="red",
113
+ linestyle="--",
114
+ marker=".",
115
+ )
116
+
117
+ ax.set_title(f"Item: {item_name}")
118
+ ax.set_xlabel("")
119
+ ax.set_ylabel("Sales")
120
+ ax.tick_params(axis="x", rotation=45)
121
+ ax.grid(True)
122
+
123
+ # Only add legend to the first subplot
124
+ handles, labels = axes[0].get_legend_handles_labels()
125
+ fig.legend(handles, labels, loc="upper center", ncol=2, fontsize=12)
126
+
127
+ plt.tight_layout(rect=[0, 0, 1, 0.97]) # Leave space for the legend
128
+ fig.suptitle(
129
+ f"Sales Forecast vs Actual - Store {store_name}", fontsize=16, fontweight="bold"
130
+ )
131
+ plt.show()
app/utils/utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import lightgbm as lgbm
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ import seaborn as sns
8
+
9
+
10
+ def fill_misisng_values(df):
11
+ """Fill NaN values in the 'sales' column with the mean of non-NaN values"""
12
+ df_filled = df.copy()
13
+ df_filled["sales"] = df_filled["sales"].fillna(df_filled["sales"].mean())
14
+ return df_filled
15
+
16
+
17
+ def correct_outliers(df, factor=3):
18
+ """Identify and correct outliers in the 'sales' column by reducing them to the mean"""
19
+ df_corrected = df.copy()
20
+
21
+ # Identify outliers using z-score
22
+ z_scores = (df_corrected["sales"] - df_corrected["sales"].mean()) / df_corrected[
23
+ "sales"
24
+ ].std()
25
+ outlier_indices = np.abs(z_scores) > factor # Adjust the threshold as needed
26
+ # Correct outliers by reducing them to the mean
27
+ df_corrected.loc[outlier_indices, "sales"] = df_corrected["sales"].mean()
28
+
29
+ return df_corrected
30
+
31
+
32
+ def get_sample_stores(df: pd.DataFrame, store_id: int = 1) -> pd.DataFrame:
33
+ """Get the sample stores with store_id"""
34
+ grouped = df.groupby("store_id")
35
+ sample_store = grouped.get_group((store_id))
36
+ return sample_store
37
+
38
+
39
+ def save_data(df, file_path, file_format="feather"):
40
+ """
41
+ Save a DataFrame to a specified file format.
42
+
43
+ Parameters:
44
+ - df (pd.DataFrame): The DataFrame to be saved.
45
+ - file_path (str): The path where the file will be saved.
46
+ - file_format (str): The format in which to save the file. Supported formats: 'feather', 'csv'.
47
+ Default is 'feather'.
48
+ Example:
49
+ ```python
50
+ # Assuming df is the DataFrame you want to save
51
+ save_data(df, 'output_data.feather', file_format='feather')
52
+ ```
53
+
54
+ Note:
55
+ - Make sure to have the required libraries (pandas and feather-format) installed.
56
+ """
57
+ if file_format.lower() == "feather":
58
+ # Save to Feather format
59
+ df.to_feather(file_path)
60
+ print(f"DataFrame saved to {file_path} in Feather format.")
61
+ elif file_format.lower() == "csv":
62
+ # Save to CSV format
63
+ df.to_csv(file_path, index=False)
64
+ print(f"DataFrame saved to {file_path} in CSV format.")
65
+ else:
66
+ print(
67
+ f"Error: Unsupported file format '{file_format}'. Supported formats: 'feather', 'csv'."
68
+ )
69
+
70
+
71
+ def flatten_prophet_predictions(predictions_dict):
72
+ all_dfs = []
73
+
74
+ for store_item, df in predictions_dict.items():
75
+ df = df.copy()
76
+ df["store_item"] = store_item
77
+ all_dfs.append(df)
78
+
79
+ return pd.concat(all_dfs, ignore_index=True)
80
+
81
+
82
+ def load_model(file_path):
83
+ """
84
+ Load a machine learning model from a file.
85
+
86
+ Parameters:
87
+ - file_path: The file path from where the model will be loaded.
88
+
89
+ Returns:
90
+ - The loaded model.
91
+ """
92
+ try:
93
+ with open(file_path, "rb") as file:
94
+ model = pickle.load(file)
95
+ print(f"Sklearn model loaded from {file_path}")
96
+
97
+ except (pickle.UnpicklingError, FileNotFoundError):
98
+ # If loading as scikit-learn model fails or the file is not found,
99
+ # assume it is a LightGBM model (scikit-learn API)
100
+ model = lgbm.Booster(model_file=file_path)
101
+ print(f"LightGBM (scikit-learn API) model loaded from {file_path}")
102
+
103
+ return model
104
+
105
+
106
+ # Function to calculate WAPE (Weighted Absolute Percentage Error)
107
+ def weighted_absolute_percentage_error(y_true, y_pred):
108
+ """
109
+ Calculate Weighted Absolute Percentage Error
110
+
111
+ Args:
112
+ y_true: Actual values
113
+ y_pred: Predicted values
114
+
115
+ Returns:
116
+ WAPE value (percentage)
117
+ """
118
+ y_true, y_pred = np.array(y_true), np.array(y_pred)
119
+ return 100 * np.sum(np.abs(y_true - y_pred)) / np.sum(np.abs(y_true))
app/utils/visualization_code.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.ticker as ticker
5
+ import numpy as np
6
+ import pandas as pd
7
+ import seaborn as sns
8
+ from matplotlib.dates import DateFormatter
9
+
10
+ # Set up plotting style
11
+ plt.style.use("seaborn-v0_8-whitegrid")
12
+ sns.set_palette("deep")
13
+ plt.rcParams["figure.figsize"] = (14, 8)
14
+ plt.rcParams["font.size"] = 12
15
+
16
+
17
+ def visualize_predictions_by_store_item(test_results, output_dir="visualizations"):
18
+ """
19
+ Create visualizations of actual vs predicted values for each store-item combination.
20
+
21
+ Args:
22
+ test_results: DataFrame containing test results with columns:
23
+ 'date', 'store_name', 'item_name', 'sales', 'prediction'
24
+ output_dir: Directory to save the visualizations
25
+ """
26
+ # Create output directory if it doesn't exist
27
+ os.makedirs(output_dir, exist_ok=True)
28
+
29
+ # Create a time series plot for each store-item combination
30
+ store_items = test_results.groupby(["store_name", "item_name"])
31
+
32
+ # Get total number of combinations for progress tracking
33
+ total_combinations = len(store_items)
34
+ print(
35
+ f"Creating visualizations for {total_combinations} store-item combinations..."
36
+ )
37
+
38
+ # Counter for progress tracking
39
+ counter = 0
40
+
41
+ # For each store-item combination, create a plot
42
+ for (store, item), group in store_items:
43
+ # Sort by date to ensure proper time series order
44
+ group = group.sort_values("date")
45
+
46
+ # Convert date to datetime if it's not already
47
+ if not pd.api.types.is_datetime64_any_dtype(group["date"]):
48
+ group["date"] = pd.to_datetime(group["date"])
49
+
50
+ # Create the plot
51
+ fig, ax = plt.subplots(figsize=(14, 6))
52
+
53
+ # Plot actual and predicted values
54
+ ax.plot(
55
+ group["date"], group["sales"], "o-", label="Actual", alpha=0.7, linewidth=2
56
+ )
57
+ ax.plot(
58
+ group["date"],
59
+ group["prediction"],
60
+ "s--",
61
+ label="Predicted",
62
+ alpha=0.7,
63
+ linewidth=2,
64
+ )
65
+
66
+ # Calculate error metrics for this store-item
67
+ mae = np.mean(np.abs(group["sales"] - group["prediction"]))
68
+ mape = (
69
+ np.mean(np.abs((group["sales"] - group["prediction"]) / group["sales"]))
70
+ * 100
71
+ )
72
+
73
+ # Add title and labels
74
+ ax.set_title(f"Store: {store}, Item: {item}\nMAE: {mae:.2f}, MAPE: {mape:.2f}%")
75
+ ax.set_xlabel("Date")
76
+ ax.set_ylabel("Sales")
77
+
78
+ # Format x-axis dates
79
+ date_formatter = DateFormatter("%Y-%m-%d")
80
+ ax.xaxis.set_major_formatter(date_formatter)
81
+ # Rotate date labels for better readability
82
+ plt.xticks(rotation=45)
83
+
84
+ # Add grid for easier reading
85
+ ax.grid(True, linestyle="--", alpha=0.7)
86
+
87
+ # Add legend
88
+ ax.legend()
89
+
90
+ # Adjust layout
91
+ plt.tight_layout()
92
+
93
+ # Save the figure
94
+ safe_store = store.replace(" ", "_").replace("/", "_")
95
+ safe_item = item.replace(" ", "_").replace("/", "_")
96
+ filename = f"{safe_store}_{safe_item}.png"
97
+ plt.savefig(os.path.join(output_dir, filename))
98
+
99
+ # Close the figure to free memory
100
+ plt.close(fig)
101
+
102
+ # Update progress
103
+ counter += 1
104
+ if counter % 10 == 0:
105
+ print(f"Processed {counter}/{total_combinations} combinations")
106
+
107
+ print(f"All visualizations saved to {output_dir}/")
108
+
109
+
110
+ def visualize_aggregated_predictions(test_results, output_dir="visualizations"):
111
+ """
112
+ Create aggregated visualizations of actual vs predicted values by store, item, and date.
113
+
114
+ Args:
115
+ test_results: DataFrame containing test results
116
+ output_dir: Directory to save the visualizations
117
+ """
118
+ # Create output directory if it doesn't exist
119
+ os.makedirs(output_dir, exist_ok=True)
120
+
121
+ # Ensure date is in datetime format
122
+ if not pd.api.types.is_datetime64_any_dtype(test_results["date"]):
123
+ test_results["date"] = pd.to_datetime(test_results["date"])
124
+
125
+ # 1. Aggregate by date
126
+ daily_results = (
127
+ test_results.groupby("date")
128
+ .agg({"sales": "sum", "prediction": "sum"})
129
+ .reset_index()
130
+ )
131
+
132
+ # Plot daily aggregated results
133
+ fig, ax = plt.subplots(figsize=(14, 6))
134
+ ax.plot(
135
+ daily_results["date"],
136
+ daily_results["sales"],
137
+ "o-",
138
+ label="Actual",
139
+ alpha=0.7,
140
+ linewidth=2,
141
+ )
142
+ ax.plot(
143
+ daily_results["date"],
144
+ daily_results["prediction"],
145
+ "s--",
146
+ label="Predicted",
147
+ alpha=0.7,
148
+ linewidth=2,
149
+ )
150
+
151
+ # Add title and labels
152
+ ax.set_title("Total Daily Sales: Actual vs Predicted")
153
+ ax.set_xlabel("Date")
154
+ ax.set_ylabel("Total Sales")
155
+
156
+ # Format x-axis dates
157
+ date_formatter = DateFormatter("%Y-%m-%d")
158
+ ax.xaxis.set_major_formatter(date_formatter)
159
+ plt.xticks(rotation=45)
160
+
161
+ # Add grid and legend
162
+ ax.grid(True, linestyle="--", alpha=0.7)
163
+ ax.legend()
164
+
165
+ # Adjust layout and save
166
+ plt.tight_layout()
167
+ plt.savefig(os.path.join(output_dir, "total_daily_sales.png"))
168
+ plt.close(fig)
169
+
170
+ # 2. Aggregate by store
171
+ store_results = (
172
+ test_results.groupby(["store_name", "date"])
173
+ .agg({"sales": "sum", "prediction": "sum"})
174
+ .reset_index()
175
+ )
176
+
177
+ # Plot for each store
178
+ stores = store_results["store_name"].unique()
179
+ for store in stores:
180
+ store_data = store_results[store_results["store_name"] == store]
181
+
182
+ fig, ax = plt.subplots(figsize=(14, 6))
183
+ ax.plot(
184
+ store_data["date"],
185
+ store_data["sales"],
186
+ "o-",
187
+ label="Actual",
188
+ alpha=0.7,
189
+ linewidth=2,
190
+ )
191
+ ax.plot(
192
+ store_data["date"],
193
+ store_data["prediction"],
194
+ "s--",
195
+ label="Predicted",
196
+ alpha=0.7,
197
+ linewidth=2,
198
+ )
199
+
200
+ # Add title and labels
201
+ ax.set_title(f"Store: {store} - Total Daily Sales")
202
+ ax.set_xlabel("Date")
203
+ ax.set_ylabel("Total Sales")
204
+
205
+ # Format x-axis dates
206
+ ax.xaxis.set_major_formatter(date_formatter)
207
+ plt.xticks(rotation=45)
208
+
209
+ # Add grid and legend
210
+ ax.grid(True, linestyle="--", alpha=0.7)
211
+ ax.legend()
212
+
213
+ # Adjust layout and save
214
+ plt.tight_layout()
215
+ safe_store = store.replace(" ", "_").replace("/", "_")
216
+ plt.savefig(os.path.join(output_dir, f"store_{safe_store}_total.png"))
217
+ plt.close(fig)
218
+
219
+ # 3. Aggregate by item
220
+ item_results = (
221
+ test_results.groupby(["item_name", "date"])
222
+ .agg({"sales": "sum", "prediction": "sum"})
223
+ .reset_index()
224
+ )
225
+
226
+ # Plot for each item
227
+ items = item_results["item_name"].unique()
228
+ for item in items:
229
+ item_data = item_results[item_results["item_name"] == item]
230
+
231
+ fig, ax = plt.subplots(figsize=(14, 6))
232
+ ax.plot(
233
+ item_data["date"],
234
+ item_data["sales"],
235
+ "o-",
236
+ label="Actual",
237
+ alpha=0.7,
238
+ linewidth=2,
239
+ )
240
+ ax.plot(
241
+ item_data["date"],
242
+ item_data["prediction"],
243
+ "s--",
244
+ label="Predicted",
245
+ alpha=0.7,
246
+ linewidth=2,
247
+ )
248
+
249
+ # Add title and labels
250
+ ax.set_title(f"Item: {item} - Total Daily Sales")
251
+ ax.set_xlabel("Date")
252
+ ax.set_ylabel("Total Sales")
253
+
254
+ # Format x-axis dates
255
+ ax.xaxis.set_major_formatter(date_formatter)
256
+ plt.xticks(rotation=45)
257
+
258
+ # Add grid and legend
259
+ ax.grid(True, linestyle="--", alpha=0.7)
260
+ ax.legend()
261
+
262
+ # Adjust layout and save
263
+ plt.tight_layout()
264
+ safe_item = item.replace(" ", "_").replace("/", "_")
265
+ plt.savefig(os.path.join(output_dir, f"item_{safe_item}_total.png"))
266
+ plt.close(fig)
267
+
268
+ print(f"Aggregated visualizations saved to {output_dir}/")
269
+
270
+
271
+ def create_interactive_dashboard(test_results, output_dir="visualizations"):
272
+ """
273
+ Create an interactive HTML dashboard with plots for all store-item combinations.
274
+ Requires Plotly and Dash libraries.
275
+
276
+ Args:
277
+ test_results: DataFrame containing test results
278
+ output_dir: Directory to save the dashboard
279
+ """
280
+ try:
281
+ import plotly.express as px
282
+ import plotly.graph_objects as go
283
+ from plotly.subplots import make_subplots
284
+
285
+ print("Creating interactive dashboard...")
286
+
287
+ # Create output directory if it doesn't exist
288
+ os.makedirs(output_dir, exist_ok=True)
289
+
290
+ # Ensure date is in datetime format
291
+ if not pd.api.types.is_datetime64_any_dtype(test_results["date"]):
292
+ test_results["date"] = pd.to_datetime(test_results["date"])
293
+
294
+ # Create overall performance figure
295
+ daily_results = (
296
+ test_results.groupby("date")
297
+ .agg({"sales": "sum", "prediction": "sum"})
298
+ .reset_index()
299
+ )
300
+
301
+ fig = go.Figure()
302
+ fig.add_trace(
303
+ go.Scatter(
304
+ x=daily_results["date"],
305
+ y=daily_results["sales"],
306
+ mode="lines+markers",
307
+ name="Actual",
308
+ line=dict(color="blue"),
309
+ )
310
+ )
311
+ fig.add_trace(
312
+ go.Scatter(
313
+ x=daily_results["date"],
314
+ y=daily_results["prediction"],
315
+ mode="lines+markers",
316
+ name="Predicted",
317
+ line=dict(color="red", dash="dash"),
318
+ )
319
+ )
320
+
321
+ fig.update_layout(
322
+ title="Overall Sales Performance: Actual vs Predicted",
323
+ xaxis_title="Date",
324
+ yaxis_title="Total Sales",
325
+ legend_title="Series",
326
+ height=600,
327
+ )
328
+
329
+ # Save the overall chart as HTML
330
+ fig.write_html(os.path.join(output_dir, "overall_performance.html"))
331
+
332
+ # Create an error heatmap
333
+ store_item_error = (
334
+ test_results.groupby(["store_name", "item_name"])
335
+ .apply(
336
+ lambda x: np.mean(np.abs((x["sales"] - x["prediction"]) / x["sales"]))
337
+ * 100
338
+ )
339
+ .reset_index()
340
+ )
341
+ store_item_error.columns = ["store_name", "item_name", "mape"]
342
+
343
+ # Pivot the data for the heatmap
344
+ heatmap_data = store_item_error.pivot(
345
+ index="store_name", columns="item_name", values="mape"
346
+ )
347
+
348
+ # Create heatmap figure
349
+ heatmap_fig = px.imshow(
350
+ heatmap_data,
351
+ labels=dict(x="Item", y="Store", color="MAPE (%)"),
352
+ x=heatmap_data.columns,
353
+ y=heatmap_data.index,
354
+ color_continuous_scale="RdBu_r",
355
+ title="Mean Absolute Percentage Error by Store and Item",
356
+ )
357
+
358
+ heatmap_fig.update_layout(height=800, width=1200)
359
+
360
+ # Save the heatmap as HTML
361
+ heatmap_fig.write_html(os.path.join(output_dir, "error_heatmap.html"))
362
+
363
+ print(f"Interactive dashboard elements saved to {output_dir}/")
364
+
365
+ except ImportError:
366
+ print("Could not create interactive dashboard. Plotly library is required.")
367
+ print("Install it with: pip install plotly dash")
368
+
369
+
370
+ def visualize_error_distribution(test_results, output_dir="visualizations"):
371
+ """
372
+ Visualize the distribution and patterns of prediction errors.
373
+
374
+ Args:
375
+ test_results: DataFrame containing test results
376
+ output_dir: Directory to save the visualizations
377
+ """
378
+ # Create output directory if it doesn't exist
379
+ os.makedirs(output_dir, exist_ok=True)
380
+
381
+ # Calculate errors
382
+ test_results["error"] = test_results["sales"] - test_results["prediction"]
383
+ test_results["abs_error"] = np.abs(test_results["error"])
384
+ test_results["pct_error"] = (test_results["error"] / test_results["sales"]) * 100
385
+
386
+ # 1. Error distribution histogram
387
+ plt.figure(figsize=(12, 6))
388
+ sns.histplot(test_results["error"], kde=True, bins=50)
389
+ plt.axvline(x=0, color="red", linestyle="--")
390
+ plt.title("Distribution of Prediction Errors")
391
+ plt.xlabel("Error (Actual - Predicted)")
392
+ plt.ylabel("Frequency")
393
+ plt.grid(True, linestyle="--", alpha=0.7)
394
+ plt.tight_layout()
395
+ plt.savefig(os.path.join(output_dir, "error_distribution.png"))
396
+ plt.close()
397
+
398
+ # 2. Error vs Actual Sales
399
+ plt.figure(figsize=(12, 6))
400
+ plt.scatter(test_results["sales"], test_results["error"], alpha=0.5)
401
+ plt.axhline(y=0, color="red", linestyle="--")
402
+ plt.title("Prediction Error vs Actual Sales")
403
+ plt.xlabel("Actual Sales")
404
+ plt.ylabel("Error (Actual - Predicted)")
405
+ plt.grid(True, linestyle="--", alpha=0.7)
406
+ plt.tight_layout()
407
+ plt.savefig(os.path.join(output_dir, "error_vs_sales.png"))
408
+ plt.close()
409
+
410
+ # 3. Error over time
411
+ plt.figure(figsize=(14, 6))
412
+ # Ensure date is in datetime format
413
+ if not pd.api.types.is_datetime64_any_dtype(test_results["date"]):
414
+ test_results["date"] = pd.to_datetime(test_results["date"])
415
+
416
+ # Group by date to see overall error trend
417
+ daily_error = test_results.groupby("date")["error"].mean().reset_index()
418
+ plt.plot(daily_error["date"], daily_error["error"], "o-")
419
+ plt.axhline(y=0, color="red", linestyle="--")
420
+ plt.title("Mean Prediction Error Over Time")
421
+ plt.xlabel("Date")
422
+ plt.ylabel("Mean Error")
423
+ date_formatter = DateFormatter("%Y-%m-%d")
424
+ plt.gca().xaxis.set_major_formatter(date_formatter)
425
+ plt.xticks(rotation=45)
426
+ plt.grid(True, linestyle="--", alpha=0.7)
427
+ plt.tight_layout()
428
+ plt.savefig(os.path.join(output_dir, "error_over_time.png"))
429
+ plt.close()
430
+
431
+ # 4. Error by day of week
432
+ test_results["day_of_week"] = test_results["date"].dt.dayofweek
433
+ test_results["day_name"] = test_results["date"].dt.day_name()
434
+
435
+ plt.figure(figsize=(12, 6))
436
+ day_error = (
437
+ test_results.groupby("day_name")["pct_error"]
438
+ .mean()
439
+ .reindex(
440
+ [
441
+ "Monday",
442
+ "Tuesday",
443
+ "Wednesday",
444
+ "Thursday",
445
+ "Friday",
446
+ "Saturday",
447
+ "Sunday",
448
+ ]
449
+ )
450
+ )
451
+ sns.barplot(x=day_error.index, y=day_error.values)
452
+ plt.title("Mean Percentage Error by Day of Week")
453
+ plt.xlabel("Day of Week")
454
+ plt.ylabel("Mean Percentage Error (%)")
455
+ plt.axhline(y=0, color="red", linestyle="--")
456
+ plt.grid(True, linestyle="--", alpha=0.7)
457
+ plt.tight_layout()
458
+ plt.savefig(os.path.join(output_dir, "error_by_day_of_week.png"))
459
+ plt.close()
460
+
461
+ # 5. Error by category - only if 'category' column exists
462
+ if "category" in test_results.columns:
463
+ plt.figure(figsize=(12, 6))
464
+ cat_error = test_results.groupby("category")["pct_error"].mean().sort_values()
465
+ sns.barplot(x=cat_error.index, y=cat_error.values)
466
+ plt.title("Mean Percentage Error by Category")
467
+ plt.xlabel("Category")
468
+ plt.ylabel("Mean Percentage Error (%)")
469
+ plt.axhline(y=0, color="red", linestyle="--")
470
+ plt.xticks(rotation=45)
471
+ plt.grid(True, linestyle="--", alpha=0.7)
472
+ plt.tight_layout()
473
+ plt.savefig(os.path.join(output_dir, "error_by_category.png"))
474
+ plt.close()
475
+
476
+ print(f"Error analysis visualizations saved to {output_dir}/")
477
+
478
+
479
+ def create_forecast_dashboard(
480
+ model, X_test, y_test, test_results, data, output_dir="visualizations"
481
+ ):
482
+ """
483
+ Create a comprehensive dashboard of forecast visualizations.
484
+
485
+ Args:
486
+ model: Trained model
487
+ X_test: Test features
488
+ y_test: Test target values
489
+ test_results: DataFrame with test results
490
+ data: Original data with date, store, item info
491
+ output_dir: Directory to save visualizations
492
+ """
493
+ # Create all visualizations
494
+ print("Creating forecast visualizations...")
495
+
496
+ # 1. Individual store-item visualizations (limited to avoid too many plots)
497
+ # Get the top 20 store-item combinations by sales volume
498
+ store_item_sales = (
499
+ test_results.groupby(["store_name", "item_name"])["sales"].sum().reset_index()
500
+ )
501
+ top_combinations = store_item_sales.sort_values("sales", ascending=False).head(20)
502
+
503
+ # Filter test_results to include only these top combinations
504
+ top_results = pd.merge(
505
+ test_results,
506
+ top_combinations[["store_name", "item_name"]],
507
+ on=["store_name", "item_name"],
508
+ )
509
+
510
+ # Create visualizations for top combinations
511
+ visualize_predictions_by_store_item(top_results, output_dir)
512
+
513
+ # 2. Aggregated visualizations
514
+ visualize_aggregated_predictions(test_results, output_dir)
515
+
516
+ # 3. Error distribution and patterns
517
+ visualize_error_distribution(test_results, output_dir)
518
+
519
+ # 4. Try to create interactive dashboard if plotly is available
520
+ create_interactive_dashboard(test_results, output_dir)
521
+
522
+ print("Forecast visualization dashboard created successfully!")