|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import mmcv |
|
|
import torch.distributed as dist |
|
|
from mmcv.runner.hooks.hook import HOOKS, Hook |
|
|
from mmcv.runner.checkpoint import save_checkpoint |
|
|
|
|
|
|
|
|
|
|
|
@HOOKS.register_module() |
|
|
class GradChecker(Hook): |
|
|
def after_train_iter(self, runner): |
|
|
for key, val in runner.model.named_parameters(): |
|
|
if val.grad == None and val.requires_grad: |
|
|
print( |
|
|
"WARNNING: {key}'s parameters are not be used!!!!".format(key=key) |
|
|
) |
|
|
|
|
|
|
|
|
@HOOKS.register_module() |
|
|
class SamplerSkipIterationHook(Hook): |
|
|
"""Data-loading sampler for distributed training. |
|
|
|
|
|
When distributed training, it is only useful in conjunction with |
|
|
:obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same |
|
|
purpose with :obj:`IterLoader`. |
|
|
""" |
|
|
|
|
|
def __init__(self, out_dir=None): |
|
|
"""Init routine.""" |
|
|
self.out_dir = out_dir |
|
|
|
|
|
def before_train_epoch(self, runner): |
|
|
if hasattr(runner.data_loader.sampler, 'skip_iter_at_epoch_x'): |
|
|
|
|
|
runner.data_loader.sampler.skip_iter_at_epoch_x(runner._inner_iter) |
|
|
elif hasattr(runner.data_loader.batch_sampler.sampler, 'skip_iter_at_epoch_x'): |
|
|
|
|
|
runner.data_loader.batch_sampler.sampler.skip_iter_at_epoch_x(runner._inner_iter) |
|
|
|
|
|
|
|
|
_logger = logging.getLogger("autoresume_hook") |
|
|
|
|
|
sys.path.append(os.environ.get("SUBMIT_SCRIPTS", ".")) |
|
|
try: |
|
|
_logger.info("Importing AutoResume lib...") |
|
|
from userlib.auto_resume import AutoResume |
|
|
|
|
|
AutoResume.init() |
|
|
_logger.info("Found AutoResume SDK!") |
|
|
except: |
|
|
_logger.info("Did not find AutResume SDK!") |
|
|
AutoResume = None |
|
|
|
|
|
|
|
|
@HOOKS.register_module() |
|
|
class AutoResumeHook(Hook): |
|
|
"""AutoResume hook. |
|
|
|
|
|
A hook to interface with ADLR's AutoResume SDK. |
|
|
|
|
|
In order to use this hook, you must first import the AutoResume SDK |
|
|
in the main training script: |
|
|
|
|
|
sys.path.append(os.environ.get("SUBMIT_SCRIPTS", ".")) |
|
|
try: |
|
|
_logger.info("Importing AutoResume lib...") |
|
|
from userlib.auto_resume import AutoResume |
|
|
|
|
|
AutoResume.init() |
|
|
_logger.info("Success!") |
|
|
except: |
|
|
_logger.info("Failed!") |
|
|
AutoResume = None |
|
|
|
|
|
Also make sure you import the code for the auto-resume hook: |
|
|
|
|
|
import autoresume_hook |
|
|
|
|
|
In the main initialization routine, set the `resume` flag |
|
|
in the MMCV configure depending on whether the job is being resumed: |
|
|
|
|
|
if AutoResume is not None: |
|
|
auto_resume_details = AutoResume.get_resume_details() |
|
|
if auto_resume_details is not None: |
|
|
print_log(f"AutoResume details: {auto_resume_details}") |
|
|
cfg.resume = True |
|
|
|
|
|
Finally, in your MMSEG config, add the following statements: |
|
|
|
|
|
# Hook for auto-suspend/resume on ADLR clusters. |
|
|
custom_hooks = [dict(type="AutoResumeHook", interval=2000)] |
|
|
|
|
|
Args: |
|
|
interval: interval (in number of iterations) between checks as to |
|
|
whether to suspend. |
|
|
""" |
|
|
|
|
|
def __init__(self, interval: int = 1000, out_dir=None): |
|
|
"""Init routine.""" |
|
|
self.interval = interval |
|
|
self.out_dir = out_dir |
|
|
|
|
|
def every_n_train_iters(self, runner, n): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return runner.iter % n == 0 if n > 0 else False |
|
|
|
|
|
def after_train_iter( |
|
|
self, |
|
|
runner: mmcv.runner.Runner, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) -> None: |
|
|
"""Training iteration post hook. |
|
|
|
|
|
Args: |
|
|
runner: The runner of the training process. |
|
|
batch_idx: The index of the current batch in the train loop. |
|
|
data_batch: Data from dataloader. |
|
|
outputs: Outputs from model. |
|
|
""" |
|
|
|
|
|
if self.every_n_train_iters(runner, self.interval): |
|
|
if dist.is_initialized(): |
|
|
global_rank = dist.get_rank() |
|
|
else: |
|
|
global_rank = 0 |
|
|
runner.logger.info("AutoResumeHook: Checking whether to suspend...") |
|
|
|
|
|
|
|
|
should_preempt = ( |
|
|
AutoResume is not None and AutoResume.termination_requested() |
|
|
) |
|
|
if should_preempt: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
meta = { |
|
|
'epoch': runner.epoch, |
|
|
'iter': runner.iter, |
|
|
'inner_iter': runner.inner_iter, |
|
|
} |
|
|
print(f"saving info {meta}") |
|
|
|
|
|
|
|
|
runner.logger.info(f"refresh the latest_iter.pth") |
|
|
filename = f"iter_{runner.iter+1:010d}_epoch_{runner.epoch:04d}_inneriter_{runner.inner_iter+1:08d}.pth" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
filepath = os.path.join(self.out_dir, filename) |
|
|
|
|
|
save_checkpoint(runner.model, filepath, optimizer=runner.optimizer, meta=meta) |
|
|
|
|
|
|
|
|
dst_file = os.path.join(self.out_dir, 'latest_iter.pth') |
|
|
mmcv.symlink(filename, dst_file) |
|
|
|
|
|
|
|
|
if global_rank == 0: |
|
|
runner.logger.info(f"AutoResumeHook: Request resume...") |
|
|
AutoResume.request_resume() |
|
|
runner.logger.info(f"AutoResumeHook: Suspend the job...") |
|
|
sys.exit(0) |
|
|
|