| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os.path |
| from typing import Any, Dict, Union |
|
|
| import lightning.pytorch as pl |
| import pytest |
| import torch |
| from lightning.pytorch import Callback, Trainer |
| from lightning.pytorch.utilities.exceptions import MisconfigurationException |
| from lightning.pytorch.utilities.types import STEP_OUTPUT |
| from omegaconf import DictConfig, OmegaConf |
|
|
| from nemo.collections.common.callbacks import EMA |
| from nemo.collections.common.callbacks.ema import EMAOptimizer |
| from nemo.core import ModelPT |
| from nemo.utils.exp_manager import exp_manager |
|
|
| DEVICE_CAPABILITY = None |
| if torch.cuda.is_available(): |
| DEVICE_CAPABILITY = torch.cuda.get_device_capability() |
|
|
|
|
| def extract_ema_weights(pl_module, trainer): |
| ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] |
| ema_callback.swap_model_weights(trainer) |
| weights = extract_weights(pl_module) |
| ema_callback.swap_model_weights(trainer) |
| return weights |
|
|
|
|
| def extract_weights(pl_module): |
| return [w.detach().clone() for w in pl_module.parameters()] |
|
|
|
|
| class RandomDataset(torch.utils.data.Dataset): |
| def __init__(self, size, length): |
| self.len = length |
| self.data = torch.randn(length, size) |
|
|
| def __getitem__(self, index): |
| return self.data[index] |
|
|
| def __len__(self): |
| return self.len |
|
|
|
|
| class ExampleModel(ModelPT): |
| def __init__(self, *args, **kwargs): |
| cfg = OmegaConf.structured({}) |
| super().__init__(cfg) |
| self.l1 = torch.nn.modules.Linear(in_features=32, out_features=32) |
| self.bn = torch.nn.BatchNorm1d(32) |
|
|
| def train_dataloader(self): |
| dataset = RandomDataset(32, 16) |
| return torch.utils.data.DataLoader(dataset, batch_size=2) |
|
|
| def val_dataloader(self): |
| dataset = RandomDataset(32, 16) |
| return torch.utils.data.DataLoader(dataset, batch_size=2) |
|
|
| def test_dataloader(self): |
| dataset = RandomDataset(32, 16) |
| dl = torch.utils.data.DataLoader(dataset, batch_size=2) |
| self._test_names = ['test_{}_'.format(idx) for idx in range(len(dl))] |
| return dl |
|
|
| def forward(self, batch): |
| return self.l1(self.bn(batch)).sum() |
|
|
| def training_step(self, batch, batch_idx): |
| return self(batch) |
|
|
| def validation_step(self, batch, batch_idx): |
| loss = self(batch) |
| self.validation_step_outputs.append(loss) |
| return loss |
|
|
| def test_step(self, batch, batch_idx): |
| loss = self(batch) |
| self.test_step_outputs.append(loss) |
| return loss |
|
|
| def configure_optimizers(self): |
| return torch.optim.SGD(self.parameters(), lr=1e-3) |
|
|
| def list_available_models(self): |
| pass |
|
|
| def setup_training_data(self, train_data_config: Union[DictConfig, Dict]): |
| pass |
|
|
| def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]): |
| pass |
|
|
| def setup_test_data(self, val_data_config: Union[DictConfig, Dict]): |
| pass |
|
|
| def on_validation_epoch_end(self): |
| self.log("val_loss", torch.stack(self.validation_step_outputs).mean()) |
| self.validation_step_outputs.clear() |
|
|
|
|
| class TestEMAConfig: |
| @pytest.mark.unit |
| def test_ema_value(self): |
| with pytest.raises(MisconfigurationException, match="between 0 and 1"): |
| EMA(decay=2) |
|
|
| @pytest.mark.unit |
| @pytest.mark.run_only_on('GPU') |
| def test_ema_saved_state(self, tmpdir, caplog): |
| """Test to ensure that when we re-load the EMA callback, it loads the EMA weights correctly.""" |
| temp_path = os.path.join(tmpdir, 'saved_state') |
|
|
| class TerminateCallback(Callback): |
| def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
| self.saved_ema_weights = extract_ema_weights(pl_module, trainer) |
| self.pl_module_weights = extract_weights(pl_module) |
| raise SystemExit |
|
|
| model = ExampleModel() |
| terminate_callback = TerminateCallback() |
|
|
| trainer = Trainer( |
| max_epochs=2, |
| limit_val_batches=1, |
| limit_train_batches=16, |
| logger=False, |
| val_check_interval=0.5, |
| enable_checkpointing=False, |
| accelerator='gpu', |
| devices=1, |
| callbacks=[terminate_callback], |
| ) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True}, |
| "explicit_log_dir": str(temp_path), |
| "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
| }, |
| ) |
| with pytest.raises(SystemExit): |
| trainer.fit(model=model) |
| resume_path = os.path.join(temp_path, 'checkpoints/epoch=0-step=8.ckpt') |
|
|
| model = ExampleModel() |
|
|
| class CheckStateCallback(Callback): |
| def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
| weights = extract_weights(pl_module) |
| for x, y in zip(weights, terminate_callback.pl_module_weights): |
| assert torch.allclose(x.cpu(), y.cpu()) |
| current_ema_weights = extract_ema_weights(pl_module, trainer) |
| for x, y in zip(current_ema_weights, terminate_callback.saved_ema_weights): |
| assert torch.allclose(x.cpu(), y.cpu()) |
|
|
| for optimizer in trainer.optimizers: |
| assert isinstance(optimizer, EMAOptimizer) |
| assert optimizer.current_step == 8 |
|
|
| trainer = Trainer( |
| max_epochs=2, |
| limit_val_batches=0, |
| limit_train_batches=16, |
| logger=False, |
| enable_checkpointing=False, |
| accelerator='gpu', |
| devices=1, |
| ) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True}, |
| "explicit_log_dir": str(temp_path), |
| "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
| }, |
| ) |
| |
| trainer.callbacks.append(CheckStateCallback()) |
| trainer.fit(model, ckpt_path=resume_path) |
|
|
| |
| ema_path = os.path.join(temp_path, 'checkpoints/epoch=0-step=8-EMA.ckpt') |
|
|
| trainer = Trainer( |
| max_epochs=1, |
| limit_val_batches=0, |
| limit_train_batches=1, |
| logger=False, |
| enable_checkpointing=False, |
| accelerator='gpu', |
| devices=1, |
| ) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True}, |
| "explicit_log_dir": str(temp_path), |
| "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
| }, |
| ) |
| trainer.fit(model, ckpt_path=ema_path) |
|
|
| |
| os.remove(ema_path) |
|
|
| trainer = Trainer( |
| max_epochs=1, |
| limit_val_batches=0, |
| limit_train_batches=1, |
| logger=False, |
| enable_checkpointing=False, |
| accelerator='gpu', |
| devices=1, |
| ) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True, "validate_original_weights": True}, |
| "explicit_log_dir": str(temp_path), |
| "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
| }, |
| ) |
| with pytest.raises( |
| MisconfigurationException, match="Unable to find the associated EMA weights when re-loading" |
| ): |
| trainer.fit(model, ckpt_path=resume_path) |
|
|
| @pytest.mark.unit |
| @pytest.mark.run_only_on('GPU') |
| def test_exp_manager_ema_weights(self, tmpdir): |
| """Test to ensure that the exp manager adds the EMA callback, and we save an additional EMA checkpoint.""" |
| tmp_path = tmpdir / "exp_manager_test" |
| model = ExampleModel() |
| trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False, accelerator='gpu', devices=1) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True, "validate_original_weights": True}, |
| "explicit_log_dir": str(tmp_path), |
| "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
| }, |
| ) |
| assert any(isinstance(callback, EMA) for callback in trainer.callbacks) |
| trainer.fit(model) |
| ema_weights = extract_ema_weights(model, trainer) |
|
|
| assert os.path.exists(tmp_path / "checkpoints/epoch=0-step=8.ckpt") |
| ema_path = tmp_path / "checkpoints/epoch=0-step=8-EMA.ckpt" |
| assert os.path.exists(ema_path) |
|
|
| duplicate_model = ExampleModel.load_from_checkpoint(str(ema_path)) |
| for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_weights): |
| assert torch.allclose(saved_weight.cpu(), ema_weight.cpu()) |
|
|
| @pytest.mark.unit |
| def test_exp_manager_ema_weights_topk(self, tmpdir): |
| """Test to ensure that EMA correctly ensures we only keep topk checkpoints.""" |
| tmp_path = tmpdir / "exp_manager_test" |
| model = ExampleModel() |
| save_top_k = 3 |
|
|
| trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True}, |
| "explicit_log_dir": str(tmp_path), |
| "checkpoint_callback_params": {"save_top_k": save_top_k}, |
| }, |
| ) |
| trainer.fit(model) |
|
|
| |
| assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1 |
|
|
| @pytest.mark.unit |
| def test_exp_manager_ema_weights_topk_resume(self, tmpdir): |
| """Test to ensure that we always keep top_k checkpoints, even after resuming.""" |
| tmp_path = tmpdir / "exp_manager_test" |
| model = ExampleModel() |
| save_top_k = 3 |
|
|
| trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True}, |
| "explicit_log_dir": str(tmp_path), |
| "checkpoint_callback_params": {"save_top_k": save_top_k}, |
| }, |
| ) |
| trainer.fit(model) |
|
|
| |
| assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1 |
|
|
| |
| save_top_k = 2 |
|
|
| trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True}, |
| "explicit_log_dir": str(tmp_path), |
| "resume_if_exists": True, |
| "checkpoint_callback_params": {"save_top_k": save_top_k}, |
| }, |
| ) |
| trainer.fit(model) |
|
|
| |
| assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1 |
|
|
|
|
| class TestEMATrain: |
| @pytest.mark.unit |
| @pytest.mark.parametrize( |
| "precision", |
| [ |
| 32, |
| 16, |
| pytest.param( |
| "bf16", |
| marks=pytest.mark.skipif( |
| not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, |
| reason='bfloat16 is not supported on this device', |
| ), |
| ), |
| ], |
| ) |
| @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) |
| @pytest.mark.parametrize("validate_original_weights", [True, False]) |
| @pytest.mark.run_only_on('GPU') |
| def test_ema_run_cuda( |
| self, |
| test_data_dir, |
| precision, |
| accumulate_grad_batches, |
| validate_original_weights, |
| tmpdir, |
| ): |
| self.run_training_test( |
| accumulate_grad_batches=accumulate_grad_batches, |
| validate_original_weights=validate_original_weights, |
| accelerator='gpu', |
| precision=precision, |
| tmpdir=tmpdir, |
| ) |
|
|
| @pytest.mark.unit |
| @pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) |
| @pytest.mark.parametrize("validate_original_weights", [True, False]) |
| def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, validate_original_weights, tmpdir): |
| self.run_training_test( |
| accumulate_grad_batches=accumulate_grad_batches, |
| validate_original_weights=validate_original_weights, |
| accelerator='cpu', |
| precision=32, |
| tmpdir=tmpdir, |
| ) |
|
|
| def run_training_test(self, accumulate_grad_batches, validate_original_weights, accelerator, precision, tmpdir): |
| pl.seed_everything(123) |
| model = ExampleModel() |
| trainer = Trainer( |
| max_epochs=1, |
| precision=precision, |
| limit_train_batches=10, |
| limit_val_batches=10, |
| logger=False, |
| accumulate_grad_batches=accumulate_grad_batches, |
| num_sanity_val_steps=0, |
| enable_model_summary=False, |
| enable_checkpointing=False, |
| accelerator=accelerator, |
| devices=1, |
| ) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True, "validate_original_weights": validate_original_weights, "decay": 0.999}, |
| "explicit_log_dir": str(tmpdir), |
| "checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
| }, |
| ) |
| |
| trainer.callbacks.append(EMAAssertCallback()) |
| trainer.callbacks.insert(0, EMAValidationAssertCallback()) |
| trainer.fit(model=model, val_dataloaders=model.train_dataloader()) |
|
|
| @pytest.mark.unit |
| def test_ema_run_with_save_best_model(self, tmpdir): |
| """Test to ensure that we save the model correctly when save best model is set to True.""" |
| tmp_path = tmpdir / "exp_manager_test" |
| model = ExampleModel() |
|
|
| trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False, devices=1, limit_train_batches=1) |
| exp_manager( |
| trainer, |
| { |
| "ema": {"enable": True}, |
| "explicit_log_dir": str(tmp_path), |
| "checkpoint_callback_params": {"save_best_model": True}, |
| }, |
| ) |
| trainer.fit(model) |
|
|
|
|
| class EMAAssertCallback(Callback): |
| def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
| model_weights = extract_weights(pl_module) |
| self.ema_weights = extract_ema_weights(pl_module, trainer) |
| for x, y in zip(model_weights, self.ema_weights): |
| assert torch.allclose(x, y) |
|
|
| def on_train_batch_end( |
| self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int |
| ) -> None: |
| if (batch_idx + 1) % trainer.accumulate_grad_batches != 0: |
| |
| return |
| ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] |
| decay = ema_callback.decay |
| expected_ema_weights = [] |
|
|
| new_weights = extract_weights(pl_module) |
|
|
| for ema_weight, new_weight in zip(self.ema_weights, new_weights): |
| expected_ema_weight = ema_weight * decay |
| expected_ema_weight += new_weight * (1 - decay) |
| expected_ema_weights.append(expected_ema_weight) |
| ema_weights = extract_ema_weights(pl_module, trainer) |
| for actual_ema_weight, expected_ema_weight in zip(ema_weights, expected_ema_weights): |
| assert torch.allclose(actual_ema_weight, expected_ema_weight) |
| self.ema_weights = expected_ema_weights |
|
|
|
|
| class EMAValidationAssertCallback(Callback): |
| def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
| ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] |
| self._original_weights = extract_weights(pl_module) |
| self._ema_weights = extract_ema_weights(pl_module, trainer) |
| |
| super().on_validation_start(trainer, pl_module) |
| if not ema_callback.validate_original_weights: |
| if ema_callback._ema_initialized: |
| |
| for ema_weights, module_weights in zip(self._ema_weights, extract_weights(pl_module)): |
| torch.allclose(ema_weights, module_weights) |
|
|
| def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
| ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] |
| if not ema_callback.validate_original_weights: |
| model_weights = extract_weights(pl_module) |
| if ema_callback._ema_initialized: |
| for orig_weights, module_weights in zip(self._original_weights, model_weights): |
| torch.allclose(orig_weights.cpu(), module_weights.cpu()) |
|
|