SowmiyaNagaraj commited on
Commit
d872cb5
·
verified ·
1 Parent(s): 8eb433e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +378 -0
app.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import gradio as gr
4
+ from statsmodels.tsa.arima.model import ARIMA
5
+ from sklearn.preprocessing import MinMaxScaler
6
+ from sklearn.metrics import r2_score
7
+ from tensorflow.keras.models import Sequential
8
+ from tensorflow.keras.layers import LSTM, Dense
9
+ from tensorflow.keras.optimizers import Adam
10
+ import warnings
11
+ import matplotlib.pyplot as plt
12
+ from matplotlib.ticker import MaxNLocator
13
+ import os
14
+
15
+ warnings.filterwarnings("ignore")
16
+
17
+ # Load Dataset with better error handling
18
+ try:
19
+ # Print current directory to help debug file location issues
20
+ print(f"Current working directory: {os.getcwd()}")
21
+ print(f"Files in directory: {os.listdir()}")
22
+
23
+ df = pd.read_csv('/content/drive/MyDrive/enhanced_sales_data_for_arima_lstm.csv')
24
+ print("\nDataset loaded successfully!")
25
+ print(f"Columns in dataset: {df.columns.tolist()}")
26
+
27
+ # Convert Date column to datetime
28
+ df['Date'] = pd.to_datetime(df['Date'])
29
+ df = df.sort_values(['Product_Name', 'Date'])
30
+
31
+ # Check if required columns exist
32
+ required_columns = ['Product_Name', 'Date', 'Sales']
33
+ if not all(col in df.columns for col in required_columns):
34
+ missing = [col for col in required_columns if col not in df.columns]
35
+ print(f"\nERROR: Missing required columns: {missing}")
36
+ df = None
37
+ else:
38
+ print(f"\nFirst few products: {df['Product_Name'].unique()[:5]}... (total: {len(df['Product_Name'].unique())} products)")
39
+
40
+ except FileNotFoundError:
41
+ df = None
42
+ print("\nERROR: Dataset file not found!")
43
+ print("Please make sure the file exists in the specified path.")
44
+ except Exception as e:
45
+ df = None
46
+ print(f"\nERROR loading dataset: {str(e)}")
47
+
48
+ # Get product list with fallback
49
+ if df is not None and 'Product_Name' in df.columns:
50
+ product_list = sorted(df['Product_Name'].unique().tolist())
51
+ print(f"\nProducts loaded ({len(product_list)} total):")
52
+ print(product_list[:5], "...") if len(product_list) > 5 else print(product_list)
53
+ else:
54
+ product_list = []
55
+ print("\nNo products loaded - using empty list")
56
+
57
+ def prepare_data(product_name):
58
+ if df is None:
59
+ print("ERROR: No data available (df is None)")
60
+ return None
61
+
62
+ print(f"\nPreparing data for product: {product_name}")
63
+ data = df[df['Product_Name'] == product_name][['Date', 'Sales']].set_index('Date')['Sales']
64
+
65
+ if data.empty:
66
+ print(f"WARNING: No sales data found for product: {product_name}")
67
+ return None
68
+
69
+ print(f"Found {len(data)} data points for {product_name}")
70
+ return data
71
+
72
+ def train_arima(data, steps=60):
73
+ if len(data) < 6:
74
+ print("ARIMA: Not enough data (need at least 6 points)")
75
+ return None
76
+
77
+ try:
78
+ print(f"\nTraining ARIMA model on {len(data)} data points...")
79
+ model = ARIMA(data, order=(5,1,0))
80
+ model_fit = model.fit()
81
+ forecast = model_fit.forecast(steps=steps)
82
+ print("ARIMA training completed successfully")
83
+ return forecast
84
+ except Exception as e:
85
+ print(f"ARIMA Error: {e}")
86
+ return None
87
+
88
+ def train_lstm(data, steps=60):
89
+ if len(data) < 6:
90
+ print("LSTM: Not enough data (need at least 6 points)")
91
+ return None
92
+
93
+ try:
94
+ print(f"\nTraining LSTM model on {len(data)} data points...")
95
+ scaler = MinMaxScaler()
96
+ data_scaled = scaler.fit_transform(data.values.reshape(-1, 1))
97
+
98
+ X, y = [], []
99
+ for i in range(5, len(data_scaled)):
100
+ X.append(data_scaled[i-5:i, 0])
101
+ y.append(data_scaled[i, 0])
102
+
103
+ if len(X) < 1:
104
+ print("LSTM: Not enough data after windowing")
105
+ return None
106
+
107
+ X, y = np.array(X), np.array(y)
108
+ X = X.reshape(X.shape[0], X.shape[1], 1)
109
+
110
+ model = Sequential([
111
+ LSTM(50, activation='relu', return_sequences=True, input_shape=(X.shape[1], 1)),
112
+ LSTM(50, activation='relu'),
113
+ Dense(1)
114
+ ])
115
+ model.compile(optimizer=Adam(learning_rate=0.01), loss='mse')
116
+ model.fit(X, y, epochs=20, batch_size=4, verbose=0)
117
+
118
+ last_sequence = data_scaled[-5:].reshape(1, 5, 1)
119
+ predictions = []
120
+
121
+ for _ in range(steps):
122
+ next_pred = model.predict(last_sequence, verbose=0)
123
+ predictions.append(next_pred[0,0])
124
+ last_sequence = np.append(last_sequence[:,1:,:], next_pred.reshape(1,1,1), axis=1)
125
+
126
+ print("LSTM training completed successfully")
127
+ return scaler.inverse_transform(np.array(predictions).reshape(-1, 1)).flatten()
128
+ except Exception as e:
129
+ print(f"LSTM Error: {e}")
130
+ return None
131
+
132
+ def hybrid_prediction(data):
133
+ print("\nStarting hybrid prediction...")
134
+ arima_pred = train_arima(data)
135
+ lstm_pred = train_lstm(data)
136
+
137
+ if arima_pred is None or lstm_pred is None:
138
+ error_msg = "Model training failed - "
139
+ error_msg += "ARIMA failed" if arima_pred is None else ""
140
+ error_msg += " and " if arima_pred is None and lstm_pred is None else ""
141
+ error_msg += "LSTM failed" if lstm_pred is None else ""
142
+ print(error_msg)
143
+ return {"error": error_msg}
144
+
145
+ min_length = min(len(arima_pred), len(lstm_pred))
146
+ if min_length < 60:
147
+ error_msg = f"Prediction length too short: {min_length} (need 60)"
148
+ print(error_msg)
149
+ return {"error": error_msg}
150
+
151
+ final_pred = 0.5 * np.array(arima_pred[:60]) + 0.5 * np.array(lstm_pred[:60])
152
+ print("Hybrid prediction completed successfully")
153
+ return final_pred.tolist()
154
+
155
+ def create_monthly_plot(monthly_data, product_name):
156
+ fig, ax = plt.subplots(figsize=(12, 6))
157
+ months = [f"Month {i+1}" for i in range(len(monthly_data))]
158
+
159
+ # Bar plot
160
+ bars = ax.bar(months, monthly_data, color='skyblue', alpha=0.7, label='Monthly Forecast')
161
+
162
+ # Line plot on top
163
+ ax.plot(months, monthly_data, color='red', marker='o', linestyle='-', linewidth=2, markersize=5, label='Trend')
164
+
165
+ ax.set_title(f"5-Year Monthly Sales Forecast for {product_name}", fontsize=14)
166
+ ax.set_xlabel("Months", fontsize=12)
167
+ ax.set_ylabel("Sales", fontsize=12)
168
+ ax.grid(True, linestyle='--', alpha=0.7)
169
+ ax.legend()
170
+
171
+ # Rotate x-axis labels and show only every 6th month to avoid crowding
172
+ plt.xticks(rotation=45, ha='right')
173
+ for i, label in enumerate(ax.xaxis.get_ticklabels()):
174
+ if i % 6 != 0:
175
+ label.set_visible(False)
176
+
177
+ plt.tight_layout()
178
+ return fig
179
+
180
+ def create_yearly_scatter(yearly_data, product_name):
181
+ fig, ax = plt.subplots(figsize=(12, 6))
182
+ colors = ['red', 'blue', 'green', 'purple', 'orange']
183
+ markers = ['o', 's', 'D', '^', 'v'] # Different markers for each year
184
+
185
+ for year_idx, year_data in enumerate(yearly_data):
186
+ months = np.arange(1, 13) # 1-12 months
187
+ ax.scatter(months, year_data, color=colors[year_idx],
188
+ marker=markers[year_idx], s=100, label=f'Year {year_idx+1}', alpha=0.7)
189
+
190
+ ax.set_title(f"Yearly Sales Comparison for {product_name}", fontsize=14)
191
+ ax.set_xlabel("Month of Year", fontsize=12)
192
+ ax.set_ylabel("Sales", fontsize=12)
193
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True)) # Only integer months
194
+ ax.grid(True, linestyle='--', alpha=0.7)
195
+ ax.legend()
196
+
197
+ plt.tight_layout()
198
+ return fig
199
+
200
+ def create_evaluation_plot(actual, predicted, product_name, r2_score):
201
+ fig, ax = plt.subplots(figsize=(12, 6))
202
+ months = [f"Month {i+1}" for i in range(len(actual))]
203
+
204
+ ax.plot(months, actual, 'b-', label='Actual Sales', marker='o')
205
+ ax.plot(months, predicted, 'r--', label='Predicted Sales', marker='x')
206
+
207
+ ax.set_title(f"Model Evaluation for {product_name}\nR² Score: {r2_score:.2f}", fontsize=14)
208
+ ax.set_xlabel("Months", fontsize=12)
209
+ ax.set_ylabel("Sales", fontsize=12)
210
+ ax.grid(True, linestyle='--', alpha=0.7)
211
+ ax.legend()
212
+
213
+ plt.xticks(rotation=45, ha='right')
214
+ plt.tight_layout()
215
+ return fig
216
+
217
+ def predict(product_name):
218
+ print(f"\nStarting prediction for: {product_name}")
219
+ if df is None:
220
+ error_msg = "Dataset not loaded or could not be processed"
221
+ print(error_msg)
222
+ return {"error": error_msg}, None, None
223
+
224
+ sales_data = prepare_data(product_name)
225
+ if sales_data is None or len(sales_data) < 6:
226
+ error_msg = "Not enough historical data for prediction"
227
+ print(error_msg)
228
+ return {"error": error_msg}, None, None
229
+
230
+ predictions = hybrid_prediction(sales_data)
231
+
232
+ if isinstance(predictions, dict) and "error" in predictions:
233
+ return predictions, None, None
234
+
235
+ monthly = predictions[:60]
236
+ yearly = [monthly[i*12:(i+1)*12] for i in range(5)]
237
+
238
+ monthly_plot = create_monthly_plot(monthly, product_name)
239
+ yearly_plot = create_yearly_scatter(yearly, product_name)
240
+
241
+ print(f"Successfully generated forecast for {product_name}")
242
+ return None, monthly_plot, yearly_plot
243
+
244
+ def evaluate_model(product_name, test_size=12):
245
+ print(f"\nStarting evaluation for: {product_name}")
246
+ if df is None:
247
+ error_msg = "Dataset not loaded or could not be processed"
248
+ print(error_msg)
249
+ return {"error": error_msg}, None
250
+
251
+ data = prepare_data(product_name)
252
+ if data is None or len(data) < test_size + 6:
253
+ error_msg = "Not enough data to evaluate model"
254
+ print(error_msg)
255
+ return {"error": error_msg}, None
256
+
257
+ train_data = data[:-test_size]
258
+ test_data = data[-test_size:]
259
+
260
+ arima_pred = train_arima(train_data, steps=test_size)
261
+ lstm_pred = train_lstm(train_data, steps=test_size)
262
+
263
+ if arima_pred is None or lstm_pred is None:
264
+ error_msg = "Model training failed during evaluation"
265
+ print(error_msg)
266
+ return {"error": error_msg}, None
267
+
268
+ hybrid_pred = 0.5 * np.array(arima_pred) + 0.5 * np.array(lstm_pred)
269
+ score = r2_score(test_data.values, hybrid_pred)
270
+
271
+ evaluation_plot = create_evaluation_plot(
272
+ test_data.values,
273
+ hybrid_pred,
274
+ product_name,
275
+ score
276
+ )
277
+
278
+ print(f"Evaluation completed for {product_name} with R² score: {score:.2f}")
279
+ return None, evaluation_plot
280
+
281
+ # Create Gradio interface
282
+ with gr.Blocks(title="Sales Forecast Dashboard", theme="soft") as demo:
283
+ gr.Markdown("# 🚀 Hybrid ARIMA-LSTM Sales Forecasting")
284
+ gr.Markdown("Predict 5 years of monthly sales and evaluate model accuracy")
285
+
286
+ with gr.Tabs():
287
+ with gr.Tab("📈 Forecast Sales"):
288
+ gr.Markdown("### Generate 5-Year Sales Forecast")
289
+ with gr.Row():
290
+ product_dropdown = gr.Dropdown(
291
+ choices=product_list,
292
+ label="Select Product",
293
+ interactive=True,
294
+ value=product_list[0] if product_list else None
295
+ )
296
+ forecast_btn = gr.Button("Generate Forecast", variant="primary")
297
+
298
+ error_box = gr.JSON(
299
+ label="Error Messages",
300
+ visible=False,
301
+ elem_id="error-box"
302
+ )
303
+
304
+ with gr.Row():
305
+ with gr.Column():
306
+ gr.Markdown("### Monthly Forecast")
307
+ monthly_plot = gr.Plot(
308
+ label="Monthly Sales Forecast",
309
+ show_label=True
310
+ )
311
+ with gr.Column():
312
+ gr.Markdown("### Yearly Comparison")
313
+ yearly_plot = gr.Plot(
314
+ label="Yearly Sales Pattern",
315
+ show_label=True
316
+ )
317
+
318
+ # Examples section
319
+ if product_list:
320
+ gr.Examples(
321
+ examples=[[product] for product in product_list[:3]],
322
+ inputs=product_dropdown,
323
+ label="Try these products:"
324
+ )
325
+
326
+ forecast_btn.click(
327
+ fn=predict,
328
+ inputs=product_dropdown,
329
+ outputs=[error_box, monthly_plot, yearly_plot],
330
+ api_name="predict"
331
+ )
332
+
333
+ with gr.Tab("📊 Evaluate Accuracy"):
334
+ gr.Markdown("### Evaluate Model Performance")
335
+ with gr.Row():
336
+ eval_product_dropdown = gr.Dropdown(
337
+ choices=product_list,
338
+ label="Select Product",
339
+ interactive=True,
340
+ value=product_list[0] if product_list else None
341
+ )
342
+ evaluate_btn = gr.Button("Evaluate Model", variant="primary")
343
+
344
+ eval_error_box = gr.JSON(
345
+ label="Error Messages",
346
+ visible=False,
347
+ elem_id="error-box"
348
+ )
349
+
350
+ gr.Markdown("### Actual vs Predicted Sales")
351
+ evaluation_plot = gr.Plot(
352
+ label="Model Evaluation Results",
353
+ show_label=True
354
+ )
355
+
356
+ evaluate_btn.click(
357
+ fn=evaluate_model,
358
+ inputs=eval_product_dropdown,
359
+ outputs=[eval_error_box, evaluation_plot],
360
+ api_name="evaluate"
361
+ )
362
+
363
+ # Add some debug info if no products found
364
+ if not product_list:
365
+ gr.Markdown("## ⚠️ No Products Found")
366
+ gr.Markdown("""
367
+ The application couldn't load any products. This usually means:
368
+ - The dataset file wasn't found at the specified path
369
+ - The dataset doesn't contain the required columns (Product_Name, Date, Sales)
370
+ - There was an error loading the data
371
+
372
+ Check the console output for more details.
373
+ """)
374
+
375
+ # Launch the application
376
+ if __name__ == "__main__":
377
+ print("\nStarting Gradio application...")
378
+ demo.launch()