|
|
""" |
|
|
train_lstm.py – LSTM model training for traffic flow prediction |
|
|
|
|
|
Features: |
|
|
- LSTM model with configurable architecture |
|
|
- Weighted loss for handling speed class imbalance |
|
|
- Huber loss option (better than regular loss per user experience) |
|
|
- CLI interface for hyperparameter tuning |
|
|
- Model and encoder saving |
|
|
- Chronological train/val/test splits |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader, TensorDataset |
|
|
from pathlib import Path |
|
|
import joblib |
|
|
from typing import Dict, Tuple, Optional |
|
|
|
|
|
from encode import TrafficDataEncoder |
|
|
|
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
|
DEVICE = torch.device("mps") |
|
|
elif torch.cuda.is_available(): |
|
|
DEVICE = torch.device("cuda") |
|
|
else: |
|
|
DEVICE = torch.device("cpu") |
|
|
|
|
|
print(f"Using device: {DEVICE}") |
|
|
|
|
|
|
|
|
class LSTMRegressor(nn.Module): |
|
|
"""LSTM model for traffic speed prediction.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
n_features: int, |
|
|
hidden_size: int = 128, |
|
|
n_layers: int = 2, |
|
|
dropout: float = 0.3, |
|
|
bidirectional: bool = False |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.n_layers = n_layers |
|
|
self.bidirectional = bidirectional |
|
|
|
|
|
|
|
|
self.lstm = nn.LSTM( |
|
|
input_size=n_features, |
|
|
hidden_size=hidden_size, |
|
|
num_layers=n_layers, |
|
|
batch_first=True, |
|
|
dropout=dropout if n_layers > 1 else 0, |
|
|
bidirectional=bidirectional |
|
|
) |
|
|
|
|
|
|
|
|
lstm_output_size = hidden_size * (2 if bidirectional else 1) |
|
|
self.head = nn.Sequential( |
|
|
nn.Linear(lstm_output_size, hidden_size // 2), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_size // 2, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass through the LSTM.""" |
|
|
|
|
|
lstm_out, _ = self.lstm(x) |
|
|
|
|
|
|
|
|
last_output = lstm_out[:, -1, :] |
|
|
|
|
|
|
|
|
prediction = self.head(last_output) |
|
|
return prediction |
|
|
|
|
|
|
|
|
class WeightedHuberLoss(nn.Module): |
|
|
"""Weighted Huber loss for handling speed class imbalance.""" |
|
|
|
|
|
def __init__(self, weight_dict: Dict[str, float], delta: float = 1.0, boost_low: float = 1.0): |
|
|
super().__init__() |
|
|
self.delta = delta |
|
|
self.weight_low = weight_dict["weight_low"] * boost_low |
|
|
self.weight_medium = weight_dict["weight_medium"] |
|
|
self.weight_high = weight_dict["weight_high"] |
|
|
self.low_threshold = weight_dict["low_threshold"] |
|
|
self.high_threshold = weight_dict["high_threshold"] |
|
|
|
|
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
|
"""Compute weighted Huber loss.""" |
|
|
|
|
|
if target.dim() > 1: |
|
|
target = target.squeeze() |
|
|
if pred.dim() > 1: |
|
|
pred = pred.squeeze() |
|
|
|
|
|
|
|
|
diff = torch.abs(pred - target) |
|
|
huber_loss = torch.where( |
|
|
diff <= self.delta, |
|
|
0.5 * diff ** 2, |
|
|
self.delta * (diff - 0.5 * self.delta) |
|
|
) |
|
|
|
|
|
|
|
|
weights = torch.ones_like(target) |
|
|
low_mask = target <= self.low_threshold |
|
|
high_mask = target >= self.high_threshold |
|
|
medium_mask = ~(low_mask | high_mask) |
|
|
|
|
|
weights[low_mask] = self.weight_low |
|
|
weights[medium_mask] = self.weight_medium |
|
|
weights[high_mask] = self.weight_high |
|
|
|
|
|
|
|
|
weighted_loss = huber_loss * weights |
|
|
return weighted_loss.mean() |
|
|
|
|
|
|
|
|
class FocalHuberLoss(nn.Module): |
|
|
"""Focal loss variant for Huber loss to focus on hard examples.""" |
|
|
|
|
|
def __init__(self, weight_dict: Dict[str, float], delta: float = 1.0, alpha: float = 2.0, gamma: float = 2.0): |
|
|
super().__init__() |
|
|
self.delta = delta |
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
self.weight_low = weight_dict["weight_low"] |
|
|
self.weight_medium = weight_dict["weight_medium"] |
|
|
self.weight_high = weight_dict["weight_high"] |
|
|
self.low_threshold = weight_dict["low_threshold"] |
|
|
self.high_threshold = weight_dict["high_threshold"] |
|
|
|
|
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
|
"""Compute focal Huber loss.""" |
|
|
if target.dim() > 1: |
|
|
target = target.squeeze() |
|
|
if pred.dim() > 1: |
|
|
pred = pred.squeeze() |
|
|
|
|
|
|
|
|
diff = torch.abs(pred - target) |
|
|
huber_loss = torch.where( |
|
|
diff <= self.delta, |
|
|
0.5 * diff ** 2, |
|
|
self.delta * (diff - 0.5 * self.delta) |
|
|
) |
|
|
|
|
|
|
|
|
focal_weights = self.alpha * (huber_loss ** self.gamma) |
|
|
|
|
|
|
|
|
class_weights = torch.ones_like(target) |
|
|
low_mask = target <= self.low_threshold |
|
|
high_mask = target >= self.high_threshold |
|
|
medium_mask = ~(low_mask | high_mask) |
|
|
|
|
|
class_weights[low_mask] = self.weight_low |
|
|
class_weights[medium_mask] = self.weight_medium |
|
|
class_weights[high_mask] = self.weight_high |
|
|
|
|
|
|
|
|
total_weights = focal_weights * class_weights |
|
|
weighted_loss = huber_loss * total_weights |
|
|
|
|
|
return weighted_loss.mean() |
|
|
|
|
|
|
|
|
def create_data_loaders( |
|
|
X: np.ndarray, |
|
|
y: np.ndarray, |
|
|
timestamps: np.ndarray, |
|
|
batch_size: int, |
|
|
train_ratio: float = 0.7, |
|
|
val_ratio: float = 0.15 |
|
|
) -> Tuple[DataLoader, DataLoader, DataLoader, np.ndarray]: |
|
|
""" |
|
|
Create chronological train/validation/test data loaders. |
|
|
|
|
|
Args: |
|
|
X: Input sequences (N, seq_len, n_features) |
|
|
y: Target values (N, horizon) |
|
|
timestamps: Timestamps for each sample |
|
|
batch_size: Batch size for data loaders |
|
|
train_ratio: Fraction of data for training |
|
|
val_ratio: Fraction of data for validation |
|
|
|
|
|
Returns: |
|
|
train_loader, val_loader, test_loader, test_indices |
|
|
""" |
|
|
|
|
|
sorted_indices = np.argsort(timestamps) |
|
|
X_sorted = X[sorted_indices] |
|
|
y_sorted = y[sorted_indices] |
|
|
|
|
|
|
|
|
n_total = len(X_sorted) |
|
|
n_train = int(n_total * train_ratio) |
|
|
n_val = int(n_total * val_ratio) |
|
|
|
|
|
|
|
|
train_indices = sorted_indices[:n_train] |
|
|
val_indices = sorted_indices[n_train:n_train + n_val] |
|
|
test_indices = sorted_indices[n_train + n_val:] |
|
|
|
|
|
|
|
|
timestamps_dt = pd.to_datetime(timestamps) |
|
|
|
|
|
print(f"Data split:") |
|
|
print(f" Train: {len(train_indices):,} samples ({train_ratio*100:.0f}%)") |
|
|
if len(train_indices) > 0: |
|
|
train_dates = timestamps_dt[train_indices] |
|
|
print(f" Date range: {train_dates.min()} to {train_dates.max()}") |
|
|
|
|
|
print(f" Val: {len(val_indices):,} samples ({val_ratio*100:.0f}%)") |
|
|
if len(val_indices) > 0: |
|
|
val_dates = timestamps_dt[val_indices] |
|
|
print(f" Date range: {val_dates.min()} to {val_dates.max()}") |
|
|
|
|
|
print(f" Test: {len(test_indices):,} samples ({(1-train_ratio-val_ratio)*100:.0f}%)") |
|
|
if len(test_indices) > 0: |
|
|
test_dates = timestamps_dt[test_indices] |
|
|
print(f" Date range: {test_dates.min()} to {test_dates.max()}") |
|
|
|
|
|
|
|
|
def create_loader(indices, shuffle=False): |
|
|
X_subset = torch.from_numpy(X[indices]).float() |
|
|
y_subset = torch.from_numpy(y[indices]).float() |
|
|
dataset = TensorDataset(X_subset, y_subset) |
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) |
|
|
|
|
|
train_loader = create_loader(train_indices, shuffle=True) |
|
|
val_loader = create_loader(val_indices, shuffle=False) |
|
|
test_loader = create_loader(test_indices, shuffle=False) |
|
|
|
|
|
return train_loader, val_loader, test_loader, test_indices |
|
|
|
|
|
|
|
|
def train_epoch( |
|
|
model: LSTMRegressor, |
|
|
train_loader: DataLoader, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
loss_fn: nn.Module, |
|
|
device: torch.device |
|
|
) -> float: |
|
|
"""Train the model for one epoch.""" |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
for batch_X, batch_y in train_loader: |
|
|
batch_X = batch_X.to(device) |
|
|
batch_y = batch_y.to(device) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
predictions = model(batch_X) |
|
|
loss = loss_fn(predictions, batch_y) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
optimizer.step() |
|
|
|
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
return total_loss / num_batches |
|
|
|
|
|
|
|
|
def evaluate( |
|
|
model: LSTMRegressor, |
|
|
data_loader: DataLoader, |
|
|
loss_fn: nn.Module, |
|
|
device: torch.device |
|
|
) -> float: |
|
|
"""Evaluate the model on a dataset.""" |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_X, batch_y in data_loader: |
|
|
batch_X = batch_X.to(device) |
|
|
batch_y = batch_y.to(device) |
|
|
|
|
|
predictions = model(batch_X) |
|
|
loss = loss_fn(predictions, batch_y) |
|
|
|
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
return total_loss / num_batches |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main training function.""" |
|
|
parser = argparse.ArgumentParser(description="Train LSTM model for traffic prediction") |
|
|
|
|
|
|
|
|
parser.add_argument("--csv", required=True, help="Path to CSV file with traffic data") |
|
|
parser.add_argument("--seq_len", type=int, default=12, help="Sequence length (default: 12)") |
|
|
parser.add_argument("--horizon", type=int, default=1, help="Prediction horizon (default: 1)") |
|
|
parser.add_argument("--target_col", default="speed_mph", help="Target column name") |
|
|
|
|
|
|
|
|
parser.add_argument("--hidden_size", type=int, default=128, help="LSTM hidden size") |
|
|
parser.add_argument("--n_layers", type=int, default=2, help="Number of LSTM layers") |
|
|
parser.add_argument("--dropout", type=float, default=0.3, help="Dropout rate") |
|
|
parser.add_argument("--bidirectional", action="store_true", help="Use bidirectional LSTM") |
|
|
|
|
|
|
|
|
parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs") |
|
|
parser.add_argument("--batch_size", type=int, default=256, help="Batch size") |
|
|
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") |
|
|
parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay") |
|
|
|
|
|
|
|
|
parser.add_argument("--loss_type", choices=["mse", "mae", "huber", "weighted_huber", "focal_huber"], |
|
|
default="weighted_huber", help="Loss function type") |
|
|
parser.add_argument("--huber_delta", type=float, default=1.0, help="Huber loss delta") |
|
|
parser.add_argument("--boost_low", type=float, default=1.0, help="Additional boost for low-speed loss (weighted_huber only)") |
|
|
parser.add_argument("--focal_alpha", type=float, default=2.0, help="Focal loss alpha parameter") |
|
|
parser.add_argument("--focal_gamma", type=float, default=2.0, help="Focal loss gamma parameter") |
|
|
|
|
|
|
|
|
parser.add_argument("--train_ratio", type=float, default=0.7, help="Training data ratio") |
|
|
parser.add_argument("--val_ratio", type=float, default=0.15, help="Validation data ratio") |
|
|
|
|
|
|
|
|
parser.add_argument("--model_out", help="Path to save the best model") |
|
|
parser.add_argument("--encoder_out", help="Path to save the fitted encoder") |
|
|
parser.add_argument("--pred_csv", help="Path to save test predictions") |
|
|
parser.add_argument("--log_file", help="Path to save training log") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
print("Loading data...") |
|
|
df = pd.read_csv(args.csv) |
|
|
print(f"Loaded {len(df):,} rows from {args.csv}") |
|
|
|
|
|
|
|
|
encoder = TrafficDataEncoder( |
|
|
seq_len=args.seq_len, |
|
|
horizon=args.horizon, |
|
|
target_col=args.target_col |
|
|
) |
|
|
|
|
|
|
|
|
print("Encoding data...") |
|
|
X, y, target_indices, timestamps = encoder.fit_transform(df) |
|
|
print(f"Encoded data shapes: X={X.shape}, y={y.shape}") |
|
|
|
|
|
|
|
|
if args.encoder_out: |
|
|
encoder.save(args.encoder_out) |
|
|
|
|
|
|
|
|
print("Creating data loaders...") |
|
|
train_loader, val_loader, test_loader, test_indices = create_data_loaders( |
|
|
X, y, timestamps, args.batch_size, args.train_ratio, args.val_ratio |
|
|
) |
|
|
|
|
|
|
|
|
print("Initializing model...") |
|
|
model = LSTMRegressor( |
|
|
n_features=X.shape[2], |
|
|
hidden_size=args.hidden_size, |
|
|
n_layers=args.n_layers, |
|
|
dropout=args.dropout, |
|
|
bidirectional=args.bidirectional |
|
|
).to(DEVICE) |
|
|
|
|
|
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam( |
|
|
model.parameters(), |
|
|
lr=args.lr, |
|
|
weight_decay=args.weight_decay |
|
|
) |
|
|
|
|
|
|
|
|
if args.loss_type == "weighted_huber": |
|
|
|
|
|
weight_dict = encoder.get_speed_weights(y.flatten()) |
|
|
loss_fn = WeightedHuberLoss(weight_dict, args.huber_delta, args.boost_low) |
|
|
print(f"Using weighted Huber loss with low-speed boost: {args.boost_low}") |
|
|
elif args.loss_type == "focal_huber": |
|
|
|
|
|
weight_dict = encoder.get_speed_weights(y.flatten()) |
|
|
loss_fn = FocalHuberLoss(weight_dict, args.huber_delta, args.focal_alpha, args.focal_gamma) |
|
|
print(f"Using focal Huber loss (alpha={args.focal_alpha}, gamma={args.focal_gamma})") |
|
|
elif args.loss_type == "huber": |
|
|
loss_fn = nn.SmoothL1Loss(beta=args.huber_delta) |
|
|
print("Using Huber loss") |
|
|
elif args.loss_type == "mae": |
|
|
loss_fn = nn.L1Loss() |
|
|
print("Using MAE loss") |
|
|
else: |
|
|
loss_fn = nn.MSELoss() |
|
|
print("Using MSE loss") |
|
|
|
|
|
|
|
|
print("Starting training...") |
|
|
best_val_loss = float('inf') |
|
|
best_model_state = None |
|
|
train_losses = [] |
|
|
val_losses = [] |
|
|
|
|
|
for epoch in range(1, args.epochs + 1): |
|
|
|
|
|
train_loss = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE) |
|
|
|
|
|
|
|
|
val_loss = evaluate(model, val_loader, loss_fn, DEVICE) |
|
|
|
|
|
train_losses.append(train_loss) |
|
|
val_losses.append(val_loss) |
|
|
|
|
|
print(f"Epoch {epoch:3d}/{args.epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}") |
|
|
|
|
|
|
|
|
if val_loss < best_val_loss: |
|
|
best_val_loss = val_loss |
|
|
best_model_state = model.state_dict().copy() |
|
|
print(f" -> New best validation loss: {best_val_loss:.4f}") |
|
|
|
|
|
|
|
|
print("\nEvaluating on test set...") |
|
|
model.load_state_dict(best_model_state) |
|
|
test_loss = evaluate(model, test_loader, loss_fn, DEVICE) |
|
|
print(f"Test Loss: {test_loss:.4f}") |
|
|
|
|
|
|
|
|
if args.model_out: |
|
|
torch.save(best_model_state, args.model_out) |
|
|
print(f"Best model saved to {args.model_out}") |
|
|
|
|
|
|
|
|
if args.pred_csv: |
|
|
print("Generating test predictions...") |
|
|
model.eval() |
|
|
predictions = [] |
|
|
targets = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_X, batch_y in test_loader: |
|
|
batch_X = batch_X.to(DEVICE) |
|
|
batch_pred = model(batch_X).cpu().numpy() |
|
|
predictions.append(batch_pred) |
|
|
targets.append(batch_y.numpy()) |
|
|
|
|
|
predictions = np.concatenate(predictions, axis=0) |
|
|
targets = np.concatenate(targets, axis=0) |
|
|
|
|
|
|
|
|
pred_df = pd.DataFrame({ |
|
|
'prediction': predictions.flatten(), |
|
|
'target': targets.flatten(), |
|
|
'error': predictions.flatten() - targets.flatten(), |
|
|
'abs_error': np.abs(predictions.flatten() - targets.flatten()) |
|
|
}) |
|
|
|
|
|
pred_df.to_csv(args.pred_csv, index=False) |
|
|
print(f"Predictions saved to {args.pred_csv}") |
|
|
|
|
|
|
|
|
mae = pred_df['abs_error'].mean() |
|
|
rmse = np.sqrt((pred_df['error'] ** 2).mean()) |
|
|
print(f"Test MAE: {mae:.4f}") |
|
|
print(f"Test RMSE: {rmse:.4f}") |
|
|
|
|
|
|
|
|
if args.log_file: |
|
|
log_df = pd.DataFrame({ |
|
|
'epoch': range(1, len(train_losses) + 1), |
|
|
'train_loss': train_losses, |
|
|
'val_loss': val_losses |
|
|
}) |
|
|
log_df.to_csv(args.log_file, index=False) |
|
|
print(f"Training log saved to {args.log_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|
|
|
|