|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import lightning.pytorch as pl |
|
|
import pytest |
|
|
|
|
|
from nemo.utils.exp_manager import exp_manager |
|
|
|
|
|
try: |
|
|
from ptl_resiliency import FaultToleranceCallback |
|
|
|
|
|
HAVE_FT = True |
|
|
except (ImportError, ModuleNotFoundError): |
|
|
HAVE_FT = False |
|
|
|
|
|
|
|
|
@pytest.mark.skipif(not HAVE_FT, reason="requires resiliency package to be installed.") |
|
|
class TestFaultTolerance: |
|
|
|
|
|
@pytest.mark.unit |
|
|
def test_fault_tol_callback_not_created_by_default(self): |
|
|
"""There should be no FT callback by default""" |
|
|
test_conf = {"create_tensorboard_logger": False, "create_checkpoint_callback": False} |
|
|
test_trainer = pl.Trainer(accelerator='cpu') |
|
|
ft_callback_found = None |
|
|
exp_manager(test_trainer, test_conf) |
|
|
for cb in test_trainer.callbacks: |
|
|
if isinstance(cb, FaultToleranceCallback): |
|
|
ft_callback_found = cb |
|
|
assert ft_callback_found is None |
|
|
|
|
|
@pytest.mark.unit |
|
|
def test_fault_tol_callback_created(self): |
|
|
"""Verify that fault tolerance callback is created""" |
|
|
try: |
|
|
os.environ['FAULT_TOL_CFG_PATH'] = "/tmp/dummy" |
|
|
test_conf = { |
|
|
"create_tensorboard_logger": False, |
|
|
"create_checkpoint_callback": False, |
|
|
"create_fault_tolerance_callback": True, |
|
|
} |
|
|
test_trainer = pl.Trainer(accelerator='cpu') |
|
|
ft_callback_found = None |
|
|
exp_manager(test_trainer, test_conf) |
|
|
for cb in test_trainer.callbacks: |
|
|
if isinstance(cb, FaultToleranceCallback): |
|
|
ft_callback_found = cb |
|
|
assert ft_callback_found is not None |
|
|
finally: |
|
|
del os.environ['FAULT_TOL_CFG_PATH'] |
|
|
|