""" train_lstm.py – LSTM model training for traffic flow prediction Features: - LSTM model with configurable architecture - Weighted loss for handling speed class imbalance - Huber loss option (better than regular loss per user experience) - CLI interface for hyperparameter tuning - Model and encoder saving - Chronological train/val/test splits """ import argparse import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from pathlib import Path import joblib from typing import Dict, Tuple, Optional from encode import TrafficDataEncoder # Device selection if torch.backends.mps.is_available(): DEVICE = torch.device("mps") elif torch.cuda.is_available(): DEVICE = torch.device("cuda") else: DEVICE = torch.device("cpu") print(f"Using device: {DEVICE}") class LSTMRegressor(nn.Module): """LSTM model for traffic speed prediction.""" def __init__( self, n_features: int, hidden_size: int = 128, n_layers: int = 2, dropout: float = 0.3, bidirectional: bool = False ): super().__init__() self.hidden_size = hidden_size self.n_layers = n_layers self.bidirectional = bidirectional # LSTM layer self.lstm = nn.LSTM( input_size=n_features, hidden_size=hidden_size, num_layers=n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0, bidirectional=bidirectional ) # Output layer lstm_output_size = hidden_size * (2 if bidirectional else 1) self.head = nn.Sequential( nn.Linear(lstm_output_size, hidden_size // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_size // 2, 1) ) def forward(self, x): """Forward pass through the LSTM.""" # LSTM forward pass lstm_out, _ = self.lstm(x) # Use the last timestep output last_output = lstm_out[:, -1, :] # Final prediction prediction = self.head(last_output) return prediction class WeightedHuberLoss(nn.Module): """Weighted Huber loss for handling speed class imbalance.""" def __init__(self, weight_dict: Dict[str, float], delta: float = 1.0, boost_low: float = 1.0): super().__init__() self.delta = delta self.weight_low = weight_dict["weight_low"] * boost_low # Additional boost for low speeds self.weight_medium = weight_dict["weight_medium"] self.weight_high = weight_dict["weight_high"] self.low_threshold = weight_dict["low_threshold"] self.high_threshold = weight_dict["high_threshold"] def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute weighted Huber loss.""" # Ensure target is 1D if target.dim() > 1: target = target.squeeze() if pred.dim() > 1: pred = pred.squeeze() # Compute Huber loss diff = torch.abs(pred - target) huber_loss = torch.where( diff <= self.delta, 0.5 * diff ** 2, self.delta * (diff - 0.5 * self.delta) ) # Compute weights based on speed classes weights = torch.ones_like(target) low_mask = target <= self.low_threshold high_mask = target >= self.high_threshold medium_mask = ~(low_mask | high_mask) weights[low_mask] = self.weight_low weights[medium_mask] = self.weight_medium weights[high_mask] = self.weight_high # Apply weights weighted_loss = huber_loss * weights return weighted_loss.mean() class FocalHuberLoss(nn.Module): """Focal loss variant for Huber loss to focus on hard examples.""" def __init__(self, weight_dict: Dict[str, float], delta: float = 1.0, alpha: float = 2.0, gamma: float = 2.0): super().__init__() self.delta = delta self.alpha = alpha self.gamma = gamma self.weight_low = weight_dict["weight_low"] self.weight_medium = weight_dict["weight_medium"] self.weight_high = weight_dict["weight_high"] self.low_threshold = weight_dict["low_threshold"] self.high_threshold = weight_dict["high_threshold"] def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute focal Huber loss.""" if target.dim() > 1: target = target.squeeze() if pred.dim() > 1: pred = pred.squeeze() # Compute Huber loss diff = torch.abs(pred - target) huber_loss = torch.where( diff <= self.delta, 0.5 * diff ** 2, self.delta * (diff - 0.5 * self.delta) ) # Compute focal weights (higher loss = harder example) focal_weights = self.alpha * (huber_loss ** self.gamma) # Apply class weights class_weights = torch.ones_like(target) low_mask = target <= self.low_threshold high_mask = target >= self.high_threshold medium_mask = ~(low_mask | high_mask) class_weights[low_mask] = self.weight_low class_weights[medium_mask] = self.weight_medium class_weights[high_mask] = self.weight_high # Combine focal and class weights total_weights = focal_weights * class_weights weighted_loss = huber_loss * total_weights return weighted_loss.mean() def create_data_loaders( X: np.ndarray, y: np.ndarray, timestamps: np.ndarray, batch_size: int, train_ratio: float = 0.7, val_ratio: float = 0.15 ) -> Tuple[DataLoader, DataLoader, DataLoader, np.ndarray]: """ Create chronological train/validation/test data loaders. Args: X: Input sequences (N, seq_len, n_features) y: Target values (N, horizon) timestamps: Timestamps for each sample batch_size: Batch size for data loaders train_ratio: Fraction of data for training val_ratio: Fraction of data for validation Returns: train_loader, val_loader, test_loader, test_indices """ # Sort by timestamp to ensure chronological order sorted_indices = np.argsort(timestamps) X_sorted = X[sorted_indices] y_sorted = y[sorted_indices] # Calculate split points n_total = len(X_sorted) n_train = int(n_total * train_ratio) n_val = int(n_total * val_ratio) # Split indices train_indices = sorted_indices[:n_train] val_indices = sorted_indices[n_train:n_train + n_val] test_indices = sorted_indices[n_train + n_val:] # Convert timestamps to datetime for date range display timestamps_dt = pd.to_datetime(timestamps) print(f"Data split:") print(f" Train: {len(train_indices):,} samples ({train_ratio*100:.0f}%)") if len(train_indices) > 0: train_dates = timestamps_dt[train_indices] print(f" Date range: {train_dates.min()} to {train_dates.max()}") print(f" Val: {len(val_indices):,} samples ({val_ratio*100:.0f}%)") if len(val_indices) > 0: val_dates = timestamps_dt[val_indices] print(f" Date range: {val_dates.min()} to {val_dates.max()}") print(f" Test: {len(test_indices):,} samples ({(1-train_ratio-val_ratio)*100:.0f}%)") if len(test_indices) > 0: test_dates = timestamps_dt[test_indices] print(f" Date range: {test_dates.min()} to {test_dates.max()}") # Create data loaders def create_loader(indices, shuffle=False): X_subset = torch.from_numpy(X[indices]).float() y_subset = torch.from_numpy(y[indices]).float() dataset = TensorDataset(X_subset, y_subset) return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) train_loader = create_loader(train_indices, shuffle=True) val_loader = create_loader(val_indices, shuffle=False) test_loader = create_loader(test_indices, shuffle=False) return train_loader, val_loader, test_loader, test_indices def train_epoch( model: LSTMRegressor, train_loader: DataLoader, optimizer: torch.optim.Optimizer, loss_fn: nn.Module, device: torch.device ) -> float: """Train the model for one epoch.""" model.train() total_loss = 0.0 num_batches = 0 for batch_X, batch_y in train_loader: batch_X = batch_X.to(device) batch_y = batch_y.to(device) # Forward pass optimizer.zero_grad() predictions = model(batch_X) loss = loss_fn(predictions, batch_y) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() total_loss += loss.item() num_batches += 1 return total_loss / num_batches def evaluate( model: LSTMRegressor, data_loader: DataLoader, loss_fn: nn.Module, device: torch.device ) -> float: """Evaluate the model on a dataset.""" model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch_X, batch_y in data_loader: batch_X = batch_X.to(device) batch_y = batch_y.to(device) predictions = model(batch_X) loss = loss_fn(predictions, batch_y) total_loss += loss.item() num_batches += 1 return total_loss / num_batches def main(): """Main training function.""" parser = argparse.ArgumentParser(description="Train LSTM model for traffic prediction") # Data parameters parser.add_argument("--csv", required=True, help="Path to CSV file with traffic data") parser.add_argument("--seq_len", type=int, default=12, help="Sequence length (default: 12)") parser.add_argument("--horizon", type=int, default=1, help="Prediction horizon (default: 1)") parser.add_argument("--target_col", default="speed_mph", help="Target column name") # Model parameters parser.add_argument("--hidden_size", type=int, default=128, help="LSTM hidden size") parser.add_argument("--n_layers", type=int, default=2, help="Number of LSTM layers") parser.add_argument("--dropout", type=float, default=0.3, help="Dropout rate") parser.add_argument("--bidirectional", action="store_true", help="Use bidirectional LSTM") # Training parameters parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs") parser.add_argument("--batch_size", type=int, default=256, help="Batch size") parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay") # Loss parameters parser.add_argument("--loss_type", choices=["mse", "mae", "huber", "weighted_huber", "focal_huber"], default="weighted_huber", help="Loss function type") parser.add_argument("--huber_delta", type=float, default=1.0, help="Huber loss delta") parser.add_argument("--boost_low", type=float, default=1.0, help="Additional boost for low-speed loss (weighted_huber only)") parser.add_argument("--focal_alpha", type=float, default=2.0, help="Focal loss alpha parameter") parser.add_argument("--focal_gamma", type=float, default=2.0, help="Focal loss gamma parameter") # Data split parameters parser.add_argument("--train_ratio", type=float, default=0.7, help="Training data ratio") parser.add_argument("--val_ratio", type=float, default=0.15, help="Validation data ratio") # Output parameters parser.add_argument("--model_out", help="Path to save the best model") parser.add_argument("--encoder_out", help="Path to save the fitted encoder") parser.add_argument("--pred_csv", help="Path to save test predictions") parser.add_argument("--log_file", help="Path to save training log") args = parser.parse_args() # Load and encode data print("Loading data...") df = pd.read_csv(args.csv) print(f"Loaded {len(df):,} rows from {args.csv}") # Create encoder encoder = TrafficDataEncoder( seq_len=args.seq_len, horizon=args.horizon, target_col=args.target_col ) # Fit encoder and transform data print("Encoding data...") X, y, target_indices, timestamps = encoder.fit_transform(df) print(f"Encoded data shapes: X={X.shape}, y={y.shape}") # Save encoder if requested if args.encoder_out: encoder.save(args.encoder_out) # Create data loaders print("Creating data loaders...") train_loader, val_loader, test_loader, test_indices = create_data_loaders( X, y, timestamps, args.batch_size, args.train_ratio, args.val_ratio ) # Initialize model print("Initializing model...") model = LSTMRegressor( n_features=X.shape[2], hidden_size=args.hidden_size, n_layers=args.n_layers, dropout=args.dropout, bidirectional=args.bidirectional ).to(DEVICE) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") # Initialize optimizer optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) # Initialize loss function if args.loss_type == "weighted_huber": # Get speed weights from encoder weight_dict = encoder.get_speed_weights(y.flatten()) loss_fn = WeightedHuberLoss(weight_dict, args.huber_delta, args.boost_low) print(f"Using weighted Huber loss with low-speed boost: {args.boost_low}") elif args.loss_type == "focal_huber": # Get speed weights from encoder weight_dict = encoder.get_speed_weights(y.flatten()) loss_fn = FocalHuberLoss(weight_dict, args.huber_delta, args.focal_alpha, args.focal_gamma) print(f"Using focal Huber loss (alpha={args.focal_alpha}, gamma={args.focal_gamma})") elif args.loss_type == "huber": loss_fn = nn.SmoothL1Loss(beta=args.huber_delta) print("Using Huber loss") elif args.loss_type == "mae": loss_fn = nn.L1Loss() print("Using MAE loss") else: # mse loss_fn = nn.MSELoss() print("Using MSE loss") # Training loop print("Starting training...") best_val_loss = float('inf') best_model_state = None train_losses = [] val_losses = [] for epoch in range(1, args.epochs + 1): # Train train_loss = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE) # Validate val_loss = evaluate(model, val_loader, loss_fn, DEVICE) train_losses.append(train_loss) val_losses.append(val_loss) print(f"Epoch {epoch:3d}/{args.epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss best_model_state = model.state_dict().copy() print(f" -> New best validation loss: {best_val_loss:.4f}") # Load best model and evaluate on test set print("\nEvaluating on test set...") model.load_state_dict(best_model_state) test_loss = evaluate(model, test_loader, loss_fn, DEVICE) print(f"Test Loss: {test_loss:.4f}") # Save best model if args.model_out: torch.save(best_model_state, args.model_out) print(f"Best model saved to {args.model_out}") # Save predictions if requested if args.pred_csv: print("Generating test predictions...") model.eval() predictions = [] targets = [] with torch.no_grad(): for batch_X, batch_y in test_loader: batch_X = batch_X.to(DEVICE) batch_pred = model(batch_X).cpu().numpy() predictions.append(batch_pred) targets.append(batch_y.numpy()) predictions = np.concatenate(predictions, axis=0) targets = np.concatenate(targets, axis=0) # Create prediction DataFrame pred_df = pd.DataFrame({ 'prediction': predictions.flatten(), 'target': targets.flatten(), 'error': predictions.flatten() - targets.flatten(), 'abs_error': np.abs(predictions.flatten() - targets.flatten()) }) pred_df.to_csv(args.pred_csv, index=False) print(f"Predictions saved to {args.pred_csv}") # Print some statistics mae = pred_df['abs_error'].mean() rmse = np.sqrt((pred_df['error'] ** 2).mean()) print(f"Test MAE: {mae:.4f}") print(f"Test RMSE: {rmse:.4f}") # Save training log if requested if args.log_file: log_df = pd.DataFrame({ 'epoch': range(1, len(train_losses) + 1), 'train_loss': train_losses, 'val_loss': val_losses }) log_df.to_csv(args.log_file, index=False) print(f"Training log saved to {args.log_file}") if __name__ == "__main__": main()