fea-surrogate / src /training /train.py
WolfDavid's picture
Upload folder using huggingface_hub
8e5ba9e verified
"""Training loop for PI-ResMLP ensemble.
Trains each ensemble member independently with different random seeds.
Supports Intel XPU (Arc GPU), NVIDIA CUDA, and CPU.
Mixed precision (fp16), gradient clipping, early stopping,
and experiment tracking via MLflow.
Usage:
python -m src.training.train --config configs/training.yaml
"""
import argparse
import logging
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import yaml
from src.models.architecture import PIResMLP
from src.models.ensemble import DeepEnsemble
from src.models.normalization import LogTransformStandardizer
from src.models.physics_loss import PhysicsInformedLoss
from src.training.dataset import create_dataloaders
from src.utils.device import get_amp_backend, get_device, supports_mixed_precision
logger = logging.getLogger(__name__)
def train_single_member(
member: PIResMLP,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
config: dict,
member_idx: int,
device: torch.device,
) -> dict:
"""Train a single ensemble member.
Returns:
Dict with training history (losses per epoch).
"""
member = member.to(device)
criterion = PhysicsInformedLoss(
classification_weight=config["loss"]["classification_weight"],
physics_weight=config["loss"]["physics_weight"],
heteroscedastic=config["model"]["heteroscedastic"],
)
optimizer = torch.optim.AdamW(
member.parameters(),
lr=config["optimizer"]["learning_rate"],
weight_decay=config["optimizer"]["weight_decay"],
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=config["scheduler"]["T_0"],
T_mult=config["scheduler"]["T_mult"],
)
amp_backend = get_amp_backend(device)
use_amp = config["training"]["mixed_precision"] and supports_mixed_precision(device)
scaler = torch.amp.GradScaler(amp_backend, enabled=use_amp)
epochs = config["training"]["epochs"]
patience = config["training"]["early_stopping_patience"]
clip_norm = config["training"]["gradient_clip_norm"]
best_val_loss = float("inf")
patience_counter = 0
history = {"train_loss": [], "val_loss": [], "val_regression": [], "val_physics": []}
for epoch in range(epochs):
# --- Training ---
member.train()
train_losses = []
for X_batch, targets_batch in train_loader:
X_batch = X_batch.to(device)
targets = {k: v.to(device) for k, v in targets_batch.items()}
optimizer.zero_grad()
with torch.amp.autocast(amp_backend, enabled=use_amp):
predictions = member(X_batch)
loss_dict = criterion(predictions, targets)
scaler.scale(loss_dict["total"]).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(member.parameters(), clip_norm)
scaler.step(optimizer)
scaler.update()
train_losses.append(loss_dict["total"].item())
scheduler.step()
# --- Validation ---
member.eval()
val_losses = []
val_reg_losses = []
val_phys_losses = []
with torch.no_grad():
for X_batch, targets_batch in val_loader:
X_batch = X_batch.to(device)
targets = {k: v.to(device) for k, v in targets_batch.items()}
predictions = member(X_batch)
loss_dict = criterion(predictions, targets)
val_losses.append(loss_dict["total"].item())
val_reg_losses.append(loss_dict["regression"].item())
val_phys_losses.append(loss_dict["physics"].item())
avg_train = np.mean(train_losses)
avg_val = np.mean(val_losses)
avg_val_reg = np.mean(val_reg_losses)
avg_val_phys = np.mean(val_phys_losses)
history["train_loss"].append(avg_train)
history["val_loss"].append(avg_val)
history["val_regression"].append(avg_val_reg)
history["val_physics"].append(avg_val_phys)
if epoch % 10 == 0 or epoch == epochs - 1:
logger.info(
f"Member {member_idx} | Epoch {epoch}/{epochs} | "
f"Train: {avg_train:.6f} | Val: {avg_val:.6f} | "
f"LR: {scheduler.get_last_lr()[0]:.2e}"
)
# Early stopping
if avg_val < best_val_loss:
best_val_loss = avg_val
patience_counter = 0
best_state = {k: v.cpu().clone() for k, v in member.state_dict().items()}
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(
f"Member {member_idx} | Early stopping at epoch {epoch} "
f"(best val loss: {best_val_loss:.6f})"
)
break
# Restore best weights
member.load_state_dict(best_state)
member.cpu()
return history
def train_ensemble(config_path: str) -> None:
"""Train the full deep ensemble."""
with open(config_path) as f:
config = yaml.safe_load(f)
device = get_device()
# Create normalizer and dataloaders
normalizer = LogTransformStandardizer()
data_dir = Path(config["data"]["directory"])
train_loader, val_loader, test_loader = create_dataloaders(
data_dir,
normalizer,
batch_size=config["data"]["batch_size"],
num_workers=config["data"].get("num_workers", 0),
)
# Save normalization parameters
output_dir = Path(config["output"]["checkpoint_dir"])
output_dir.mkdir(parents=True, exist_ok=True)
normalizer.save(output_dir / "normalization_params.json")
# Model config
model_kwargs = {
"input_dim": normalizer.input_dim,
"hidden_dim": config["model"]["hidden_dim"],
"num_blocks": config["model"]["num_residual_blocks"],
"dropout": config["model"]["dropout"],
"num_classes": config["model"]["classification_classes"],
"heteroscedastic": config["model"]["heteroscedastic"],
}
num_members = config["model"]["num_ensemble_members"]
ensemble = DeepEnsemble(num_members=num_members, **model_kwargs)
all_histories = []
total_start = time.time()
for i, member in enumerate(ensemble.members):
# Different seed per member for diversity
torch.manual_seed(config["seed"] + i * 1000)
np.random.seed(config["seed"] + i * 1000)
# Re-initialize weights (ensures independence)
for module in member.modules():
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
if module.bias is not None:
nn.init.zeros_(module.bias)
logger.info(f"\n{'='*60}")
logger.info(f"Training ensemble member {i+1}/{num_members}")
logger.info(f"{'='*60}")
history = train_single_member(
member, train_loader, val_loader, config, i, device,
)
all_histories.append(history)
total_time = time.time() - total_start
logger.info(f"\nTotal training time: {total_time:.1f}s ({total_time/60:.1f}min)")
# Save ensemble
ensemble.save(output_dir / "model_ensemble")
logger.info(f"Ensemble saved to {output_dir / 'model_ensemble'}")
# Save model config for loading
import json
with open(output_dir / "model_config.json", "w") as f:
json.dump(model_kwargs, f, indent=2)
def main() -> None:
parser = argparse.ArgumentParser(description="Train PI-ResMLP ensemble")
parser.add_argument("--config", default="configs/training.yaml")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
train_ensemble(args.config)
if __name__ == "__main__":
main()