Spaces:
Sleeping
Sleeping
| """Training loop for PI-ResMLP ensemble. | |
| Trains each ensemble member independently with different random seeds. | |
| Supports Intel XPU (Arc GPU), NVIDIA CUDA, and CPU. | |
| Mixed precision (fp16), gradient clipping, early stopping, | |
| and experiment tracking via MLflow. | |
| Usage: | |
| python -m src.training.train --config configs/training.yaml | |
| """ | |
| import argparse | |
| import logging | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from src.models.architecture import PIResMLP | |
| from src.models.ensemble import DeepEnsemble | |
| from src.models.normalization import LogTransformStandardizer | |
| from src.models.physics_loss import PhysicsInformedLoss | |
| from src.training.dataset import create_dataloaders | |
| from src.utils.device import get_amp_backend, get_device, supports_mixed_precision | |
| logger = logging.getLogger(__name__) | |
| def train_single_member( | |
| member: PIResMLP, | |
| train_loader: torch.utils.data.DataLoader, | |
| val_loader: torch.utils.data.DataLoader, | |
| config: dict, | |
| member_idx: int, | |
| device: torch.device, | |
| ) -> dict: | |
| """Train a single ensemble member. | |
| Returns: | |
| Dict with training history (losses per epoch). | |
| """ | |
| member = member.to(device) | |
| criterion = PhysicsInformedLoss( | |
| classification_weight=config["loss"]["classification_weight"], | |
| physics_weight=config["loss"]["physics_weight"], | |
| heteroscedastic=config["model"]["heteroscedastic"], | |
| ) | |
| optimizer = torch.optim.AdamW( | |
| member.parameters(), | |
| lr=config["optimizer"]["learning_rate"], | |
| weight_decay=config["optimizer"]["weight_decay"], | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( | |
| optimizer, | |
| T_0=config["scheduler"]["T_0"], | |
| T_mult=config["scheduler"]["T_mult"], | |
| ) | |
| amp_backend = get_amp_backend(device) | |
| use_amp = config["training"]["mixed_precision"] and supports_mixed_precision(device) | |
| scaler = torch.amp.GradScaler(amp_backend, enabled=use_amp) | |
| epochs = config["training"]["epochs"] | |
| patience = config["training"]["early_stopping_patience"] | |
| clip_norm = config["training"]["gradient_clip_norm"] | |
| best_val_loss = float("inf") | |
| patience_counter = 0 | |
| history = {"train_loss": [], "val_loss": [], "val_regression": [], "val_physics": []} | |
| for epoch in range(epochs): | |
| # --- Training --- | |
| member.train() | |
| train_losses = [] | |
| for X_batch, targets_batch in train_loader: | |
| X_batch = X_batch.to(device) | |
| targets = {k: v.to(device) for k, v in targets_batch.items()} | |
| optimizer.zero_grad() | |
| with torch.amp.autocast(amp_backend, enabled=use_amp): | |
| predictions = member(X_batch) | |
| loss_dict = criterion(predictions, targets) | |
| scaler.scale(loss_dict["total"]).backward() | |
| scaler.unscale_(optimizer) | |
| nn.utils.clip_grad_norm_(member.parameters(), clip_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| train_losses.append(loss_dict["total"].item()) | |
| scheduler.step() | |
| # --- Validation --- | |
| member.eval() | |
| val_losses = [] | |
| val_reg_losses = [] | |
| val_phys_losses = [] | |
| with torch.no_grad(): | |
| for X_batch, targets_batch in val_loader: | |
| X_batch = X_batch.to(device) | |
| targets = {k: v.to(device) for k, v in targets_batch.items()} | |
| predictions = member(X_batch) | |
| loss_dict = criterion(predictions, targets) | |
| val_losses.append(loss_dict["total"].item()) | |
| val_reg_losses.append(loss_dict["regression"].item()) | |
| val_phys_losses.append(loss_dict["physics"].item()) | |
| avg_train = np.mean(train_losses) | |
| avg_val = np.mean(val_losses) | |
| avg_val_reg = np.mean(val_reg_losses) | |
| avg_val_phys = np.mean(val_phys_losses) | |
| history["train_loss"].append(avg_train) | |
| history["val_loss"].append(avg_val) | |
| history["val_regression"].append(avg_val_reg) | |
| history["val_physics"].append(avg_val_phys) | |
| if epoch % 10 == 0 or epoch == epochs - 1: | |
| logger.info( | |
| f"Member {member_idx} | Epoch {epoch}/{epochs} | " | |
| f"Train: {avg_train:.6f} | Val: {avg_val:.6f} | " | |
| f"LR: {scheduler.get_last_lr()[0]:.2e}" | |
| ) | |
| # Early stopping | |
| if avg_val < best_val_loss: | |
| best_val_loss = avg_val | |
| patience_counter = 0 | |
| best_state = {k: v.cpu().clone() for k, v in member.state_dict().items()} | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= patience: | |
| logger.info( | |
| f"Member {member_idx} | Early stopping at epoch {epoch} " | |
| f"(best val loss: {best_val_loss:.6f})" | |
| ) | |
| break | |
| # Restore best weights | |
| member.load_state_dict(best_state) | |
| member.cpu() | |
| return history | |
| def train_ensemble(config_path: str) -> None: | |
| """Train the full deep ensemble.""" | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| device = get_device() | |
| # Create normalizer and dataloaders | |
| normalizer = LogTransformStandardizer() | |
| data_dir = Path(config["data"]["directory"]) | |
| train_loader, val_loader, test_loader = create_dataloaders( | |
| data_dir, | |
| normalizer, | |
| batch_size=config["data"]["batch_size"], | |
| num_workers=config["data"].get("num_workers", 0), | |
| ) | |
| # Save normalization parameters | |
| output_dir = Path(config["output"]["checkpoint_dir"]) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| normalizer.save(output_dir / "normalization_params.json") | |
| # Model config | |
| model_kwargs = { | |
| "input_dim": normalizer.input_dim, | |
| "hidden_dim": config["model"]["hidden_dim"], | |
| "num_blocks": config["model"]["num_residual_blocks"], | |
| "dropout": config["model"]["dropout"], | |
| "num_classes": config["model"]["classification_classes"], | |
| "heteroscedastic": config["model"]["heteroscedastic"], | |
| } | |
| num_members = config["model"]["num_ensemble_members"] | |
| ensemble = DeepEnsemble(num_members=num_members, **model_kwargs) | |
| all_histories = [] | |
| total_start = time.time() | |
| for i, member in enumerate(ensemble.members): | |
| # Different seed per member for diversity | |
| torch.manual_seed(config["seed"] + i * 1000) | |
| np.random.seed(config["seed"] + i * 1000) | |
| # Re-initialize weights (ensures independence) | |
| for module in member.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.kaiming_normal_(module.weight, nonlinearity="linear") | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| logger.info(f"\n{'='*60}") | |
| logger.info(f"Training ensemble member {i+1}/{num_members}") | |
| logger.info(f"{'='*60}") | |
| history = train_single_member( | |
| member, train_loader, val_loader, config, i, device, | |
| ) | |
| all_histories.append(history) | |
| total_time = time.time() - total_start | |
| logger.info(f"\nTotal training time: {total_time:.1f}s ({total_time/60:.1f}min)") | |
| # Save ensemble | |
| ensemble.save(output_dir / "model_ensemble") | |
| logger.info(f"Ensemble saved to {output_dir / 'model_ensemble'}") | |
| # Save model config for loading | |
| import json | |
| with open(output_dir / "model_config.json", "w") as f: | |
| json.dump(model_kwargs, f, indent=2) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Train PI-ResMLP ensemble") | |
| parser.add_argument("--config", default="configs/training.yaml") | |
| args = parser.parse_args() | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| train_ensemble(args.config) | |
| if __name__ == "__main__": | |
| main() | |