TrafCast / model_v3 /train_lstm.py
amitom's picture
Minimal app for HF Space
73e9c25
"""
train_lstm.py – LSTM model training for traffic flow prediction
Features:
- LSTM model with configurable architecture
- Weighted loss for handling speed class imbalance
- Huber loss option (better than regular loss per user experience)
- CLI interface for hyperparameter tuning
- Model and encoder saving
- Chronological train/val/test splits
"""
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
import joblib
from typing import Dict, Tuple, Optional
from encode import TrafficDataEncoder
# Device selection
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")
class LSTMRegressor(nn.Module):
"""LSTM model for traffic speed prediction."""
def __init__(
self,
n_features: int,
hidden_size: int = 128,
n_layers: int = 2,
dropout: float = 0.3,
bidirectional: bool = False
):
super().__init__()
self.hidden_size = hidden_size
self.n_layers = n_layers
self.bidirectional = bidirectional
# LSTM layer
self.lstm = nn.LSTM(
input_size=n_features,
hidden_size=hidden_size,
num_layers=n_layers,
batch_first=True,
dropout=dropout if n_layers > 1 else 0,
bidirectional=bidirectional
)
# Output layer
lstm_output_size = hidden_size * (2 if bidirectional else 1)
self.head = nn.Sequential(
nn.Linear(lstm_output_size, hidden_size // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size // 2, 1)
)
def forward(self, x):
"""Forward pass through the LSTM."""
# LSTM forward pass
lstm_out, _ = self.lstm(x)
# Use the last timestep output
last_output = lstm_out[:, -1, :]
# Final prediction
prediction = self.head(last_output)
return prediction
class WeightedHuberLoss(nn.Module):
"""Weighted Huber loss for handling speed class imbalance."""
def __init__(self, weight_dict: Dict[str, float], delta: float = 1.0, boost_low: float = 1.0):
super().__init__()
self.delta = delta
self.weight_low = weight_dict["weight_low"] * boost_low # Additional boost for low speeds
self.weight_medium = weight_dict["weight_medium"]
self.weight_high = weight_dict["weight_high"]
self.low_threshold = weight_dict["low_threshold"]
self.high_threshold = weight_dict["high_threshold"]
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute weighted Huber loss."""
# Ensure target is 1D
if target.dim() > 1:
target = target.squeeze()
if pred.dim() > 1:
pred = pred.squeeze()
# Compute Huber loss
diff = torch.abs(pred - target)
huber_loss = torch.where(
diff <= self.delta,
0.5 * diff ** 2,
self.delta * (diff - 0.5 * self.delta)
)
# Compute weights based on speed classes
weights = torch.ones_like(target)
low_mask = target <= self.low_threshold
high_mask = target >= self.high_threshold
medium_mask = ~(low_mask | high_mask)
weights[low_mask] = self.weight_low
weights[medium_mask] = self.weight_medium
weights[high_mask] = self.weight_high
# Apply weights
weighted_loss = huber_loss * weights
return weighted_loss.mean()
class FocalHuberLoss(nn.Module):
"""Focal loss variant for Huber loss to focus on hard examples."""
def __init__(self, weight_dict: Dict[str, float], delta: float = 1.0, alpha: float = 2.0, gamma: float = 2.0):
super().__init__()
self.delta = delta
self.alpha = alpha
self.gamma = gamma
self.weight_low = weight_dict["weight_low"]
self.weight_medium = weight_dict["weight_medium"]
self.weight_high = weight_dict["weight_high"]
self.low_threshold = weight_dict["low_threshold"]
self.high_threshold = weight_dict["high_threshold"]
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute focal Huber loss."""
if target.dim() > 1:
target = target.squeeze()
if pred.dim() > 1:
pred = pred.squeeze()
# Compute Huber loss
diff = torch.abs(pred - target)
huber_loss = torch.where(
diff <= self.delta,
0.5 * diff ** 2,
self.delta * (diff - 0.5 * self.delta)
)
# Compute focal weights (higher loss = harder example)
focal_weights = self.alpha * (huber_loss ** self.gamma)
# Apply class weights
class_weights = torch.ones_like(target)
low_mask = target <= self.low_threshold
high_mask = target >= self.high_threshold
medium_mask = ~(low_mask | high_mask)
class_weights[low_mask] = self.weight_low
class_weights[medium_mask] = self.weight_medium
class_weights[high_mask] = self.weight_high
# Combine focal and class weights
total_weights = focal_weights * class_weights
weighted_loss = huber_loss * total_weights
return weighted_loss.mean()
def create_data_loaders(
X: np.ndarray,
y: np.ndarray,
timestamps: np.ndarray,
batch_size: int,
train_ratio: float = 0.7,
val_ratio: float = 0.15
) -> Tuple[DataLoader, DataLoader, DataLoader, np.ndarray]:
"""
Create chronological train/validation/test data loaders.
Args:
X: Input sequences (N, seq_len, n_features)
y: Target values (N, horizon)
timestamps: Timestamps for each sample
batch_size: Batch size for data loaders
train_ratio: Fraction of data for training
val_ratio: Fraction of data for validation
Returns:
train_loader, val_loader, test_loader, test_indices
"""
# Sort by timestamp to ensure chronological order
sorted_indices = np.argsort(timestamps)
X_sorted = X[sorted_indices]
y_sorted = y[sorted_indices]
# Calculate split points
n_total = len(X_sorted)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
# Split indices
train_indices = sorted_indices[:n_train]
val_indices = sorted_indices[n_train:n_train + n_val]
test_indices = sorted_indices[n_train + n_val:]
# Convert timestamps to datetime for date range display
timestamps_dt = pd.to_datetime(timestamps)
print(f"Data split:")
print(f" Train: {len(train_indices):,} samples ({train_ratio*100:.0f}%)")
if len(train_indices) > 0:
train_dates = timestamps_dt[train_indices]
print(f" Date range: {train_dates.min()} to {train_dates.max()}")
print(f" Val: {len(val_indices):,} samples ({val_ratio*100:.0f}%)")
if len(val_indices) > 0:
val_dates = timestamps_dt[val_indices]
print(f" Date range: {val_dates.min()} to {val_dates.max()}")
print(f" Test: {len(test_indices):,} samples ({(1-train_ratio-val_ratio)*100:.0f}%)")
if len(test_indices) > 0:
test_dates = timestamps_dt[test_indices]
print(f" Date range: {test_dates.min()} to {test_dates.max()}")
# Create data loaders
def create_loader(indices, shuffle=False):
X_subset = torch.from_numpy(X[indices]).float()
y_subset = torch.from_numpy(y[indices]).float()
dataset = TensorDataset(X_subset, y_subset)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
train_loader = create_loader(train_indices, shuffle=True)
val_loader = create_loader(val_indices, shuffle=False)
test_loader = create_loader(test_indices, shuffle=False)
return train_loader, val_loader, test_loader, test_indices
def train_epoch(
model: LSTMRegressor,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
loss_fn: nn.Module,
device: torch.device
) -> float:
"""Train the model for one epoch."""
model.train()
total_loss = 0.0
num_batches = 0
for batch_X, batch_y in train_loader:
batch_X = batch_X.to(device)
batch_y = batch_y.to(device)
# Forward pass
optimizer.zero_grad()
predictions = model(batch_X)
loss = loss_fn(predictions, batch_y)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches
def evaluate(
model: LSTMRegressor,
data_loader: DataLoader,
loss_fn: nn.Module,
device: torch.device
) -> float:
"""Evaluate the model on a dataset."""
model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch_X, batch_y in data_loader:
batch_X = batch_X.to(device)
batch_y = batch_y.to(device)
predictions = model(batch_X)
loss = loss_fn(predictions, batch_y)
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches
def main():
"""Main training function."""
parser = argparse.ArgumentParser(description="Train LSTM model for traffic prediction")
# Data parameters
parser.add_argument("--csv", required=True, help="Path to CSV file with traffic data")
parser.add_argument("--seq_len", type=int, default=12, help="Sequence length (default: 12)")
parser.add_argument("--horizon", type=int, default=1, help="Prediction horizon (default: 1)")
parser.add_argument("--target_col", default="speed_mph", help="Target column name")
# Model parameters
parser.add_argument("--hidden_size", type=int, default=128, help="LSTM hidden size")
parser.add_argument("--n_layers", type=int, default=2, help="Number of LSTM layers")
parser.add_argument("--dropout", type=float, default=0.3, help="Dropout rate")
parser.add_argument("--bidirectional", action="store_true", help="Use bidirectional LSTM")
# Training parameters
parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=256, help="Batch size")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=1e-5, help="Weight decay")
# Loss parameters
parser.add_argument("--loss_type", choices=["mse", "mae", "huber", "weighted_huber", "focal_huber"],
default="weighted_huber", help="Loss function type")
parser.add_argument("--huber_delta", type=float, default=1.0, help="Huber loss delta")
parser.add_argument("--boost_low", type=float, default=1.0, help="Additional boost for low-speed loss (weighted_huber only)")
parser.add_argument("--focal_alpha", type=float, default=2.0, help="Focal loss alpha parameter")
parser.add_argument("--focal_gamma", type=float, default=2.0, help="Focal loss gamma parameter")
# Data split parameters
parser.add_argument("--train_ratio", type=float, default=0.7, help="Training data ratio")
parser.add_argument("--val_ratio", type=float, default=0.15, help="Validation data ratio")
# Output parameters
parser.add_argument("--model_out", help="Path to save the best model")
parser.add_argument("--encoder_out", help="Path to save the fitted encoder")
parser.add_argument("--pred_csv", help="Path to save test predictions")
parser.add_argument("--log_file", help="Path to save training log")
args = parser.parse_args()
# Load and encode data
print("Loading data...")
df = pd.read_csv(args.csv)
print(f"Loaded {len(df):,} rows from {args.csv}")
# Create encoder
encoder = TrafficDataEncoder(
seq_len=args.seq_len,
horizon=args.horizon,
target_col=args.target_col
)
# Fit encoder and transform data
print("Encoding data...")
X, y, target_indices, timestamps = encoder.fit_transform(df)
print(f"Encoded data shapes: X={X.shape}, y={y.shape}")
# Save encoder if requested
if args.encoder_out:
encoder.save(args.encoder_out)
# Create data loaders
print("Creating data loaders...")
train_loader, val_loader, test_loader, test_indices = create_data_loaders(
X, y, timestamps, args.batch_size, args.train_ratio, args.val_ratio
)
# Initialize model
print("Initializing model...")
model = LSTMRegressor(
n_features=X.shape[2],
hidden_size=args.hidden_size,
n_layers=args.n_layers,
dropout=args.dropout,
bidirectional=args.bidirectional
).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# Initialize optimizer
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay
)
# Initialize loss function
if args.loss_type == "weighted_huber":
# Get speed weights from encoder
weight_dict = encoder.get_speed_weights(y.flatten())
loss_fn = WeightedHuberLoss(weight_dict, args.huber_delta, args.boost_low)
print(f"Using weighted Huber loss with low-speed boost: {args.boost_low}")
elif args.loss_type == "focal_huber":
# Get speed weights from encoder
weight_dict = encoder.get_speed_weights(y.flatten())
loss_fn = FocalHuberLoss(weight_dict, args.huber_delta, args.focal_alpha, args.focal_gamma)
print(f"Using focal Huber loss (alpha={args.focal_alpha}, gamma={args.focal_gamma})")
elif args.loss_type == "huber":
loss_fn = nn.SmoothL1Loss(beta=args.huber_delta)
print("Using Huber loss")
elif args.loss_type == "mae":
loss_fn = nn.L1Loss()
print("Using MAE loss")
else: # mse
loss_fn = nn.MSELoss()
print("Using MSE loss")
# Training loop
print("Starting training...")
best_val_loss = float('inf')
best_model_state = None
train_losses = []
val_losses = []
for epoch in range(1, args.epochs + 1):
# Train
train_loss = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE)
# Validate
val_loss = evaluate(model, val_loader, loss_fn, DEVICE)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Epoch {epoch:3d}/{args.epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = model.state_dict().copy()
print(f" -> New best validation loss: {best_val_loss:.4f}")
# Load best model and evaluate on test set
print("\nEvaluating on test set...")
model.load_state_dict(best_model_state)
test_loss = evaluate(model, test_loader, loss_fn, DEVICE)
print(f"Test Loss: {test_loss:.4f}")
# Save best model
if args.model_out:
torch.save(best_model_state, args.model_out)
print(f"Best model saved to {args.model_out}")
# Save predictions if requested
if args.pred_csv:
print("Generating test predictions...")
model.eval()
predictions = []
targets = []
with torch.no_grad():
for batch_X, batch_y in test_loader:
batch_X = batch_X.to(DEVICE)
batch_pred = model(batch_X).cpu().numpy()
predictions.append(batch_pred)
targets.append(batch_y.numpy())
predictions = np.concatenate(predictions, axis=0)
targets = np.concatenate(targets, axis=0)
# Create prediction DataFrame
pred_df = pd.DataFrame({
'prediction': predictions.flatten(),
'target': targets.flatten(),
'error': predictions.flatten() - targets.flatten(),
'abs_error': np.abs(predictions.flatten() - targets.flatten())
})
pred_df.to_csv(args.pred_csv, index=False)
print(f"Predictions saved to {args.pred_csv}")
# Print some statistics
mae = pred_df['abs_error'].mean()
rmse = np.sqrt((pred_df['error'] ** 2).mean())
print(f"Test MAE: {mae:.4f}")
print(f"Test RMSE: {rmse:.4f}")
# Save training log if requested
if args.log_file:
log_df = pd.DataFrame({
'epoch': range(1, len(train_losses) + 1),
'train_loss': train_losses,
'val_loss': val_losses
})
log_df.to_csv(args.log_file, index=False)
print(f"Training log saved to {args.log_file}")
if __name__ == "__main__":
main()