AlgoX commited on
Commit
f69e29e
·
1 Parent(s): d5d3c5b

feat : hawk training file

Browse files
Files changed (1) hide show
  1. train/hawk_train.py +587 -0
train/hawk_train.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.preprocessing import StandardScaler
10
+ import os
11
+ from datetime import datetime
12
+ import json
13
+
14
+
15
+
16
+
17
+ def get_model_device(model):
18
+ return next(iter(model.parameters())).device
19
+
20
+
21
+ class RGLRU(nn.Module):
22
+ def __init__(self, hidden_size: int, c: float = 8.0):
23
+ super().__init__()
24
+ self.hidden_size = hidden_size
25
+ self.c = c
26
+
27
+ self.input_gate = nn.Linear(hidden_size, hidden_size, bias=False)
28
+ self.recurrence_gate = nn.Linear(hidden_size, hidden_size, bias=False)
29
+
30
+ self._base_param = nn.Parameter(torch.empty(hidden_size))
31
+ nn.init.normal_(self._base_param, mean=0.0, std=1.0) # ok to be any real
32
+
33
+ def forward(self, x_t: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
34
+ batch_size, hidden_size = x_t.shape
35
+ assert hidden_size == self.hidden_size
36
+ assert state.shape[0] == batch_size
37
+
38
+ i_t = torch.sigmoid(self.input_gate(x_t))
39
+ r_t = torch.sigmoid(self.recurrence_gate(x_t)) # in (0,1)
40
+
41
+ eps = 1e-4
42
+ base = torch.sigmoid(self._base_param).unsqueeze(0) # shape (1, hidden)
43
+ base = base.clamp(min=eps, max=1.0 - eps)
44
+
45
+ # exponent = c * r_t (positive)
46
+ a_t = base ** (
47
+ self.c * r_t
48
+ ) # shape (batch, hidden), safe because base in (0,1)
49
+
50
+ # ensure numerical stability for sqrt
51
+ one_minus_sq = 1.0 - a_t * a_t
52
+ one_minus_sq = torch.clamp(one_minus_sq, min=0.0)
53
+ multiplier = torch.sqrt(one_minus_sq)
54
+
55
+ new_state = (state * a_t) + (multiplier * (i_t * x_t))
56
+
57
+ return new_state
58
+
59
+ def init_state(self, batch_size: int, device: torch.device | None = None):
60
+ if device is None:
61
+ device = get_model_device(self)
62
+ return torch.zeros(batch_size, self.hidden_size, device=device)
63
+
64
+
65
+ class CausalConv1d(nn.Module):
66
+ def __init__(self, hidden_size, kernel_size):
67
+ super().__init__()
68
+ self.hidden_size = hidden_size
69
+ self.kernel_size = kernel_size
70
+ self.conv = nn.Conv1d(
71
+ hidden_size, hidden_size, kernel_size, groups=hidden_size, bias=True
72
+ )
73
+
74
+ def init_state(self, batch_size: int, device: torch.device | None = None):
75
+ if device is None:
76
+ device = get_model_device(self)
77
+ return torch.zeros(
78
+ batch_size, self.hidden_size, self.kernel_size - 1, device=device
79
+ )
80
+
81
+ def forward(self, x: torch.Tensor, state: torch.Tensor):
82
+ x_with_state = torch.concat([state, x[:, :, None]], dim=-1)
83
+ out = self.conv(x_with_state)
84
+ new_state = x_with_state[:, :, 1:]
85
+ return out.squeeze(-1), new_state
86
+
87
+
88
+ class Hawk(nn.Module):
89
+ def __init__(self, hidden_size: int, conv_kernel_size: int = 4):
90
+ super().__init__()
91
+
92
+ self.conv_kernel_size = conv_kernel_size
93
+ self.hidden_size = hidden_size
94
+
95
+ self.gate_proj = nn.Linear(hidden_size, hidden_size, bias=False)
96
+ self.recurrent_proj = nn.Linear(hidden_size, hidden_size, bias=False)
97
+ self.conv = CausalConv1d(hidden_size, conv_kernel_size)
98
+ self.rglru = RGLRU(hidden_size)
99
+ self.out_proj = nn.Linear(hidden_size, hidden_size, bias=False)
100
+
101
+ def forward(
102
+ self, x: torch.Tensor, state: tuple[torch.Tensor, torch.Tensor]
103
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
104
+ conv_state, rglru_state = state
105
+
106
+ batch_size, hidden_size = x.shape
107
+ assert batch_size == conv_state.shape[0] == rglru_state.shape[0]
108
+ assert self.hidden_size == hidden_size == rglru_state.shape[1]
109
+
110
+ gate = F.gelu(self.gate_proj(x))
111
+ x = self.recurrent_proj(x)
112
+
113
+ x, new_conv_state = self.conv(x, conv_state)
114
+ new_rglru_state = self.rglru(x, rglru_state)
115
+
116
+ gated = gate * new_rglru_state
117
+ out = self.out_proj(gated)
118
+
119
+ new_state = [new_conv_state, new_rglru_state]
120
+ return out, new_state
121
+
122
+ def init_state(
123
+ self, batch_size: int, device: torch.device | None = None
124
+ ) -> list[torch.Tensor]:
125
+ return [
126
+ self.conv.init_state(batch_size, device),
127
+ self.rglru.init_state(batch_size, device),
128
+ ]
129
+
130
+
131
+ class HawkPredictor(nn.Module):
132
+ """Full model with input projection and output head"""
133
+
134
+ def __init__(
135
+ self,
136
+ input_size: int,
137
+ hidden_size: int,
138
+ num_layers: int = 2,
139
+ conv_kernel_size: int = 4,
140
+ dropout: float = 0.1,
141
+ ):
142
+ super().__init__()
143
+ self.input_size = input_size
144
+ self.hidden_size = hidden_size
145
+ self.num_layers = num_layers
146
+
147
+ # Input projection
148
+ self.input_proj = nn.Linear(input_size, hidden_size)
149
+ self.input_norm = nn.LayerNorm(hidden_size)
150
+
151
+ # Hawk layers
152
+ self.hawk_layers = nn.ModuleList(
153
+ [Hawk(hidden_size, conv_kernel_size) for _ in range(num_layers)]
154
+ )
155
+
156
+ # Layer norms
157
+ self.layer_norms = nn.ModuleList(
158
+ [nn.LayerNorm(hidden_size) for _ in range(num_layers)]
159
+ )
160
+
161
+ # Dropout
162
+ self.dropout = nn.Dropout(dropout)
163
+
164
+ # Output head
165
+ self.output_head = nn.Sequential(
166
+ nn.Linear(hidden_size, hidden_size // 2),
167
+ nn.GELU(),
168
+ nn.Dropout(dropout),
169
+ nn.Linear(hidden_size // 2, 1),
170
+ )
171
+
172
+ def forward(self, x: torch.Tensor, states=None):
173
+ """
174
+ Args:
175
+ x: (batch_size, seq_len, input_size)
176
+ states: list of states for each layer
177
+ Returns:
178
+ predictions: (batch_size, seq_len, 1)
179
+ final_states: list of final states
180
+ """
181
+ batch_size, seq_len, _ = x.shape
182
+ device = x.device
183
+
184
+ # Initialize states if needed
185
+ if states is None:
186
+ states = [
187
+ layer.init_state(batch_size, device) for layer in self.hawk_layers
188
+ ]
189
+
190
+ # Input projection
191
+ x = self.input_proj(x) # (batch, seq, hidden)
192
+ x = self.input_norm(x)
193
+
194
+ outputs = []
195
+ final_states = []
196
+
197
+ # Process sequence
198
+ for t in range(seq_len):
199
+ x_t = x[:, t, :] # (batch, hidden)
200
+
201
+ # Pass through Hawk layers
202
+ new_states = []
203
+ for i, (hawk_layer, layer_norm) in enumerate(
204
+ zip(self.hawk_layers, self.layer_norms)
205
+ ):
206
+ residual = x_t
207
+ x_t, state = hawk_layer(x_t, states[i])
208
+ x_t = layer_norm(x_t + residual)
209
+ x_t = self.dropout(x_t)
210
+ new_states.append(state)
211
+
212
+ states = new_states
213
+ outputs.append(x_t)
214
+
215
+ # Stack outputs
216
+ outputs = torch.stack(outputs, dim=1) # (batch, seq, hidden)
217
+
218
+ # Generate predictions
219
+ predictions = self.output_head(outputs) # (batch, seq, 1)
220
+
221
+ return predictions, states
222
+
223
+
224
+
225
+
226
+ class TimeSeriesDataset(Dataset):
227
+ def __init__(self, features, targets, seq_length=20):
228
+ self.features = features
229
+ self.targets = targets
230
+ self.seq_length = seq_length
231
+
232
+ def __len__(self):
233
+ return len(self.features) - self.seq_length
234
+
235
+ def __getitem__(self, idx):
236
+ x = self.features[idx : idx + self.seq_length]
237
+ y = self.targets[idx : idx + self.seq_length]
238
+ return torch.FloatTensor(x), torch.FloatTensor(y).squeeze(-1)
239
+
240
+
241
+
242
+
243
+ class MetricsLogger:
244
+ def __init__(self, save_dir):
245
+ self.save_dir = save_dir
246
+ self.metrics = {
247
+ "train_loss": [],
248
+ "val_loss": [],
249
+ "train_mse": [],
250
+ "val_mse": [],
251
+ "train_mae": [],
252
+ "val_mae": [],
253
+ "learning_rates": [],
254
+ }
255
+
256
+ def update(self, epoch_metrics):
257
+ for key, value in epoch_metrics.items():
258
+ if key in self.metrics:
259
+ self.metrics[key].append(value)
260
+
261
+ def save(self):
262
+ with open(os.path.join(self.save_dir, "metrics.json"), "w") as f:
263
+ json.dump(self.metrics, f, indent=4)
264
+
265
+ def plot_metrics(self):
266
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
267
+ fig.suptitle("Training Metrics", fontsize=16)
268
+
269
+ # Loss
270
+ ax = axes[0, 0]
271
+ ax.plot(self.metrics["train_loss"], label="Train Loss", marker="o")
272
+ ax.plot(self.metrics["val_loss"], label="Val Loss", marker="s")
273
+ ax.set_xlabel("Epoch")
274
+ ax.set_ylabel("Loss")
275
+ ax.set_title("Training and Validation Loss")
276
+ ax.legend()
277
+ ax.grid(True)
278
+
279
+ # MSE
280
+ ax = axes[0, 1]
281
+ ax.plot(self.metrics["train_mse"], label="Train MSE", marker="o")
282
+ ax.plot(self.metrics["val_mse"], label="Val MSE", marker="s")
283
+ ax.set_xlabel("Epoch")
284
+ ax.set_ylabel("MSE")
285
+ ax.set_title("Mean Squared Error")
286
+ ax.legend()
287
+ ax.grid(True)
288
+
289
+ # MAE
290
+ ax = axes[1, 0]
291
+ ax.plot(self.metrics["train_mae"], label="Train MAE", marker="o")
292
+ ax.plot(self.metrics["val_mae"], label="Val MAE", marker="s")
293
+ ax.set_xlabel("Epoch")
294
+ ax.set_ylabel("MAE")
295
+ ax.set_title("Mean Absolute Error")
296
+ ax.legend()
297
+ ax.grid(True)
298
+
299
+ # Learning Rate
300
+ ax = axes[1, 1]
301
+ ax.plot(self.metrics["learning_rates"], marker="o", color="purple")
302
+ ax.set_xlabel("Epoch")
303
+ ax.set_ylabel("Learning Rate")
304
+ ax.set_title("Learning Rate Schedule")
305
+ ax.grid(True)
306
+ ax.set_yscale("log")
307
+
308
+ plt.tight_layout()
309
+ plt.savefig(os.path.join(self.save_dir, "training_metrics.png"), dpi=300)
310
+ plt.close()
311
+
312
+
313
+ def calculate_metrics(predictions, targets):
314
+ """Calculate MSE and MAE"""
315
+ mse = F.mse_loss(predictions, targets).item()
316
+ mae = F.l1_loss(predictions, targets).item()
317
+ return mse, mae
318
+
319
+
320
+ def save_checkpoint(
321
+ model, optimizer, scheduler, epoch, metrics, save_dir, is_best=False
322
+ ):
323
+ checkpoint = {
324
+ "epoch": epoch,
325
+ "model_state_dict": model.state_dict(),
326
+ "optimizer_state_dict": optimizer.state_dict(),
327
+ "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
328
+ "metrics": metrics,
329
+ }
330
+
331
+ # Save regular checkpoint
332
+ checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch}.pt")
333
+ torch.save(checkpoint, checkpoint_path)
334
+
335
+ # Save best model
336
+ if is_best:
337
+ best_path = os.path.join(save_dir, "best_model.pt")
338
+ torch.save(checkpoint, best_path)
339
+ print(f"✓ Saved best model at epoch {epoch}")
340
+
341
+ # Keep only last 5 checkpoints
342
+ checkpoints = sorted(
343
+ [f for f in os.listdir(save_dir) if f.startswith("checkpoint_epoch_")]
344
+ )
345
+ if len(checkpoints) > 5:
346
+ for old_ckpt in checkpoints[:-5]:
347
+ os.remove(os.path.join(save_dir, old_ckpt))
348
+
349
+
350
+
351
+
352
+ def train_epoch(model, train_loader, optimizer, criterion, device):
353
+ model.train()
354
+ total_loss = 0
355
+ all_predictions = []
356
+ all_targets = []
357
+
358
+ for batch_idx, (x, y) in enumerate(train_loader):
359
+ x, y = x.to(device), y.to(device)
360
+
361
+ optimizer.zero_grad()
362
+
363
+ # Forward pass
364
+ predictions, _ = model(x)
365
+ predictions = predictions.squeeze(-1)
366
+
367
+ # Calculate loss
368
+ loss = criterion(predictions, y)
369
+
370
+ # Backward pass
371
+ loss.backward()
372
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
373
+ optimizer.step()
374
+
375
+ total_loss += loss.item()
376
+ all_predictions.append(predictions.detach())
377
+ all_targets.append(y.detach())
378
+
379
+ avg_loss = total_loss / len(train_loader)
380
+ all_predictions = torch.cat(all_predictions, dim=0)
381
+ all_targets = torch.cat(all_targets, dim=0)
382
+ mse, mae = calculate_metrics(all_predictions, all_targets)
383
+
384
+ return avg_loss, mse, mae
385
+
386
+
387
+ def validate(model, val_loader, criterion, device):
388
+ model.eval()
389
+ total_loss = 0
390
+ all_predictions = []
391
+ all_targets = []
392
+
393
+ with torch.no_grad():
394
+ for x, y in val_loader:
395
+ x, y = x.to(device), y.to(device)
396
+
397
+ predictions, _ = model(x)
398
+ predictions = predictions.squeeze(-1)
399
+
400
+ loss = criterion(predictions, y)
401
+
402
+ total_loss += loss.item()
403
+ all_predictions.append(predictions)
404
+ all_targets.append(y)
405
+
406
+ avg_loss = total_loss / len(val_loader)
407
+ all_predictions = torch.cat(all_predictions, dim=0)
408
+ all_targets = torch.cat(all_targets, dim=0)
409
+ mse, mae = calculate_metrics(all_predictions, all_targets)
410
+
411
+ return avg_loss, mse, mae
412
+
413
+
414
+ def train_model(model, train_loader, val_loader, config):
415
+ """Main training loop"""
416
+ device = config["device"]
417
+ model = model.to(device)
418
+
419
+ # Setup
420
+ criterion = nn.MSELoss()
421
+ optimizer = torch.optim.AdamW(
422
+ model.parameters(),
423
+ lr=config["learning_rate"],
424
+ weight_decay=config["weight_decay"],
425
+ )
426
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
427
+ optimizer, mode="min", factor=0.5, patience=5, verbose=True
428
+ )
429
+
430
+ # Create save directory
431
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
432
+ save_dir = os.path.join(config["save_dir"], f"run_{timestamp}")
433
+ os.makedirs(save_dir, exist_ok=True)
434
+
435
+ # Save config
436
+ with open(os.path.join(save_dir, "config.json"), "w") as f:
437
+ json.dump(config, f, indent=4)
438
+
439
+ # Initialize logger
440
+ logger = MetricsLogger(save_dir)
441
+ best_val_loss = float("inf")
442
+
443
+ print(f"{'='*60}")
444
+ print(f"Training started at {timestamp}")
445
+ print(f"Model: {config['model_name']}")
446
+ print(f"Device: {device}")
447
+ print(f"Save directory: {save_dir}")
448
+ print(f"{'='*60}\n")
449
+
450
+ # Training loop
451
+ for epoch in range(1, config["num_epochs"] + 1):
452
+ # Train
453
+ train_loss, train_mse, train_mae = train_epoch(
454
+ model, train_loader, optimizer, criterion, device
455
+ )
456
+
457
+ # Validate
458
+ val_loss, val_mse, val_mae = validate(model, val_loader, criterion, device)
459
+
460
+ # Update scheduler
461
+ scheduler.step(val_loss)
462
+ current_lr = optimizer.param_groups[0]["lr"]
463
+
464
+ # Log metrics
465
+ epoch_metrics = {
466
+ "train_loss": train_loss,
467
+ "val_loss": val_loss,
468
+ "train_mse": train_mse,
469
+ "val_mse": val_mse,
470
+ "train_mae": train_mae,
471
+ "val_mae": val_mae,
472
+ "learning_rates": current_lr,
473
+ }
474
+ logger.update(epoch_metrics)
475
+
476
+ # Print progress
477
+ print(f"Epoch {epoch}/{config['num_epochs']}")
478
+ print(
479
+ f" Train - Loss: {train_loss:.6f}, MSE: {train_mse:.6f}, MAE: {train_mae:.6f}"
480
+ )
481
+ print(f" Val - Loss: {val_loss:.6f}, MSE: {val_mse:.6f}, MAE: {val_mae:.6f}")
482
+ print(f" LR: {current_lr:.2e}")
483
+
484
+ # Save checkpoint
485
+ is_best = val_loss < best_val_loss
486
+ if is_best:
487
+ best_val_loss = val_loss
488
+
489
+ if epoch % config["save_every"] == 0 or is_best:
490
+ save_checkpoint(
491
+ model, optimizer, scheduler, epoch, epoch_metrics, save_dir, is_best
492
+ )
493
+
494
+ # Plot metrics every 10 epochs
495
+ if epoch % 10 == 0:
496
+ logger.plot_metrics()
497
+
498
+ print()
499
+
500
+ # Final save
501
+ logger.save()
502
+ logger.plot_metrics()
503
+
504
+ print(f"{'='*60}")
505
+ print(f"Training completed!")
506
+ print(f"Best validation loss: {best_val_loss:.6f}")
507
+ print(f"Results saved to: {save_dir}")
508
+ print(f"{'='*60}")
509
+
510
+ return model, logger
511
+
512
+
513
+
514
+ if __name__ == "__main__":
515
+ from data_prep.data_clean import clean_indicator
516
+ from data_prep.data_load import prepare_data
517
+ torch.autograd.set_detect_anomaly(True)
518
+ # Configuration
519
+ config = {
520
+ 'model_name': 'HawkPredictor',
521
+ 'seq_length': 20,
522
+ 'hidden_size': 128,
523
+ 'num_layers': 3,
524
+ 'conv_kernel_size': 4,
525
+ 'dropout': 0.2,
526
+ 'batch_size': 64,
527
+ 'num_epochs': 100,
528
+ 'learning_rate': 0.001,
529
+ 'weight_decay': 1e-5,
530
+ 'train_split': 0.8,
531
+ 'save_every': 5,
532
+ 'save_dir': './checkpoints',
533
+ 'device': 'cuda' if torch.cuda.is_available() else 'cpu'
534
+ }
535
+
536
+ print("Loading data...")
537
+ test_dir = "/home/aman/code/ml_fr/ml_stocks/data/NIFTY_5_years.csv"
538
+
539
+ load_df = prepare_data(test_dir)
540
+ df = clean_indicator(load_df)
541
+
542
+ # Prepare features and target
543
+ target_col = "Daily_Return"
544
+ feature_cols = [col for col in df.columns if col != target_col]
545
+
546
+ train_size = int(len(df) * config["train_split"])
547
+ train_df = df[:train_size]
548
+ val_df = df[train_size:]
549
+
550
+ scaler = StandardScaler()
551
+ train_features = scaler.fit_transform(train_df[feature_cols].values)
552
+ val_features = scaler.transform(val_df[feature_cols].values)
553
+
554
+ train_targets = train_df[target_col].values.reshape(-1, 1)
555
+ val_targets = val_df[target_col].values.reshape(-1, 1)
556
+
557
+ # Create datasets
558
+ train_dataset = TimeSeriesDataset(
559
+ train_features, train_targets, config["seq_length"]
560
+ )
561
+ val_dataset = TimeSeriesDataset(val_features, val_targets, config["seq_length"])
562
+
563
+ train_loader = DataLoader(
564
+ train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=0
565
+ )
566
+ val_loader = DataLoader(val_dataset, batch_size=config['batch_size'],
567
+ shuffle=False, num_workers=0)
568
+
569
+ print(f"Training samples: {len(train_dataset)}")
570
+ print(f"Validation samples: {len(val_dataset)}")
571
+ print(f"Input features: {len(feature_cols)}")
572
+
573
+ # Initialize model
574
+ model = HawkPredictor(
575
+ input_size=len(feature_cols),
576
+ hidden_size=config['hidden_size'],
577
+ num_layers=config['num_layers'],
578
+ conv_kernel_size=config['conv_kernel_size'],
579
+ dropout=config['dropout']
580
+ )
581
+
582
+ print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")
583
+
584
+ # Train model
585
+ trained_model, metrics_logger = train_model(model, train_loader, val_loader, config)
586
+
587
+ print("\nTraining complete! Check the checkpoints directory for saved models and metrics.")