camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
from typing import Any, Dict, Union
import pytest
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT
from nemo.collections.common.callbacks import EMA
from nemo.collections.common.callbacks.ema import EMAOptimizer
from nemo.core import ModelPT
from nemo.utils.exp_manager import exp_manager
DEVICE_CAPABILITY = None
if torch.cuda.is_available():
DEVICE_CAPABILITY = torch.cuda.get_device_capability()
def extract_ema_weights(pl_module, trainer):
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
ema_callback.swap_model_weights(trainer)
weights = extract_weights(pl_module)
ema_callback.swap_model_weights(trainer)
return weights
def extract_weights(pl_module):
return [w.detach().clone() for w in pl_module.parameters()]
class RandomDataset(torch.utils.data.Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class ExampleModel(ModelPT):
def __init__(self, *args, **kwargs):
cfg = OmegaConf.structured({})
super().__init__(cfg)
self.l1 = torch.nn.modules.Linear(in_features=32, out_features=32)
self.bn = torch.nn.BatchNorm1d(32)
def train_dataloader(self):
dataset = RandomDataset(32, 16)
return torch.utils.data.DataLoader(dataset, batch_size=2)
def val_dataloader(self):
dataset = RandomDataset(32, 16)
return torch.utils.data.DataLoader(dataset, batch_size=2)
def test_dataloader(self):
dataset = RandomDataset(32, 16)
dl = torch.utils.data.DataLoader(dataset, batch_size=2)
self._test_names = ['test_{}_'.format(idx) for idx in range(len(dl))]
return dl
def forward(self, batch):
return self.l1(self.bn(batch)).sum()
def training_step(self, batch, batch_idx):
return self(batch)
def validation_step(self, batch, batch_idx):
return self(batch)
def test_step(self, batch, batch_idx):
return self(batch)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=1e-3)
def list_available_models(self):
pass
def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
pass
def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]):
pass
def setup_test_data(self, val_data_config: Union[DictConfig, Dict]):
pass
def validation_epoch_end(self, loss):
self.log("val_loss", torch.stack(loss).mean())
class TestEMAConfig:
@pytest.mark.unit
def test_ema_value(self):
with pytest.raises(MisconfigurationException, match="between 0 and 1"):
EMA(decay=2)
@pytest.mark.unit
@pytest.mark.run_only_on('GPU')
def test_ema_saved_state(self, tmpdir, caplog):
"""Test to ensure that when we re-load the EMA callback, it loads the EMA weights correctly."""
temp_path = os.path.join(tmpdir, 'saved_state')
class TerminateCallback(Callback):
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.saved_ema_weights = extract_ema_weights(pl_module, trainer)
self.pl_module_weights = extract_weights(pl_module)
raise SystemExit
model = ExampleModel()
terminate_callback = TerminateCallback()
trainer = Trainer(
max_epochs=2,
limit_val_batches=1,
limit_train_batches=16,
logger=False,
val_check_interval=0.5,
enable_checkpointing=False,
accelerator='gpu',
devices=1,
callbacks=[terminate_callback],
)
exp_manager(
trainer,
{
"ema": {"enable": True},
"explicit_log_dir": str(temp_path),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
with pytest.raises(SystemExit):
trainer.fit(model=model)
resume_path = os.path.join(temp_path, 'checkpoints/epoch=0-step=8.ckpt')
model = ExampleModel()
class CheckStateCallback(Callback):
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
weights = extract_weights(pl_module)
for x, y in zip(weights, terminate_callback.pl_module_weights):
assert torch.allclose(x.cpu(), y.cpu())
current_ema_weights = extract_ema_weights(pl_module, trainer)
for x, y in zip(current_ema_weights, terminate_callback.saved_ema_weights):
assert torch.allclose(x.cpu(), y.cpu())
for optimizer in trainer.optimizers:
assert isinstance(optimizer, EMAOptimizer)
assert optimizer.current_step == 8
trainer = Trainer(
max_epochs=2,
limit_val_batches=0,
limit_train_batches=16,
logger=False,
enable_checkpointing=False,
accelerator='gpu',
devices=1,
)
exp_manager(
trainer,
{
"ema": {"enable": True},
"explicit_log_dir": str(temp_path),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
# add the callback after the exp manager has made modifications.
trainer.callbacks.append(CheckStateCallback())
trainer.fit(model, ckpt_path=resume_path)
# ensure we can resume from the EMA weights
ema_path = os.path.join(temp_path, 'checkpoints/epoch=0-step=8-EMA.ckpt')
trainer = Trainer(
max_epochs=1,
limit_val_batches=0,
limit_train_batches=1,
logger=False,
enable_checkpointing=False,
accelerator='gpu',
devices=1,
)
exp_manager(
trainer,
{
"ema": {"enable": True},
"explicit_log_dir": str(temp_path),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
trainer.fit(model, ckpt_path=ema_path)
# ensure that we warn when the EMA weights do not exist
os.remove(ema_path)
trainer = Trainer(
max_epochs=1,
limit_val_batches=0,
limit_train_batches=1,
logger=False,
enable_checkpointing=False,
accelerator='gpu',
devices=1,
)
exp_manager(
trainer,
{
"ema": {"enable": True, "validate_original_weights": True},
"explicit_log_dir": str(temp_path),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
with pytest.raises(
MisconfigurationException, match="Unable to find the associated EMA weights when re-loading"
):
trainer.fit(model, ckpt_path=resume_path)
@pytest.mark.unit
@pytest.mark.run_only_on('GPU')
def test_exp_manager_ema_weights(self, tmpdir):
"""Test to ensure that the exp manager adds the EMA callback, and we save an additional EMA checkpoint."""
tmp_path = tmpdir / "exp_manager_test"
model = ExampleModel()
trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False, accelerator='gpu', devices=1)
exp_manager(
trainer,
{
"ema": {"enable": True, "validate_original_weights": True},
"explicit_log_dir": str(tmp_path),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
assert any(isinstance(callback, EMA) for callback in trainer.callbacks)
trainer.fit(model)
ema_weights = extract_ema_weights(model, trainer)
assert os.path.exists(tmp_path / "checkpoints/epoch=0-step=8.ckpt")
ema_path = tmp_path / "checkpoints/epoch=0-step=8-EMA.ckpt"
assert os.path.exists(ema_path)
duplicate_model = ExampleModel.load_from_checkpoint(str(ema_path))
for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_weights):
assert torch.allclose(saved_weight.cpu(), ema_weight.cpu())
@pytest.mark.unit
def test_exp_manager_ema_weights_topk(self, tmpdir):
"""Test to ensure that EMA correctly ensures we only keep topk checkpoints."""
tmp_path = tmpdir / "exp_manager_test"
model = ExampleModel()
save_top_k = 3
trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1)
exp_manager(
trainer,
{
"ema": {"enable": True},
"explicit_log_dir": str(tmp_path),
"checkpoint_callback_params": {"save_top_k": save_top_k},
},
)
trainer.fit(model)
# we save 3 checkpoints for the model, 3 accompanied EMA weights, the last checkpoint and nemo model.
assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1
@pytest.mark.unit
def test_exp_manager_ema_weights_topk_resume(self, tmpdir):
"""Test to ensure that we always keep top_k checkpoints, even after resuming."""
tmp_path = tmpdir / "exp_manager_test"
model = ExampleModel()
save_top_k = 3
trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1)
exp_manager(
trainer,
{
"ema": {"enable": True},
"explicit_log_dir": str(tmp_path),
"checkpoint_callback_params": {"save_top_k": save_top_k},
},
)
trainer.fit(model)
# we save 3 checkpoints for the model, 3 accompanied EMA weights, the last checkpoint and nemo model.
assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1
# reduce the top_k number when resuming, we should see only 2 top_k checkpoints now (one is deleted).
save_top_k = 2
trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1)
exp_manager(
trainer,
{
"ema": {"enable": True},
"explicit_log_dir": str(tmp_path),
"resume_if_exists": True,
"checkpoint_callback_params": {"save_top_k": save_top_k},
},
)
trainer.fit(model)
# we save 2 checkpoints for the model, 2 accompanied EMA weights, the last checkpoint and nemo model.
assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1
class TestEMATrain:
@pytest.mark.unit
@pytest.mark.parametrize(
"precision",
[
32,
16,
pytest.param(
"bf16",
marks=pytest.mark.skipif(
not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8,
reason='bfloat16 is not supported on this device',
),
),
],
)
@pytest.mark.parametrize("accumulate_grad_batches", [1, 2])
@pytest.mark.parametrize("validate_original_weights", [True, False])
@pytest.mark.run_only_on('GPU')
def test_ema_run_cuda(
self, test_data_dir, precision, accumulate_grad_batches, validate_original_weights, tmpdir,
):
self.run_training_test(
accumulate_grad_batches=accumulate_grad_batches,
validate_original_weights=validate_original_weights,
accelerator='gpu',
precision=precision,
tmpdir=tmpdir,
)
@pytest.mark.unit
@pytest.mark.parametrize("accumulate_grad_batches", [1, 2])
@pytest.mark.parametrize("validate_original_weights", [True, False])
def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, validate_original_weights, tmpdir):
self.run_training_test(
accumulate_grad_batches=accumulate_grad_batches,
validate_original_weights=validate_original_weights,
accelerator='cpu',
precision=32,
tmpdir=tmpdir,
)
def run_training_test(self, accumulate_grad_batches, validate_original_weights, accelerator, precision, tmpdir):
pl.seed_everything(123)
model = ExampleModel()
trainer = Trainer(
max_epochs=1,
precision=precision,
limit_train_batches=10,
limit_val_batches=10,
logger=False,
accumulate_grad_batches=accumulate_grad_batches,
num_sanity_val_steps=0,
enable_model_summary=False,
enable_checkpointing=False,
accelerator=accelerator,
devices=1,
)
exp_manager(
trainer,
{
"ema": {"enable": True, "validate_original_weights": validate_original_weights, "decay": 0.999},
"explicit_log_dir": str(tmpdir),
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"},
},
)
# add the check callback after the exp manager has made modifications.
trainer.callbacks.append(EMAAssertCallback())
trainer.callbacks.insert(0, EMAValidationAssertCallback())
trainer.fit(model=model, val_dataloaders=model.train_dataloader())
@pytest.mark.unit
def test_ema_run_with_save_best_model(self, tmpdir):
"""Test to ensure that we save the model correctly when save best model is set to True."""
tmp_path = tmpdir / "exp_manager_test"
model = ExampleModel()
trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False, devices=1, limit_train_batches=1)
exp_manager(
trainer,
{
"ema": {"enable": True},
"explicit_log_dir": str(tmp_path),
"checkpoint_callback_params": {"save_best_model": True},
},
)
trainer.fit(model)
class EMAAssertCallback(Callback):
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
model_weights = extract_weights(pl_module)
self.ema_weights = extract_ema_weights(pl_module, trainer)
for x, y in zip(model_weights, self.ema_weights):
assert torch.allclose(x, y)
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
if (batch_idx + 1) % trainer.accumulate_grad_batches != 0:
# skip assertion as ema weights are not updated.
return
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
decay = ema_callback.decay
expected_ema_weights = []
new_weights = extract_weights(pl_module)
for ema_weight, new_weight in zip(self.ema_weights, new_weights):
expected_ema_weight = ema_weight * decay
expected_ema_weight += new_weight * (1 - decay)
expected_ema_weights.append(expected_ema_weight)
ema_weights = extract_ema_weights(pl_module, trainer)
for actual_ema_weight, expected_ema_weight in zip(ema_weights, expected_ema_weights):
assert torch.allclose(actual_ema_weight, expected_ema_weight)
self.ema_weights = expected_ema_weights
class EMAValidationAssertCallback(Callback):
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
self._original_weights = extract_weights(pl_module)
self._ema_weights = extract_ema_weights(pl_module, trainer)
# call original EMA function
super().on_validation_start(trainer, pl_module)
if not ema_callback.validate_original_weights:
if ema_callback._ema_initialized:
# check model weights are now EMA weights
for ema_weights, module_weights in zip(self._ema_weights, extract_weights(pl_module)):
torch.allclose(ema_weights, module_weights)
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0]
if not ema_callback.validate_original_weights:
model_weights = extract_weights(pl_module)
if ema_callback._ema_initialized:
for orig_weights, module_weights in zip(self._original_weights, model_weights):
torch.allclose(orig_weights.cpu(), module_weights.cpu())