AlgoX commited on
Commit
77a4f6b
·
1 Parent(s): b6f3206

feat : add training script for time-series model + classical models

Browse files
Files changed (1) hide show
  1. train/classical_train.py +652 -0
train/classical_train.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from sklearn.ensemble import (
6
+ RandomForestRegressor,
7
+ GradientBoostingRegressor,
8
+ ExtraTreesRegressor,
9
+ )
10
+ from sklearn.linear_model import Ridge, Lasso, ElasticNet
11
+ from sklearn.svm import SVR
12
+ from sklearn.preprocessing import StandardScaler
13
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
14
+ import xgboost as xgb
15
+ import lightgbm as lgb
16
+ from statsmodels.tsa.arima.model import ARIMA
17
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
18
+ from statsmodels.tsa.holtwinters import ExponentialSmoothing
19
+ import warnings
20
+ import json
21
+ import os
22
+ from datetime import datetime
23
+ import pickle
24
+
25
+ warnings.filterwarnings("ignore")
26
+
27
+
28
+
29
+
30
+ def create_sequences(features, targets, seq_length=20):
31
+ """Create sequences for time series prediction"""
32
+ X, y = [], []
33
+ for i in range(len(features) - seq_length):
34
+ X.append(features[i : i + seq_length].flatten()) # Flatten sequence
35
+ y.append(targets[i + seq_length]) # Predict next value
36
+ return np.array(X), np.array(y)
37
+
38
+
39
+ def create_lagged_features(df, target_col, lags=[1, 2, 3, 5, 10, 20]):
40
+ """Create lagged features for time series"""
41
+ df_lagged = df.copy()
42
+ for lag in lags:
43
+ df_lagged[f"{target_col}_lag_{lag}"] = df_lagged[target_col].shift(lag)
44
+
45
+ # Add rolling statistics
46
+ for window in [5, 10, 20]:
47
+ df_lagged[f"{target_col}_rolling_mean_{window}"] = (
48
+ df_lagged[target_col].rolling(window).mean()
49
+ )
50
+ df_lagged[f"{target_col}_rolling_std_{window}"] = (
51
+ df_lagged[target_col].rolling(window).std()
52
+ )
53
+
54
+ # Drop NaN values created by lagging
55
+ df_lagged = df_lagged.dropna()
56
+ return df_lagged
57
+
58
+
59
+
60
+
61
+ class ModelTrainer:
62
+ def __init__(self, model_name, model, save_dir="./checkpoints_classical"):
63
+ self.model_name = model_name
64
+ self.model = model
65
+ self.save_dir = save_dir
66
+ self.metrics = {}
67
+ self.predictions = None
68
+
69
+ def train(self, X_train, y_train):
70
+ """Train the model"""
71
+ print(f"\nTraining {self.model_name}...")
72
+ self.model.fit(X_train, y_train)
73
+
74
+ def predict(self, X):
75
+ """Make predictions"""
76
+ return self.model.predict(X)
77
+
78
+ def evaluate(self, X_train, y_train, X_val, y_val):
79
+ """Evaluate model on train and validation sets"""
80
+ train_pred = self.predict(X_train)
81
+ val_pred = self.predict(X_val)
82
+
83
+ self.metrics = {
84
+ "train_mse": mean_squared_error(y_train, train_pred),
85
+ "train_rmse": np.sqrt(mean_squared_error(y_train, train_pred)),
86
+ "train_mae": mean_absolute_error(y_train, train_pred),
87
+ "train_r2": r2_score(y_train, train_pred),
88
+ "val_mse": mean_squared_error(y_val, val_pred),
89
+ "val_rmse": np.sqrt(mean_squared_error(y_val, val_pred)),
90
+ "val_mae": mean_absolute_error(y_val, val_pred),
91
+ "val_r2": r2_score(y_val, val_pred),
92
+ }
93
+
94
+ self.predictions = {"train": train_pred, "val": val_pred}
95
+
96
+ return self.metrics
97
+
98
+ def save_model(self, run_dir):
99
+ """Save model to disk"""
100
+ model_path = os.path.join(run_dir, f"{self.model_name}_model.pkl")
101
+ with open(model_path, "wb") as f:
102
+ pickle.dump(self.model, f)
103
+ print(f"✓ Saved {self.model_name} model")
104
+
105
+
106
+
107
+
108
+ class ARIMAModel:
109
+ def __init__(self, order=(1, 1, 1)):
110
+ self.order = order
111
+ self.model = None
112
+ self.model_fit = None
113
+
114
+ def fit(self, X_train, y_train):
115
+ """Fit ARIMA model - uses only target variable"""
116
+ # ARIMA works on univariate time series
117
+ self.model = ARIMA(y_train, order=self.order)
118
+ self.model_fit = self.model.fit()
119
+
120
+ def predict(self, X):
121
+ """Make predictions"""
122
+ n_periods = len(X)
123
+ forecast = self.model_fit.forecast(steps=n_periods)
124
+ return np.array(forecast)
125
+
126
+
127
+ class SARIMAXModel:
128
+ def __init__(self, order=(1, 1, 1), seasonal_order=(0, 0, 0, 0)):
129
+ self.order = order
130
+ self.seasonal_order = seasonal_order
131
+ self.model = None
132
+ self.model_fit = None
133
+
134
+ def fit(self, X_train, y_train):
135
+ """Fit SARIMAX model"""
136
+ self.model = SARIMAX(
137
+ y_train, order=self.order, seasonal_order=self.seasonal_order
138
+ )
139
+ self.model_fit = self.model.fit(disp=False)
140
+
141
+ def predict(self, X):
142
+ """Make predictions"""
143
+ n_periods = len(X)
144
+ forecast = self.model_fit.forecast(steps=n_periods)
145
+ return np.array(forecast)
146
+
147
+
148
+ class ExponentialSmoothingModel:
149
+ def __init__(self, seasonal_periods=None):
150
+ self.seasonal_periods = seasonal_periods
151
+ self.model = None
152
+ self.model_fit = None
153
+
154
+ def fit(self, X_train, y_train):
155
+ """Fit Exponential Smoothing model"""
156
+ self.model = ExponentialSmoothing(
157
+ y_train,
158
+ seasonal_periods=self.seasonal_periods,
159
+ trend="add",
160
+ seasonal="add" if self.seasonal_periods else None,
161
+ )
162
+ self.model_fit = self.model.fit()
163
+
164
+ def predict(self, X):
165
+ """Make predictions"""
166
+ n_periods = len(X)
167
+ forecast = self.model_fit.forecast(steps=n_periods)
168
+ return np.array(forecast)
169
+
170
+
171
+
172
+
173
+ def get_ml_models():
174
+ """Get dictionary of classical ML models"""
175
+ models = {
176
+ # Linear Models
177
+ "Ridge": Ridge(alpha=1.0),
178
+ "Lasso": Lasso(alpha=0.1),
179
+ "ElasticNet": ElasticNet(alpha=0.1, l1_ratio=0.5),
180
+ "RandomForest": RandomForestRegressor(
181
+ n_estimators=100,
182
+ max_depth=10,
183
+ min_samples_split=5,
184
+ random_state=42,
185
+ n_jobs=-1,
186
+ ),
187
+ "ExtraTrees": ExtraTreesRegressor(
188
+ n_estimators=100,
189
+ max_depth=10,
190
+ min_samples_split=5,
191
+ random_state=42,
192
+ n_jobs=-1,
193
+ ),
194
+ "GradientBoosting": GradientBoostingRegressor(
195
+ n_estimators=100, max_depth=5, learning_rate=0.1, random_state=42
196
+ ),
197
+ "XGBoost": xgb.XGBRegressor(
198
+ n_estimators=100, max_depth=5, learning_rate=0.1, random_state=42, n_jobs=-1
199
+ ),
200
+ "LightGBM": lgb.LGBMRegressor(
201
+ n_estimators=100,
202
+ max_depth=5,
203
+ learning_rate=0.1,
204
+ random_state=42,
205
+ n_jobs=-1,
206
+ verbose=-1,
207
+ ),
208
+ "SVR": SVR(kernel="rbf", C=1.0, epsilon=0.1),
209
+ }
210
+ return models
211
+
212
+
213
+ def get_time_series_models():
214
+ """Get dictionary of time series models"""
215
+ models = {
216
+ "ARIMA": ARIMAModel(order=(2, 1, 2)),
217
+ "SARIMAX": SARIMAXModel(order=(1, 1, 1), seasonal_order=(1, 1, 1, 5)),
218
+ "ExpSmoothing": ExponentialSmoothingModel(seasonal_periods=5),
219
+ }
220
+ return models
221
+
222
+
223
+
224
+
225
+ def train_ml_models(X_train, y_train, X_val, y_val, save_dir):
226
+ """Train all classical ML models"""
227
+ models = get_ml_models()
228
+ results = {}
229
+ trained_models = {}
230
+
231
+ print("\n" + "=" * 60)
232
+ print("TRAINING CLASSICAL ML MODELS")
233
+ print("=" * 60)
234
+
235
+ for name, model in models.items():
236
+ try:
237
+ trainer = ModelTrainer(name, model, save_dir)
238
+ trainer.train(X_train, y_train)
239
+ metrics = trainer.evaluate(X_train, y_train, X_val, y_val)
240
+ trainer.save_model(save_dir)
241
+
242
+ results[name] = metrics
243
+ trained_models[name] = trainer
244
+
245
+ print(f"\n{name}:")
246
+ print(
247
+ f" Train - RMSE: {metrics['train_rmse']:.6f}, MAE: {metrics['train_mae']:.6f}, R²: {metrics['train_r2']:.4f}"
248
+ )
249
+ print(
250
+ f" Val - RMSE: {metrics['val_rmse']:.6f}, MAE: {metrics['val_mae']:.6f}, R²: {metrics['val_r2']:.4f}"
251
+ )
252
+
253
+ except Exception as e:
254
+ print(f"\n{name}: FAILED - {str(e)}")
255
+ results[name] = None
256
+
257
+ return results, trained_models
258
+
259
+
260
+ def train_time_series_models(y_train, y_val, save_dir):
261
+ """Train time series models (univariate)"""
262
+ models = get_time_series_models()
263
+ results = {}
264
+ trained_models = {}
265
+
266
+ print("\n" + "=" * 60)
267
+ print("TRAINING TIME SERIES MODELS")
268
+ print("=" * 60)
269
+
270
+ for name, model in models.items():
271
+ try:
272
+ trainer = ModelTrainer(name, model, save_dir)
273
+ # Time series models use only target variable
274
+ trainer.train(None, y_train)
275
+
276
+ # Make predictions
277
+ train_pred = trainer.predict(np.arange(len(y_train)))
278
+ val_pred = trainer.predict(np.arange(len(y_val)))
279
+
280
+ # Calculate metrics
281
+ metrics = {
282
+ "train_mse": mean_squared_error(y_train, train_pred),
283
+ "train_rmse": np.sqrt(mean_squared_error(y_train, train_pred)),
284
+ "train_mae": mean_absolute_error(y_train, train_pred),
285
+ "train_r2": r2_score(y_train, train_pred),
286
+ "val_mse": mean_squared_error(y_val, val_pred),
287
+ "val_rmse": np.sqrt(mean_squared_error(y_val, val_pred)),
288
+ "val_mae": mean_absolute_error(y_val, val_pred),
289
+ "val_r2": r2_score(y_val, val_pred),
290
+ }
291
+
292
+ trainer.metrics = metrics
293
+ trainer.predictions = {"train": train_pred, "val": val_pred}
294
+ trainer.save_model(save_dir)
295
+
296
+ results[name] = metrics
297
+ trained_models[name] = trainer
298
+
299
+ print(f"\n{name}:")
300
+ print(
301
+ f" Train - RMSE: {metrics['train_rmse']:.6f}, MAE: {metrics['train_mae']:.6f}, R²: {metrics['train_r2']:.4f}"
302
+ )
303
+ print(
304
+ f" Val - RMSE: {metrics['val_rmse']:.6f}, MAE: {metrics['val_mae']:.6f}, R²: {metrics['val_r2']:.4f}"
305
+ )
306
+
307
+ except Exception as e:
308
+ print(f"\n{name}: FAILED - {str(e)}")
309
+ results[name] = None
310
+
311
+ return results, trained_models
312
+
313
+
314
+
315
+
316
+ def plot_model_comparison(results, save_dir):
317
+ """Plot comparison of all models"""
318
+ # Filter out failed models
319
+ results = {k: v for k, v in results.items() if v is not None}
320
+
321
+ if not results:
322
+ print("No successful models to plot")
323
+ return
324
+
325
+ models = list(results.keys())
326
+
327
+ # Extract metrics
328
+ train_rmse = [results[m]["train_rmse"] for m in models]
329
+ val_rmse = [results[m]["val_rmse"] for m in models]
330
+ train_mae = [results[m]["train_mae"] for m in models]
331
+ val_mae = [results[m]["val_mae"] for m in models]
332
+ train_r2 = [results[m]["train_r2"] for m in models]
333
+ val_r2 = [results[m]["val_r2"] for m in models]
334
+
335
+ # Create comparison plots
336
+ fig, axes = plt.subplots(2, 2, figsize=(16, 12))
337
+
338
+ # RMSE comparison
339
+ ax = axes[0, 0]
340
+ x = np.arange(len(models))
341
+ width = 0.35
342
+ ax.bar(x - width / 2, train_rmse, width, label="Train", alpha=0.8)
343
+ ax.bar(x + width / 2, val_rmse, width, label="Validation", alpha=0.8)
344
+ ax.set_xlabel("Model")
345
+ ax.set_ylabel("RMSE")
346
+ ax.set_title("Root Mean Squared Error Comparison")
347
+ ax.set_xticks(x)
348
+ ax.set_xticklabels(models, rotation=45, ha="right")
349
+ ax.legend()
350
+ ax.grid(True, alpha=0.3)
351
+
352
+ # MAE comparison
353
+ ax = axes[0, 1]
354
+ ax.bar(x - width / 2, train_mae, width, label="Train", alpha=0.8)
355
+ ax.bar(x + width / 2, val_mae, width, label="Validation", alpha=0.8)
356
+ ax.set_xlabel("Model")
357
+ ax.set_ylabel("MAE")
358
+ ax.set_title("Mean Absolute Error Comparison")
359
+ ax.set_xticks(x)
360
+ ax.set_xticklabels(models, rotation=45, ha="right")
361
+ ax.legend()
362
+ ax.grid(True, alpha=0.3)
363
+
364
+ # R² comparison
365
+ ax = axes[1, 0]
366
+ ax.bar(x - width / 2, train_r2, width, label="Train", alpha=0.8)
367
+ ax.bar(x + width / 2, val_r2, width, label="Validation", alpha=0.8)
368
+ ax.set_xlabel("Model")
369
+ ax.set_ylabel("R² Score")
370
+ ax.set_title("R² Score Comparison")
371
+ ax.set_xticks(x)
372
+ ax.set_xticklabels(models, rotation=45, ha="right")
373
+ ax.legend()
374
+ ax.grid(True, alpha=0.3)
375
+
376
+ # Validation RMSE sorted
377
+ ax = axes[1, 1]
378
+ sorted_idx = np.argsort(val_rmse)
379
+ sorted_models = [models[i] for i in sorted_idx]
380
+ sorted_rmse = [val_rmse[i] for i in sorted_idx]
381
+ colors = plt.cm.RdYlGn_r(np.linspace(0.3, 0.9, len(sorted_models)))
382
+ ax.barh(range(len(sorted_models)), sorted_rmse, color=colors)
383
+ ax.set_yticks(range(len(sorted_models)))
384
+ ax.set_yticklabels(sorted_models)
385
+ ax.set_xlabel("Validation RMSE")
386
+ ax.set_title("Models Ranked by Validation RMSE")
387
+ ax.grid(True, alpha=0.3, axis="x")
388
+
389
+ plt.tight_layout()
390
+ plt.savefig(
391
+ os.path.join(save_dir, "model_comparison.png"), dpi=300, bbox_inches="tight"
392
+ )
393
+ print(f"\n✓ Saved model comparison plot")
394
+ plt.close()
395
+
396
+
397
+ def plot_predictions_comparison(trained_models, y_val, save_dir, n_samples=200):
398
+ """Plot predictions from top models"""
399
+ # Get top 5 models by validation RMSE
400
+ model_scores = [
401
+ (name, trainer.metrics["val_rmse"])
402
+ for name, trainer in trained_models.items()
403
+ if trainer.metrics is not None
404
+ ]
405
+ model_scores.sort(key=lambda x: x[1])
406
+ top_models = model_scores[:5]
407
+
408
+ fig, axes = plt.subplots(len(top_models), 1, figsize=(14, 4 * len(top_models)))
409
+ if len(top_models) == 1:
410
+ axes = [axes]
411
+
412
+ plot_len = min(n_samples, len(y_val))
413
+
414
+ for i, (name, score) in enumerate(top_models):
415
+ ax = axes[i]
416
+ trainer = trained_models[name]
417
+ val_pred = trainer.predictions["val"]
418
+
419
+ ax.plot(y_val[:plot_len], label="Actual", alpha=0.7, linewidth=1.5)
420
+ ax.plot(val_pred[:plot_len], label="Predicted", alpha=0.7, linewidth=1.5)
421
+ ax.set_xlabel("Time Step")
422
+ ax.set_ylabel("Value")
423
+ ax.set_title(f"{name} Predictions (Val RMSE: {score:.6f})")
424
+ ax.legend()
425
+ ax.grid(True, alpha=0.3)
426
+
427
+ plt.tight_layout()
428
+ plt.savefig(
429
+ os.path.join(save_dir, "top_model_predictions.png"),
430
+ dpi=300,
431
+ bbox_inches="tight",
432
+ )
433
+ print(f"✓ Saved top model predictions plot")
434
+ plt.close()
435
+
436
+
437
+ def create_results_table(results, save_dir):
438
+ """Create and save results table"""
439
+ # Filter out failed models
440
+ results = {k: v for k, v in results.items() if v is not None}
441
+
442
+ df = pd.DataFrame(results).T
443
+ df = df.sort_values("val_rmse")
444
+
445
+ print("\n" + "=" * 80)
446
+ print("MODEL COMPARISON RESULTS (sorted by validation RMSE)")
447
+ print("=" * 80)
448
+ print(df.to_string())
449
+ print("=" * 80)
450
+
451
+ # Save to CSV
452
+ df.to_csv(os.path.join(save_dir, "results_comparison.csv"))
453
+ print(f"\n✓ Saved results table")
454
+
455
+ return df
456
+
457
+
458
+ # ========================= ABLATION STUDIES =========================
459
+
460
+
461
+ def run_ablation_study(X_train, y_train, X_val, y_val, save_dir):
462
+ """Run ablation studies on feature importance and model configurations"""
463
+
464
+ print("\n" + "=" * 60)
465
+ print("ABLATION STUDY: Feature Importance")
466
+ print("=" * 60)
467
+
468
+ # Train a Random Forest to get feature importances
469
+ rf_model = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=-1)
470
+ rf_model.fit(X_train, y_train)
471
+
472
+ # Get feature importances
473
+ importances = rf_model.feature_importances_
474
+
475
+ # Test with different number of features
476
+ n_features_list = [10, 20, 50, 100, X_train.shape[1]]
477
+ ablation_results = {}
478
+
479
+ for n_features in n_features_list:
480
+ if n_features > X_train.shape[1]:
481
+ continue
482
+
483
+ # Select top n features
484
+ top_indices = np.argsort(importances)[-n_features:]
485
+ X_train_subset = X_train[:, top_indices]
486
+ X_val_subset = X_val[:, top_indices]
487
+
488
+ # Train model with subset
489
+ model = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=-1)
490
+ model.fit(X_train_subset, y_train)
491
+
492
+ val_pred = model.predict(X_val_subset)
493
+ rmse = np.sqrt(mean_squared_error(y_val, val_pred))
494
+ mae = mean_absolute_error(y_val, val_pred)
495
+ r2 = r2_score(y_val, val_pred)
496
+
497
+ ablation_results[f"Top_{n_features}_features"] = {
498
+ "val_rmse": rmse,
499
+ "val_mae": mae,
500
+ "val_r2": r2,
501
+ }
502
+
503
+ print(
504
+ f"\nTop {n_features} features: RMSE={rmse:.6f}, MAE={mae:.6f}, R²={r2:.4f}"
505
+ )
506
+
507
+ # Save ablation results
508
+ ablation_df = pd.DataFrame(ablation_results).T
509
+ ablation_df.to_csv(os.path.join(save_dir, "ablation_feature_importance.csv"))
510
+
511
+ # Plot ablation results
512
+ fig, ax = plt.subplots(figsize=(10, 6))
513
+ x = range(len(ablation_results))
514
+ ax.plot(
515
+ list(ablation_results.keys()),
516
+ [v["val_rmse"] for v in ablation_results.values()],
517
+ "o-",
518
+ linewidth=2,
519
+ markersize=8,
520
+ )
521
+ ax.set_xlabel("Number of Features")
522
+ ax.set_ylabel("Validation RMSE")
523
+ ax.set_title("Ablation Study: Impact of Feature Count on Performance")
524
+ ax.grid(True, alpha=0.3)
525
+ plt.xticks(rotation=45, ha="right")
526
+ plt.tight_layout()
527
+ plt.savefig(os.path.join(save_dir, "ablation_feature_importance.png"), dpi=300)
528
+ plt.close()
529
+
530
+ print(f"\n✓ Saved ablation study results")
531
+
532
+ return ablation_results
533
+
534
+
535
+ # ========================= MAIN EXECUTION =========================
536
+
537
+
538
+ def main():
539
+ from data_prep.data_clean import clean_indicator
540
+ from data_prep.data_load import prepare_data
541
+
542
+ # Configuration
543
+ config = {
544
+ "data_path": "/home/aman/code/ml_fr/ml_stocks/data/NIFTY_5_years.csv",
545
+ "seq_length": 20,
546
+ "train_split": 0.8,
547
+ "save_dir": "./checkpoints_classical",
548
+ "target_col": "Daily_Return",
549
+ }
550
+
551
+ # Create save directory
552
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
553
+ save_dir = os.path.join(config["save_dir"], f"run_{timestamp}")
554
+ os.makedirs(save_dir, exist_ok=True)
555
+
556
+ print(f"\n{'='*60}")
557
+ print(f"CLASSICAL ML & TIME SERIES MODEL TRAINING")
558
+ print(f"{'='*60}")
559
+ print(f"Save directory: {save_dir}")
560
+ print(f"{'='*60}\n")
561
+
562
+ # Load and prepare data
563
+ print("Loading data...")
564
+ load_df = prepare_data(config["data_path"])
565
+ df = clean_indicator(load_df)
566
+
567
+ target_col = config["target_col"]
568
+ feature_cols = [col for col in df.columns if col != target_col]
569
+
570
+ # Split data
571
+ train_size = int(len(df) * config["train_split"])
572
+ train_df = df[:train_size]
573
+ val_df = df[train_size:]
574
+
575
+ print(f"Train samples: {len(train_df)}")
576
+ print(f"Validation samples: {len(val_df)}")
577
+ print(f"Number of features: {len(feature_cols)}")
578
+
579
+ # Prepare features for ML models (with sequences)
580
+ scaler = StandardScaler()
581
+ train_features = scaler.fit_transform(train_df[feature_cols].values)
582
+ val_features = scaler.transform(val_df[feature_cols].values)
583
+
584
+ train_targets = train_df[target_col].values
585
+ val_targets = val_df[target_col].values
586
+
587
+ # Create sequences
588
+ X_train, y_train = create_sequences(
589
+ train_features, train_targets, config["seq_length"]
590
+ )
591
+ X_val, y_val = create_sequences(val_features, val_targets, config["seq_length"])
592
+
593
+ print(f"\nSequence shape: {X_train.shape}")
594
+ print(f"Target shape: {y_train.shape}")
595
+
596
+ # Save config
597
+ with open(os.path.join(save_dir, "config.json"), "w") as f:
598
+ json.dump(config, f, indent=4)
599
+
600
+ # Train ML models
601
+ ml_results, ml_models = train_ml_models(X_train, y_train, X_val, y_val, save_dir)
602
+
603
+ # Train time series models (using non-sequenced data)
604
+ ts_results, ts_models = train_time_series_models(
605
+ train_targets[config["seq_length"] :], # Align with ML model targets
606
+ val_targets[config["seq_length"] :],
607
+ save_dir,
608
+ )
609
+
610
+ # Combine results
611
+ all_results = {**ml_results, **ts_results}
612
+ all_models = {**ml_models, **ts_models}
613
+
614
+ # Create visualizations
615
+ print("\n" + "=" * 60)
616
+ print("CREATING VISUALIZATIONS")
617
+ print("=" * 60)
618
+
619
+ plot_model_comparison(all_results, save_dir)
620
+ plot_predictions_comparison(all_models, y_val, save_dir)
621
+ results_df = create_results_table(all_results, save_dir)
622
+
623
+ # Run ablation study
624
+ ablation_results = run_ablation_study(X_train, y_train, X_val, y_val, save_dir)
625
+
626
+ print(f"\n{'='*60}")
627
+ print("TRAINING COMPLETE!")
628
+ print(f"Results saved to: {save_dir}")
629
+ print(f"{'='*60}\n")
630
+
631
+ # Print best model
632
+ best_model = results_df.index[0]
633
+ best_rmse = results_df.loc[best_model, "val_rmse"]
634
+ print(f"🏆 Best Model: {best_model}")
635
+ print(f" Validation RMSE: {best_rmse:.6f}")
636
+ print(f" Validation MAE: {results_df.loc[best_model, 'val_mae']:.6f}")
637
+ print(f" Validation R²: {results_df.loc[best_model, 'val_r2']:.4f}")
638
+
639
+ return all_results, all_models, save_dir
640
+
641
+
642
+ if __name__ == "__main__":
643
+ results, models, save_dir = main()
644
+
645
+ print("\n" + "=" * 60)
646
+ print("All models trained successfully!")
647
+ print("Check the save directory for:")
648
+ print(" - Model comparison plots")
649
+ print(" - Results CSV")
650
+ print(" - Saved model files (.pkl)")
651
+ print(" - Ablation study results")
652
+ print("=" * 60)