|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os.path |
|
|
from typing import Any, Dict, Union |
|
|
|
|
|
import pytest |
|
|
import pytorch_lightning as pl |
|
|
import torch |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
from pytorch_lightning import Callback, Trainer |
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException |
|
|
from pytorch_lightning.utilities.types import STEP_OUTPUT |
|
|
|
|
|
from nemo.collections.common.callbacks import EMA |
|
|
from nemo.collections.common.callbacks.ema import EMAOptimizer |
|
|
from nemo.core import ModelPT |
|
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
|
|
DEVICE_CAPABILITY = None |
|
|
if torch.cuda.is_available(): |
|
|
DEVICE_CAPABILITY = torch.cuda.get_device_capability() |
|
|
|
|
|
|
|
|
def extract_ema_weights(pl_module, trainer): |
|
|
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] |
|
|
ema_callback.swap_model_weights(trainer) |
|
|
weights = extract_weights(pl_module) |
|
|
ema_callback.swap_model_weights(trainer) |
|
|
return weights |
|
|
|
|
|
|
|
|
def extract_weights(pl_module): |
|
|
return [w.detach().clone() for w in pl_module.parameters()] |
|
|
|
|
|
|
|
|
class RandomDataset(torch.utils.data.Dataset): |
|
|
def __init__(self, size, length): |
|
|
self.len = length |
|
|
self.data = torch.randn(length, size) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
return self.data[index] |
|
|
|
|
|
def __len__(self): |
|
|
return self.len |
|
|
|
|
|
|
|
|
class ExampleModel(ModelPT): |
|
|
def __init__(self, *args, **kwargs): |
|
|
cfg = OmegaConf.structured({}) |
|
|
super().__init__(cfg) |
|
|
self.l1 = torch.nn.modules.Linear(in_features=32, out_features=32) |
|
|
self.bn = torch.nn.BatchNorm1d(32) |
|
|
|
|
|
def train_dataloader(self): |
|
|
dataset = RandomDataset(32, 16) |
|
|
return torch.utils.data.DataLoader(dataset, batch_size=2) |
|
|
|
|
|
def val_dataloader(self): |
|
|
dataset = RandomDataset(32, 16) |
|
|
return torch.utils.data.DataLoader(dataset, batch_size=2) |
|
|
|
|
|
def test_dataloader(self): |
|
|
dataset = RandomDataset(32, 16) |
|
|
dl = torch.utils.data.DataLoader(dataset, batch_size=2) |
|
|
self._test_names = ['test_{}_'.format(idx) for idx in range(len(dl))] |
|
|
return dl |
|
|
|
|
|
def forward(self, batch): |
|
|
return self.l1(self.bn(batch)).sum() |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
return self(batch) |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
return self(batch) |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
return self(batch) |
|
|
|
|
|
def configure_optimizers(self): |
|
|
return torch.optim.SGD(self.parameters(), lr=1e-3) |
|
|
|
|
|
def list_available_models(self): |
|
|
pass |
|
|
|
|
|
def setup_training_data(self, train_data_config: Union[DictConfig, Dict]): |
|
|
pass |
|
|
|
|
|
def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]): |
|
|
pass |
|
|
|
|
|
def setup_test_data(self, val_data_config: Union[DictConfig, Dict]): |
|
|
pass |
|
|
|
|
|
def validation_epoch_end(self, loss): |
|
|
self.log("val_loss", torch.stack(loss).mean()) |
|
|
|
|
|
|
|
|
class TestEMAConfig: |
|
|
@pytest.mark.unit |
|
|
def test_ema_value(self): |
|
|
with pytest.raises(MisconfigurationException, match="between 0 and 1"): |
|
|
EMA(decay=2) |
|
|
|
|
|
@pytest.mark.unit |
|
|
@pytest.mark.run_only_on('GPU') |
|
|
def test_ema_saved_state(self, tmpdir, caplog): |
|
|
"""Test to ensure that when we re-load the EMA callback, it loads the EMA weights correctly.""" |
|
|
temp_path = os.path.join(tmpdir, 'saved_state') |
|
|
|
|
|
class TerminateCallback(Callback): |
|
|
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
self.saved_ema_weights = extract_ema_weights(pl_module, trainer) |
|
|
self.pl_module_weights = extract_weights(pl_module) |
|
|
raise SystemExit |
|
|
|
|
|
model = ExampleModel() |
|
|
terminate_callback = TerminateCallback() |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=2, |
|
|
limit_val_batches=1, |
|
|
limit_train_batches=16, |
|
|
logger=False, |
|
|
val_check_interval=0.5, |
|
|
enable_checkpointing=False, |
|
|
accelerator='gpu', |
|
|
devices=1, |
|
|
callbacks=[terminate_callback], |
|
|
) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True}, |
|
|
"explicit_log_dir": str(temp_path), |
|
|
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
|
|
}, |
|
|
) |
|
|
with pytest.raises(SystemExit): |
|
|
trainer.fit(model=model) |
|
|
resume_path = os.path.join(temp_path, 'checkpoints/epoch=0-step=8.ckpt') |
|
|
|
|
|
model = ExampleModel() |
|
|
|
|
|
class CheckStateCallback(Callback): |
|
|
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
weights = extract_weights(pl_module) |
|
|
for x, y in zip(weights, terminate_callback.pl_module_weights): |
|
|
assert torch.allclose(x.cpu(), y.cpu()) |
|
|
current_ema_weights = extract_ema_weights(pl_module, trainer) |
|
|
for x, y in zip(current_ema_weights, terminate_callback.saved_ema_weights): |
|
|
assert torch.allclose(x.cpu(), y.cpu()) |
|
|
|
|
|
for optimizer in trainer.optimizers: |
|
|
assert isinstance(optimizer, EMAOptimizer) |
|
|
assert optimizer.current_step == 8 |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=2, |
|
|
limit_val_batches=0, |
|
|
limit_train_batches=16, |
|
|
logger=False, |
|
|
enable_checkpointing=False, |
|
|
accelerator='gpu', |
|
|
devices=1, |
|
|
) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True}, |
|
|
"explicit_log_dir": str(temp_path), |
|
|
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
|
|
}, |
|
|
) |
|
|
|
|
|
trainer.callbacks.append(CheckStateCallback()) |
|
|
trainer.fit(model, ckpt_path=resume_path) |
|
|
|
|
|
|
|
|
ema_path = os.path.join(temp_path, 'checkpoints/epoch=0-step=8-EMA.ckpt') |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=1, |
|
|
limit_val_batches=0, |
|
|
limit_train_batches=1, |
|
|
logger=False, |
|
|
enable_checkpointing=False, |
|
|
accelerator='gpu', |
|
|
devices=1, |
|
|
) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True}, |
|
|
"explicit_log_dir": str(temp_path), |
|
|
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
|
|
}, |
|
|
) |
|
|
trainer.fit(model, ckpt_path=ema_path) |
|
|
|
|
|
|
|
|
os.remove(ema_path) |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=1, |
|
|
limit_val_batches=0, |
|
|
limit_train_batches=1, |
|
|
logger=False, |
|
|
enable_checkpointing=False, |
|
|
accelerator='gpu', |
|
|
devices=1, |
|
|
) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True, "validate_original_weights": True}, |
|
|
"explicit_log_dir": str(temp_path), |
|
|
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
|
|
}, |
|
|
) |
|
|
with pytest.raises( |
|
|
MisconfigurationException, match="Unable to find the associated EMA weights when re-loading" |
|
|
): |
|
|
trainer.fit(model, ckpt_path=resume_path) |
|
|
|
|
|
@pytest.mark.unit |
|
|
@pytest.mark.run_only_on('GPU') |
|
|
def test_exp_manager_ema_weights(self, tmpdir): |
|
|
"""Test to ensure that the exp manager adds the EMA callback, and we save an additional EMA checkpoint.""" |
|
|
tmp_path = tmpdir / "exp_manager_test" |
|
|
model = ExampleModel() |
|
|
trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False, accelerator='gpu', devices=1) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True, "validate_original_weights": True}, |
|
|
"explicit_log_dir": str(tmp_path), |
|
|
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
|
|
}, |
|
|
) |
|
|
assert any(isinstance(callback, EMA) for callback in trainer.callbacks) |
|
|
trainer.fit(model) |
|
|
ema_weights = extract_ema_weights(model, trainer) |
|
|
|
|
|
assert os.path.exists(tmp_path / "checkpoints/epoch=0-step=8.ckpt") |
|
|
ema_path = tmp_path / "checkpoints/epoch=0-step=8-EMA.ckpt" |
|
|
assert os.path.exists(ema_path) |
|
|
|
|
|
duplicate_model = ExampleModel.load_from_checkpoint(str(ema_path)) |
|
|
for saved_weight, ema_weight in zip(duplicate_model.state_dict().values(), ema_weights): |
|
|
assert torch.allclose(saved_weight.cpu(), ema_weight.cpu()) |
|
|
|
|
|
@pytest.mark.unit |
|
|
def test_exp_manager_ema_weights_topk(self, tmpdir): |
|
|
"""Test to ensure that EMA correctly ensures we only keep topk checkpoints.""" |
|
|
tmp_path = tmpdir / "exp_manager_test" |
|
|
model = ExampleModel() |
|
|
save_top_k = 3 |
|
|
|
|
|
trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True}, |
|
|
"explicit_log_dir": str(tmp_path), |
|
|
"checkpoint_callback_params": {"save_top_k": save_top_k}, |
|
|
}, |
|
|
) |
|
|
trainer.fit(model) |
|
|
|
|
|
|
|
|
assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1 |
|
|
|
|
|
@pytest.mark.unit |
|
|
def test_exp_manager_ema_weights_topk_resume(self, tmpdir): |
|
|
"""Test to ensure that we always keep top_k checkpoints, even after resuming.""" |
|
|
tmp_path = tmpdir / "exp_manager_test" |
|
|
model = ExampleModel() |
|
|
save_top_k = 3 |
|
|
|
|
|
trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True}, |
|
|
"explicit_log_dir": str(tmp_path), |
|
|
"checkpoint_callback_params": {"save_top_k": save_top_k}, |
|
|
}, |
|
|
) |
|
|
trainer.fit(model) |
|
|
|
|
|
|
|
|
assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1 |
|
|
|
|
|
|
|
|
save_top_k = 2 |
|
|
|
|
|
trainer = Trainer(max_epochs=10, enable_checkpointing=False, logger=False, devices=1) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True}, |
|
|
"explicit_log_dir": str(tmp_path), |
|
|
"resume_if_exists": True, |
|
|
"checkpoint_callback_params": {"save_top_k": save_top_k}, |
|
|
}, |
|
|
) |
|
|
trainer.fit(model) |
|
|
|
|
|
|
|
|
assert len(os.listdir(tmp_path / "checkpoints/")) == (save_top_k + 1) * 2 + 1 |
|
|
|
|
|
|
|
|
class TestEMATrain: |
|
|
@pytest.mark.unit |
|
|
@pytest.mark.parametrize( |
|
|
"precision", |
|
|
[ |
|
|
32, |
|
|
16, |
|
|
pytest.param( |
|
|
"bf16", |
|
|
marks=pytest.mark.skipif( |
|
|
not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, |
|
|
reason='bfloat16 is not supported on this device', |
|
|
), |
|
|
), |
|
|
], |
|
|
) |
|
|
@pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) |
|
|
@pytest.mark.parametrize("validate_original_weights", [True, False]) |
|
|
@pytest.mark.run_only_on('GPU') |
|
|
def test_ema_run_cuda( |
|
|
self, test_data_dir, precision, accumulate_grad_batches, validate_original_weights, tmpdir, |
|
|
): |
|
|
self.run_training_test( |
|
|
accumulate_grad_batches=accumulate_grad_batches, |
|
|
validate_original_weights=validate_original_weights, |
|
|
accelerator='gpu', |
|
|
precision=precision, |
|
|
tmpdir=tmpdir, |
|
|
) |
|
|
|
|
|
@pytest.mark.unit |
|
|
@pytest.mark.parametrize("accumulate_grad_batches", [1, 2]) |
|
|
@pytest.mark.parametrize("validate_original_weights", [True, False]) |
|
|
def test_ema_run_cpu(self, test_data_dir, accumulate_grad_batches, validate_original_weights, tmpdir): |
|
|
self.run_training_test( |
|
|
accumulate_grad_batches=accumulate_grad_batches, |
|
|
validate_original_weights=validate_original_weights, |
|
|
accelerator='cpu', |
|
|
precision=32, |
|
|
tmpdir=tmpdir, |
|
|
) |
|
|
|
|
|
def run_training_test(self, accumulate_grad_batches, validate_original_weights, accelerator, precision, tmpdir): |
|
|
pl.seed_everything(123) |
|
|
model = ExampleModel() |
|
|
trainer = Trainer( |
|
|
max_epochs=1, |
|
|
precision=precision, |
|
|
limit_train_batches=10, |
|
|
limit_val_batches=10, |
|
|
logger=False, |
|
|
accumulate_grad_batches=accumulate_grad_batches, |
|
|
num_sanity_val_steps=0, |
|
|
enable_model_summary=False, |
|
|
enable_checkpointing=False, |
|
|
accelerator=accelerator, |
|
|
devices=1, |
|
|
) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True, "validate_original_weights": validate_original_weights, "decay": 0.999}, |
|
|
"explicit_log_dir": str(tmpdir), |
|
|
"checkpoint_callback_params": {"filename": f"{{epoch}}-{{step}}"}, |
|
|
}, |
|
|
) |
|
|
|
|
|
trainer.callbacks.append(EMAAssertCallback()) |
|
|
trainer.callbacks.insert(0, EMAValidationAssertCallback()) |
|
|
trainer.fit(model=model, val_dataloaders=model.train_dataloader()) |
|
|
|
|
|
@pytest.mark.unit |
|
|
def test_ema_run_with_save_best_model(self, tmpdir): |
|
|
"""Test to ensure that we save the model correctly when save best model is set to True.""" |
|
|
tmp_path = tmpdir / "exp_manager_test" |
|
|
model = ExampleModel() |
|
|
|
|
|
trainer = Trainer(max_epochs=1, enable_checkpointing=False, logger=False, devices=1, limit_train_batches=1) |
|
|
exp_manager( |
|
|
trainer, |
|
|
{ |
|
|
"ema": {"enable": True}, |
|
|
"explicit_log_dir": str(tmp_path), |
|
|
"checkpoint_callback_params": {"save_best_model": True}, |
|
|
}, |
|
|
) |
|
|
trainer.fit(model) |
|
|
|
|
|
|
|
|
class EMAAssertCallback(Callback): |
|
|
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
model_weights = extract_weights(pl_module) |
|
|
self.ema_weights = extract_ema_weights(pl_module, trainer) |
|
|
for x, y in zip(model_weights, self.ema_weights): |
|
|
assert torch.allclose(x, y) |
|
|
|
|
|
def on_train_batch_end( |
|
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int |
|
|
) -> None: |
|
|
if (batch_idx + 1) % trainer.accumulate_grad_batches != 0: |
|
|
|
|
|
return |
|
|
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] |
|
|
decay = ema_callback.decay |
|
|
expected_ema_weights = [] |
|
|
|
|
|
new_weights = extract_weights(pl_module) |
|
|
|
|
|
for ema_weight, new_weight in zip(self.ema_weights, new_weights): |
|
|
expected_ema_weight = ema_weight * decay |
|
|
expected_ema_weight += new_weight * (1 - decay) |
|
|
expected_ema_weights.append(expected_ema_weight) |
|
|
ema_weights = extract_ema_weights(pl_module, trainer) |
|
|
for actual_ema_weight, expected_ema_weight in zip(ema_weights, expected_ema_weights): |
|
|
assert torch.allclose(actual_ema_weight, expected_ema_weight) |
|
|
self.ema_weights = expected_ema_weights |
|
|
|
|
|
|
|
|
class EMAValidationAssertCallback(Callback): |
|
|
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] |
|
|
self._original_weights = extract_weights(pl_module) |
|
|
self._ema_weights = extract_ema_weights(pl_module, trainer) |
|
|
|
|
|
super().on_validation_start(trainer, pl_module) |
|
|
if not ema_callback.validate_original_weights: |
|
|
if ema_callback._ema_initialized: |
|
|
|
|
|
for ema_weights, module_weights in zip(self._ema_weights, extract_weights(pl_module)): |
|
|
torch.allclose(ema_weights, module_weights) |
|
|
|
|
|
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
ema_callback = [x for x in trainer.callbacks if isinstance(x, EMA)][0] |
|
|
if not ema_callback.validate_original_weights: |
|
|
model_weights = extract_weights(pl_module) |
|
|
if ema_callback._ema_initialized: |
|
|
for orig_weights, module_weights in zip(self._original_weights, model_weights): |
|
|
torch.allclose(orig_weights.cpu(), module_weights.cpu()) |
|
|
|