Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| import inspect | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from physicsnemo import Module | |
| from physicsnemo.active_learning import protocols as p | |
| from physicsnemo.distributed import DistributedManager | |
| from physicsnemo.launch.logging import LaunchLogger | |
| from physicsnemo.utils.capture import StaticCaptureEvaluateNoGrad, StaticCaptureTraining | |
| __all__ = ["DefaultTrainingLoop"] | |
| def _recursive_data_device_cast( | |
| data: Any, | |
| device: torch.device | str | None = None, | |
| dtype: torch.dtype | None = None, | |
| **kwargs: Any, | |
| ) -> Any: | |
| """ | |
| Recursively moves/cast input data to a specified device and dtype. | |
| For iterable objects, we recurse through the elements depending on | |
| the type of iterable until we reach an object that either has a ``to`` | |
| method that can be called, or just returns the data unchanged. | |
| Parameters | |
| ---------- | |
| data: Any | |
| The data to move to the device. | |
| device: torch.device | str | None = None | |
| The device to move the data to. | |
| dtype: torch.dtype | None = None | |
| The dtype to move the data to. | |
| kwargs: Any | |
| Additional keyword arguments to pass to the `to` method. | |
| By default, `non_blocking` is set to `True` to allow | |
| asynchronous data transfers. | |
| Returns | |
| ------- | |
| Any | |
| The data moved to the device. | |
| """ | |
| kwargs.setdefault("non_blocking", True) | |
| if hasattr(data, "to"): | |
| # if there is a `to` method, then we can just call it | |
| return data.to(device=device, dtype=dtype, **kwargs) | |
| elif isinstance(data, dict): | |
| return { | |
| k: _recursive_data_device_cast(v, device, dtype) for k, v in data.items() | |
| } | |
| elif isinstance(data, list): | |
| return [_recursive_data_device_cast(v, device, dtype) for v in data] | |
| elif isinstance(data, tuple): | |
| return tuple(_recursive_data_device_cast(v, device, dtype) for v in data) | |
| else: | |
| return data | |
| class DefaultTrainingLoop(p.TrainingLoop): | |
| def __new__(cls, *args: Any, **kwargs: Any) -> DefaultTrainingLoop: | |
| """ | |
| Wrapper for instantiating DefaultTrainingLoop. | |
| This method captures arguments used to instantiate the loop | |
| and stores them in the `_args` attribute for serialization. | |
| This follows the same pattern as `ActiveLearningProtocol.__new__`. | |
| Parameters | |
| ---------- | |
| args: Any | |
| Arguments to pass to the loop's constructor. | |
| kwargs: Any | |
| Keyword arguments to pass to the loop's constructor. | |
| Returns | |
| ------- | |
| DefaultTrainingLoop | |
| A new instance with an `_args` attribute for serialization. | |
| """ | |
| out = super().__new__(cls) | |
| # Get signature of __init__ function | |
| sig = inspect.signature(cls.__init__) | |
| # Bind args and kwargs to signature | |
| bound_args = sig.bind_partial( | |
| *([None] + list(args)), **kwargs | |
| ) # Add None to account for self | |
| bound_args.apply_defaults() | |
| # Get args and kwargs (excluding self and unroll kwargs) | |
| instantiate_args = {} | |
| for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()): | |
| # Skip self | |
| if k == "self": | |
| continue | |
| # Add args and kwargs to instantiate_args | |
| if param.kind == param.VAR_KEYWORD: | |
| instantiate_args.update(v) | |
| else: | |
| # Special handling for device: convert torch.device to string | |
| if k == "device" and isinstance(v, torch.device): | |
| instantiate_args[k] = str(v) | |
| # Special handling for dtype: convert to string representation | |
| elif k == "dtype" and isinstance(v, torch.dtype): | |
| instantiate_args[k] = str(v) | |
| else: | |
| instantiate_args[k] = v | |
| # Store args needed for instantiation | |
| out._args = { | |
| "__name__": cls.__name__, | |
| "__module__": cls.__module__, | |
| "__args__": instantiate_args, | |
| } | |
| return out | |
| def __init__( | |
| self, | |
| train_step_fn: p.TrainingProtocol | None = None, | |
| validate_step_fn: p.ValidationProtocol | None = None, | |
| enable_static_capture: bool = True, | |
| use_progress_bars: bool = True, | |
| device: str | torch.device | None = None, | |
| dtype: torch.dtype | None = None, | |
| checkpoint_frequency: int = 0, | |
| **capture_kwargs: Any, | |
| ) -> None: | |
| """ | |
| Initializes the default training loop. | |
| The general usage of this loop is to | |
| TODO: add support for early stopping | |
| Parameters | |
| ---------- | |
| train_step_fn: TrainingProtocol | None = None | |
| A callable that implements the logic for performing a single | |
| training step. See ``protocols.TrainingProtocol`` for the expected | |
| interface, but ultimately the function should return a scalar loss | |
| value that has a ``backward`` method. | |
| validate_step_fn: ValidationProtocol | None = None | |
| A callable that implements the logic for performing a single | |
| validation step. See ``protocols.ValidationProtocol`` for the expected | |
| interface, but in contrast to ``train_step_fn`` this function should | |
| not return anything. | |
| enable_static_capture: bool = True | |
| Whether to enable static capture for the training and validation steps. | |
| use_progress_bars: bool = True | |
| Whether to show ``tqdm`` progress bars to display epoch and step progress. | |
| device: str | torch.device | None = None | |
| The device used for performing the loop. If not provided, then the device | |
| will default to the model's device at runtime. | |
| dtype: torch.dtype | None = None | |
| The dtype used for performing the loop. If not provided, then the dtype | |
| will default to ``torch.get_default_dtype()``. | |
| checkpoint_frequency: int = 0 | |
| How often to save checkpoints during training (every N epochs). | |
| If 0, no checkpoints are saved during training. Set via Driver before | |
| training execution. | |
| capture_kwargs: Any | |
| Additional keyword arguments to pass to the static capture decorators. | |
| """ | |
| self.train_step_fn = train_step_fn | |
| self.validate_step_fn = validate_step_fn | |
| self.enable_static_capture = enable_static_capture | |
| if isinstance(device, str): | |
| device = torch.device(device) | |
| # check to see if we can rely on DistributedManager | |
| if device is None and DistributedManager.is_initialized(): | |
| device = DistributedManager.device | |
| self.device = device | |
| if dtype is None: | |
| dtype = torch.get_default_dtype() | |
| self.dtype = dtype | |
| self.capture_kwargs = capture_kwargs | |
| self.use_progress_bars = use_progress_bars | |
| self.capture_functions = {} | |
| self.checkpoint_frequency = checkpoint_frequency | |
| self.checkpoint_base_dir: Path | None = None | |
| def save_training_checkpoint( | |
| self, | |
| checkpoint_dir: Path, | |
| model: Module | p.LearnerProtocol, | |
| optimizer: Optimizer, | |
| lr_scheduler: _LRScheduler | None = None, | |
| training_epoch: int | None = None, | |
| ) -> None: | |
| """ | |
| Save training state to checkpoint directory. | |
| Model weights are saved separately. Optimizer, scheduler, and epoch | |
| metadata are combined into a single training_state.pt file. | |
| Parameters | |
| ---------- | |
| checkpoint_dir: Path | |
| Directory to save checkpoint files. | |
| model: Module | p.LearnerProtocol | |
| Model to save weights for. | |
| optimizer: Optimizer | |
| Optimizer to save state from. | |
| lr_scheduler: _LRScheduler | None | |
| Optional LR scheduler to save state from. | |
| training_epoch: int | None | |
| Current training epoch for metadata. | |
| """ | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| # Save model weights separately | |
| if isinstance(model, Module): | |
| model_path = checkpoint_dir / "model.mdlus" | |
| model.save(str(model_path)) | |
| else: | |
| model_path = checkpoint_dir / "model_state.pt" | |
| torch.save(model.state_dict(), model_path) | |
| # Combine optimizer, scheduler, and epoch metadata into single file | |
| training_state = { | |
| "optimizer_state": optimizer.state_dict(), | |
| "lr_scheduler_state": lr_scheduler.state_dict() if lr_scheduler else None, | |
| "training_epoch": training_epoch, | |
| } | |
| training_state_path = checkpoint_dir / "training_state.pt" | |
| torch.save(training_state, training_state_path) | |
| def load_training_checkpoint( | |
| checkpoint_dir: Path, | |
| model: Module | p.LearnerProtocol, | |
| optimizer: Optimizer, | |
| lr_scheduler: _LRScheduler | None = None, | |
| ) -> int | None: | |
| """ | |
| Load training state from checkpoint directory. | |
| Model weights are loaded separately. Optimizer, scheduler, and epoch | |
| metadata are loaded from the combined training_state.pt file. | |
| Parameters | |
| ---------- | |
| checkpoint_dir: Path | |
| Directory containing checkpoint files. | |
| model: Module | p.LearnerProtocol | |
| Model to load weights into. | |
| optimizer: Optimizer | |
| Optimizer to load state into. | |
| lr_scheduler: _LRScheduler | None | |
| Optional LR scheduler to load state into. | |
| Returns | |
| ------- | |
| int | None | |
| Training epoch from metadata if available, else None. | |
| """ | |
| # Load model weights separately | |
| if isinstance(model, Module): | |
| model_path = checkpoint_dir / "model.mdlus" | |
| if model_path.exists(): | |
| model.load(str(model_path)) | |
| else: | |
| model_state_path = checkpoint_dir / "model_state.pt" | |
| if model_state_path.exists(): | |
| state_dict = torch.load(model_state_path, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| # Load combined training state (optimizer, scheduler, epoch) | |
| training_state_path = checkpoint_dir / "training_state.pt" | |
| if training_state_path.exists(): | |
| training_state = torch.load(training_state_path, map_location="cpu") | |
| # Restore optimizer state | |
| if "optimizer_state" in training_state: | |
| optimizer.load_state_dict(training_state["optimizer_state"]) | |
| # Restore scheduler state if present | |
| if lr_scheduler and training_state.get("lr_scheduler_state"): | |
| lr_scheduler.load_state_dict(training_state["lr_scheduler_state"]) | |
| # Return epoch metadata | |
| return training_state.get("training_epoch", None) | |
| return None | |
| def amp_type(self) -> torch.dtype: | |
| if self.dtype in [torch.float16, torch.bfloat16]: | |
| return self.dtype | |
| else: | |
| return torch.float16 | |
| def _create_capture_functions( | |
| self, | |
| model: Module | p.LearnerProtocol, | |
| optimizer: Optimizer, | |
| train_step_fn: p.TrainingProtocol | None = None, | |
| validate_step_fn: p.ValidationProtocol | None = None, | |
| ) -> tuple[p.TrainingProtocol | None, p.ValidationProtocol | None]: | |
| """ | |
| Attempt to create static capture functions based off training and validation | |
| functions. | |
| This uses the Python object IDs to unique identify functions, and adds the | |
| decorated functions to an internal `capture_functions` dictionary. If the | |
| decorated functions already exist, then this function will be no-op. | |
| Parameters | |
| ---------- | |
| model: Module | p.LearnerProtocol | |
| The model to train. | |
| optimizer: Optimizer | |
| The optimizer to use for training. | |
| train_step_fn: p.TrainingProtocol | None = None | |
| The training function to use for training. | |
| validate_step_fn: p.ValidationProtocol | None = None | |
| The validation function to use for validation. | |
| Returns | |
| ------- | |
| tuple[p.TrainingProtocol | None, p.ValidationProtocol | None] | |
| The training and validation functions with static capture applied. | |
| """ | |
| if not train_step_fn: | |
| train_step_fn = self.train_step_fn | |
| train_func_id = id(train_step_fn) | |
| if train_func_id not in self.capture_functions: | |
| try: | |
| train_step_fn = StaticCaptureTraining( | |
| model=model, | |
| optim=optimizer, | |
| amp_type=self.amp_type, | |
| **self.capture_kwargs, | |
| )(train_step_fn) | |
| self.capture_functions[train_func_id] = train_step_fn | |
| except Exception as e: | |
| raise RuntimeError( | |
| "Failed to create static capture for `train_step_fn`. " | |
| ) from e | |
| else: | |
| train_step_fn = self.capture_functions[train_func_id] | |
| if not validate_step_fn: | |
| validate_step_fn = self.validate_step_fn | |
| if validate_step_fn: | |
| val_func_id = id(validate_step_fn) | |
| if val_func_id not in self.capture_functions: | |
| try: | |
| validate_step_fn = StaticCaptureEvaluateNoGrad( | |
| model=model, amp_type=self.amp_type, **self.capture_kwargs | |
| )(validate_step_fn) | |
| self.capture_functions[val_func_id] = validate_step_fn | |
| except Exception as e: | |
| raise RuntimeError( | |
| "Failed to create static capture for `validate_step_fn`. " | |
| ) from e | |
| else: | |
| validate_step_fn = self.capture_functions[val_func_id] | |
| return train_step_fn, validate_step_fn | |
| def __call__( | |
| self, | |
| model: Module | p.LearnerProtocol, | |
| optimizer: Optimizer, | |
| train_dataloader: DataLoader, | |
| max_epochs: int, | |
| validation_dataloader: DataLoader | None = None, | |
| train_step_fn: p.TrainingProtocol | None = None, | |
| validate_step_fn: p.ValidationProtocol | None = None, | |
| lr_scheduler: _LRScheduler | None = None, | |
| device: str | torch.device | None = None, | |
| dtype: torch.dtype | None = None, | |
| *args: Any, | |
| **kwargs: Any, | |
| ) -> None: | |
| """ | |
| Performs ``max_epochs`` epochs of training and optionally validation. | |
| Some of the arguments, such as ``train_step_fn`` and ``validate_step_fn``, | |
| are optional only if the ``model`` implements the ``p.LearnerProtocol``. | |
| If they are passed, however, they will take precedence over the methods | |
| originally provided to the constructor method. | |
| The bare minimum required arguments for this loop to work are: | |
| 1. A model to train | |
| 2. An optimizer to step | |
| 3. A training dataloader to iterate over | |
| 4. The maximum number of epochs to train for | |
| If validation is required, then both ``validation_dataloader`` and | |
| ``validate_step_fn`` must be specified. | |
| Parameters | |
| ---------- | |
| model: Module | p.LearnerProtocol | |
| The model to train. | |
| optimizer: torch.optim.Optimizer | |
| The optimizer to use for training. | |
| train_dataloader: DataLoader | |
| The dataloader to use for training. | |
| max_epochs: int | |
| The number of epochs to train for. | |
| validation_dataloader: DataLoader | None | |
| The dataloader to use for validation. If not provided, then validation | |
| will not be performed. | |
| train_step_fn: p.TrainingProtocol | None = None | |
| The training function to use for training. If passed, it will take | |
| precedence over the method provided to the constructor method. | |
| validate_step_fn: p.ValidationProtocol | None = None | |
| The validation function to use for validation. | |
| lr_scheduler: torch.optim.lr_scheduler._LRScheduler | None = None | |
| The learning rate scheduler to use for training. | |
| device: str | torch.device | None = None | |
| The device used for performing the loop. If provided, it will | |
| override the device specified in the constructor. If both values | |
| are not provided, then we default to PyTorch's default device. | |
| dtype: torch.dtype | None = None | |
| The dtype used for performing the loop. If provided, it will | |
| override the dtype specified in the constructor. If both values | |
| are not provided, then we default to PyTorch's default dtype. | |
| args: Any | |
| Additional arguments to pass the training and validation | |
| step functions. | |
| kwargs: Any | |
| Additional keyword arguments to pass the training and validation | |
| step functions. | |
| """ | |
| if not train_step_fn and not self.train_step_fn: | |
| raise RuntimeError( | |
| """ | |
| No training step function provided. | |
| Either provide a `train_step_fn` to this constructor, or | |
| provide a `train_step_fn` to the `__call__` method. | |
| """ | |
| ) | |
| if not device and not self.device: | |
| device = torch.get_default_device() | |
| if not dtype and not self.dtype: | |
| dtype = torch.get_default_dtype() | |
| # if a device is specified, move the model | |
| if device and device != model.device: | |
| # not 100% sure this will trigger issues with the optimizer | |
| # but allows a potentially different device to be used | |
| model = model.to(device) | |
| if self.enable_static_capture: | |
| # if static capture is enabled, we check for a cache hit based on | |
| # the incoming function IDs. If we miss, we then create new wrappers. | |
| train_step_fn, validate_step_fn = self._create_capture_functions( | |
| model, optimizer, train_step_fn, validate_step_fn | |
| ) | |
| epoch_iter = range(1, max_epochs + 1) | |
| if self.use_progress_bars: | |
| epoch_iter = tqdm(epoch_iter, desc="Epoch", leave=False, position=0) | |
| ########### EPOCH LOOP ########### | |
| for epoch in epoch_iter: | |
| model.train() | |
| train_iter = iter(train_dataloader) | |
| if self.use_progress_bars: | |
| train_iter = tqdm( | |
| train_iter, desc="Training step", leave=False, unit="batch" | |
| ) | |
| ########### TRAINING STEP LOOP ########### | |
| with LaunchLogger( | |
| "train", epoch=epoch, num_mini_batch=len(train_dataloader) | |
| ) as log: | |
| for batch in train_iter: | |
| batch = _recursive_data_device_cast( | |
| batch, device=device, dtype=dtype | |
| ) | |
| model.zero_grad(set_to_none=True) | |
| loss = train_step_fn(model, batch, *args, **kwargs) | |
| log.log_minibatch({"train_loss": loss.detach().item()}) | |
| # normally, static capture will call backward because of AMP | |
| if not self.enable_static_capture: | |
| loss.backward() | |
| optimizer.step() | |
| if lr_scheduler: | |
| lr_scheduler.step() | |
| ########### VALIDATION STEP LOOP ########### | |
| if validate_step_fn and validation_dataloader: | |
| model.eval() | |
| val_iter = iter(validation_dataloader) | |
| if self.use_progress_bars: | |
| val_iter = tqdm( | |
| val_iter, desc="Validation step", leave=False, unit="batch" | |
| ) | |
| with LaunchLogger( | |
| "validation", epoch=epoch, num_mini_batch=len(validation_dataloader) | |
| ) as log: | |
| for batch in val_iter: | |
| batch = _recursive_data_device_cast( | |
| batch, device=device, dtype=dtype | |
| ) | |
| validate_step_fn(model, batch, *args, **kwargs) | |
| ########### CHECKPOINT SAVE ########### | |
| # Save training state at specified frequency | |
| if self.checkpoint_base_dir and self.checkpoint_frequency > 0: | |
| if epoch % self.checkpoint_frequency == 0: | |
| epoch_checkpoint_dir = self.checkpoint_base_dir / f"epoch_{epoch}" | |
| self.save_training_checkpoint( | |
| checkpoint_dir=epoch_checkpoint_dir, | |
| model=model, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| training_epoch=epoch, | |
| ) | |