|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import signal |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
from lightning.pytorch.callbacks import Callback |
|
|
|
|
|
from nemo.utils import logging |
|
|
|
|
|
|
|
|
class PreemptionCallback(Callback): |
|
|
""" |
|
|
PreemptionCallback class creates a callback that checks for preemption during training at the end of every step. |
|
|
Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the current state in a checkpoint as *last.ckpt. |
|
|
(to be able to start from the same step without wasting any compute while resuming the next time). |
|
|
|
|
|
PreemptionCallback is always enabled by default via the arg create_preemption_callback under ExpManagerConfig. To disable please pass |
|
|
create_preemption_callback: False in your config file. |
|
|
""" |
|
|
|
|
|
def __init__(self, checkpoint_callback, sig=None): |
|
|
self.sig = sig |
|
|
if self.sig is None: |
|
|
self.sig = signal.SIGTERM |
|
|
self.checkpoint_callback = checkpoint_callback |
|
|
self.preemption_enabled = False |
|
|
|
|
|
@property |
|
|
def interrupted(self): |
|
|
interrupted = torch.tensor(self._interrupted, device=torch.cuda.current_device(), dtype=torch.int32) |
|
|
torch.distributed.broadcast(interrupted, 0) |
|
|
interrupted = bool(interrupted.item()) |
|
|
return interrupted |
|
|
|
|
|
def on_train_start(self, trainer, pl_module): |
|
|
""" |
|
|
Defines custom handlers at the beginning of training to be executed when the |
|
|
preemption signal is received. |
|
|
""" |
|
|
|
|
|
|
|
|
if not (torch.distributed.is_available() and torch.distributed.is_initialized()): |
|
|
logging.info("Preemption requires torch distributed to be initialized, disabling preemption") |
|
|
else: |
|
|
self.preemption_enabled = True |
|
|
|
|
|
self._interrupted = False |
|
|
self.released = False |
|
|
self.original_handler = signal.getsignal(self.sig) |
|
|
|
|
|
|
|
|
def master_handler(signum, frame): |
|
|
self.release() |
|
|
self._interrupted = True |
|
|
|
|
|
|
|
|
def ignoring_handler(signum, frame): |
|
|
self.release() |
|
|
|
|
|
self.private_rank = torch.distributed.get_rank() |
|
|
if self.private_rank == 0: |
|
|
signal.signal(self.sig, master_handler) |
|
|
else: |
|
|
signal.signal(self.sig, ignoring_handler) |
|
|
|
|
|
return self |
|
|
|
|
|
def on_train_end(self, trainer, pl_module): |
|
|
if self.preemption_enabled: |
|
|
self.release() |
|
|
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int): |
|
|
if self.preemption_enabled: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interrupted = self.interrupted |
|
|
if interrupted: |
|
|
logging.info("Received SIGTERM, saving checkpoint and exiting") |
|
|
monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer) |
|
|
self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) |
|
|
sys.exit(0) |
|
|
|
|
|
def release(self): |
|
|
if self.released: |
|
|
return False |
|
|
|
|
|
signal.signal(self.sig, self.original_handler) |
|
|
self.released = True |
|
|
return True |
|
|
|