|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from types import MethodType |
|
|
from typing import Any, Dict |
|
|
|
|
|
import lightning.pytorch as pl |
|
|
import torch |
|
|
from lightning.pytorch import LightningModule |
|
|
from lightning.pytorch.callbacks import Callback |
|
|
from lightning.pytorch.loops.optimization.automatic import ClosureResult |
|
|
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection, _ResultMetric |
|
|
from lightning.pytorch.utilities import CombinedLoader, rank_zero_info |
|
|
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature |
|
|
from lightning.pytorch.utilities.types import STEP_OUTPUT |
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
|
|
__all__ = ["CUDAGraphCallback"] |
|
|
|
|
|
|
|
|
def struct_copy_one(src): |
|
|
if isinstance(src, tuple): |
|
|
return tuple(struct_copy_one(i) for i in src) |
|
|
elif isinstance(src, list): |
|
|
return list(struct_copy_one(i) for i in src) |
|
|
elif isinstance(src, dict): |
|
|
return {k: struct_copy_one(src[k]) for k in src} |
|
|
elif isinstance(src, torch.Tensor): |
|
|
return src.clone().detach().cuda() |
|
|
else: |
|
|
return src |
|
|
|
|
|
|
|
|
def struct_copy_two(tgt, src): |
|
|
if isinstance(src, tuple): |
|
|
raise Exception(f"Unsupported copy for tuple yet: {type(src)}") |
|
|
elif isinstance(src, list): |
|
|
for i in range(len(src)): |
|
|
if isinstance(src[i], (tuple, list, dict, torch.Tensor)): |
|
|
struct_copy_two(tgt[i], src[i]) |
|
|
else: |
|
|
tgt[i] = src[i] |
|
|
elif isinstance(src, dict): |
|
|
for k in src: |
|
|
if isinstance(src[k], (tuple, list, dict, torch.Tensor)): |
|
|
struct_copy_two(tgt[k], src[k]) |
|
|
else: |
|
|
tgt[k] = src[k] |
|
|
elif isinstance(src, torch.Tensor): |
|
|
tgt.copy_(src, non_blocking=True) |
|
|
else: |
|
|
raise Exception(f"Expect top-level as container type but got: {type(src)}") |
|
|
|
|
|
|
|
|
class StaticBufferLoader: |
|
|
"""Load data to static buffers.""" |
|
|
|
|
|
def __init__(self, loader): |
|
|
self.loader = loader |
|
|
self.stream = torch.cuda.Stream() |
|
|
self.static = None |
|
|
|
|
|
def __iter__(self): |
|
|
for inputs in self.loader: |
|
|
if self.static is None: |
|
|
with torch.cuda.stream(self.stream): |
|
|
self.static = struct_copy_one(inputs) |
|
|
|
|
|
with torch.cuda.stream(self.stream): |
|
|
struct_copy_two(self.static, inputs) |
|
|
torch.cuda.current_stream().wait_stream(self.stream) |
|
|
yield self.static |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.loader) |
|
|
|
|
|
|
|
|
def get_lr(lr_scheduler): |
|
|
lrs = lr_scheduler.__orig_get_lr__() |
|
|
if not hasattr(lr_scheduler, "static_lrs"): |
|
|
lr_scheduler.static_lrs = lrs |
|
|
for i in range(len(lrs)): |
|
|
lr_scheduler.static_lrs[i].copy_(lrs[i]) |
|
|
return lr_scheduler.static_lrs |
|
|
|
|
|
|
|
|
def zero_grad(optimizer, *args, **kwargs): |
|
|
|
|
|
if torch.cuda.is_current_stream_capturing(): |
|
|
rank_zero_info("CUDAGraphCallback: set optimizer.zero_grad as nop during graph capturing.") |
|
|
else: |
|
|
optimizer.__orig_zero_grad__(*args, **kwargs) |
|
|
|
|
|
|
|
|
def to_tensor(self, value, name): |
|
|
|
|
|
|
|
|
|
|
|
value = value.clone().detach() if isinstance(value, torch.Tensor) else torch.tensor(value) |
|
|
if not torch.numel(value) == 1: |
|
|
raise ValueError( |
|
|
f"`self.log({name}, {value})` was called, but the tensor must have a single element." |
|
|
f" You can try doing `self.log({name}, {value}.mean())`" |
|
|
) |
|
|
value = value.squeeze() |
|
|
return value |
|
|
|
|
|
|
|
|
def get_optimizer_step(state): |
|
|
def optimizer_step( |
|
|
self, |
|
|
epoch, |
|
|
batch_idx, |
|
|
optimizer, |
|
|
optimizer_closure=None, |
|
|
) -> None: |
|
|
|
|
|
if not hasattr(optimizer, "support_set_to_none"): |
|
|
optimizer.support_set_to_none = is_param_in_hook_signature( |
|
|
optimizer.zero_grad, "set_to_none", explicit=True |
|
|
) |
|
|
if optimizer.support_set_to_none: |
|
|
zero_grad_kwargs = {"set_to_none": True} |
|
|
else: |
|
|
zero_grad_kwargs = {} |
|
|
|
|
|
if 0 <= state.current_iteration < state.capture_iteration or state.capture_iteration < 0: |
|
|
state.stream.wait_stream(torch.cuda.current_stream()) |
|
|
with torch.cuda.stream(state.stream): |
|
|
optimizer.zero_grad(**zero_grad_kwargs) |
|
|
self.__orig_optimizer_step__( |
|
|
epoch, |
|
|
batch_idx, |
|
|
optimizer, |
|
|
optimizer_closure=optimizer_closure, |
|
|
) |
|
|
torch.cuda.current_stream().wait_stream(state.stream) |
|
|
|
|
|
if state.current_iteration == state.capture_iteration: |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
time.sleep(1) |
|
|
rank_zero_info("CUDAGraphCallback: capturing CUDA graph for module %s.", self.__class__.__name__) |
|
|
with torch.cuda.graph(state.graph, stream=state.stream, capture_error_mode="global"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad(**zero_grad_kwargs) |
|
|
self.__orig_optimizer_step__( |
|
|
epoch, |
|
|
batch_idx, |
|
|
optimizer, |
|
|
optimizer_closure=optimizer_closure, |
|
|
) |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
if state.current_iteration >= state.capture_iteration >= 0: |
|
|
state.graph.replay() |
|
|
optimizer_closure._result = ClosureResult.from_training_step_output(state.output) |
|
|
|
|
|
|
|
|
if hasattr(self, "non_cuda_graph_capturable"): |
|
|
self.non_cuda_graph_capturable() |
|
|
|
|
|
state.current_iteration += 1 |
|
|
|
|
|
return optimizer_step |
|
|
|
|
|
|
|
|
def get_training_step(state): |
|
|
def training_step(self, batch): |
|
|
results = self.__orig_training_step__(batch) |
|
|
if state.output is None: |
|
|
state.output = struct_copy_one(results) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
struct_copy_two(state.output, results) |
|
|
return results |
|
|
|
|
|
return training_step |
|
|
|
|
|
|
|
|
def get_amp_autocast_init(state): |
|
|
def amp_autocast_init(self, *args, **kwargs): |
|
|
if "cache_enabled" not in kwargs: |
|
|
kwargs["cache_enabled"] = False |
|
|
if state.current_iteration == 0: |
|
|
rank_zero_info("CUDAGraphCallback: disable autocast cache.") |
|
|
return self.__orig_init__(*args, **kwargs) |
|
|
|
|
|
return amp_autocast_init |
|
|
|
|
|
|
|
|
def get_ddp_init(state): |
|
|
def init(self, *args, **kwargs): |
|
|
rank_zero_info("CUDAGraphCallback: init DDP on side stream.") |
|
|
with torch.cuda.stream(state.stream): |
|
|
self.__orig_init__(*args, **kwargs) |
|
|
|
|
|
return init |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CUDAGraphState: |
|
|
current_iteration: int = 0 |
|
|
capture_iteration: int = -1 |
|
|
stream: torch.cuda.Stream = None |
|
|
graph: torch.cuda.CUDAGraph = None |
|
|
output: Any = None |
|
|
|
|
|
|
|
|
class CUDAGraphCallback(Callback): |
|
|
"""Full iteration CUDA graph callback. |
|
|
|
|
|
Dataloader and LR scheduler are not included in the CUDA graph with this callback. |
|
|
""" |
|
|
|
|
|
def __init__(self, capture_iteration=-1): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
if 0 <= capture_iteration <= 11: |
|
|
raise Exception("Warmup must run at least 11 DDP-enabled eager iterations before capture.") |
|
|
if torch.distributed.is_initialized(): |
|
|
raise Exception("CUDAGraphCallback should be initialized before process group.") |
|
|
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" |
|
|
|
|
|
self.state = CUDAGraphState(capture_iteration=capture_iteration) |
|
|
|
|
|
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: |
|
|
"""Called when fit, validate, test, predict, or tune begins.""" |
|
|
if self.state.capture_iteration < 0: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.autocast.__orig_init__ = torch.autocast.__init__ |
|
|
torch.autocast.__init__ = get_amp_autocast_init(self.state) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DistributedDataParallel.__orig_init__ = DistributedDataParallel.__init__ |
|
|
DistributedDataParallel.__init__ = get_ddp_init(self.state) |
|
|
|
|
|
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: |
|
|
"""Called when fit, validate, test, predict, or tune ends.""" |
|
|
if self.state.capture_iteration < 0: |
|
|
return |
|
|
|
|
|
torch.autocast.__init__ = torch.autocast.__orig_init__ |
|
|
del torch.autocast.__orig_init__ |
|
|
|
|
|
DistributedDataParallel.__init__ = DistributedDataParallel.__orig_init__ |
|
|
del DistributedDataParallel.__orig_init__ |
|
|
|
|
|
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
"""Called when fit begins.""" |
|
|
if self.state.capture_iteration < 0: |
|
|
return |
|
|
|
|
|
if is_param_in_hook_signature(pl_module.training_step, "dataloader_iter", explicit=True): |
|
|
raise Exception( |
|
|
"Found `dataloader_iter` argument in the `training_step`. This is " |
|
|
"not supported by full iteration CUDA graph capturing yet since " |
|
|
"dataloader will be within the CUDA graph capturing range.\n" |
|
|
"Try to change `dataloader_iter` to `batch` and remove " |
|
|
"`next(dataloader_iter)` from `training_step`." |
|
|
) |
|
|
|
|
|
|
|
|
self.state.stream = torch.cuda.Stream() |
|
|
self.state.graph = torch.cuda.CUDAGraph() |
|
|
|
|
|
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
"""Called when fit ends.""" |
|
|
if self.state.capture_iteration < 0: |
|
|
return |
|
|
|
|
|
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
"""Called when the train begins.""" |
|
|
if self.state.capture_iteration < 0: |
|
|
return |
|
|
|
|
|
|
|
|
dataloader = trainer.fit_loop._combined_loader._iterables |
|
|
assert isinstance( |
|
|
dataloader, torch.utils.data.dataloader.DataLoader |
|
|
), f"Expect Dataloader type but got {type(dataloader)}" |
|
|
static_loader = StaticBufferLoader(dataloader) |
|
|
_mode = trainer.fit_loop._combined_loader._mode |
|
|
combined_loader = CombinedLoader(static_loader, mode=_mode) |
|
|
trainer.fit_loop.__orig_combined_loader__ = trainer.fit_loop._combined_loader |
|
|
trainer.fit_loop._combined_loader = combined_loader |
|
|
trainer.fit_loop._data_fetcher.setup(trainer.fit_loop._combined_loader) |
|
|
iter(trainer.fit_loop._data_fetcher) |
|
|
|
|
|
|
|
|
for optimizer in trainer.optimizers: |
|
|
assert isinstance(optimizer, torch.optim.Optimizer), f"Expect Optimizer type but got {type(optimizer)}" |
|
|
optimizer.__orig_zero_grad__ = optimizer.zero_grad |
|
|
optimizer.zero_grad = MethodType(zero_grad, optimizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for config in trainer.lr_scheduler_configs: |
|
|
assert isinstance( |
|
|
config.scheduler, torch.optim.lr_scheduler._LRScheduler |
|
|
), f"Expect _LRScheduler type but got {type(config.scheduler)}" |
|
|
config.scheduler.__orig_get_lr__ = config.scheduler.get_lr |
|
|
config.scheduler.get_lr = MethodType(get_lr, config.scheduler) |
|
|
|
|
|
|
|
|
LightningModule.__orig_to_tensor__ = LightningModule._LightningModule__to_tensor |
|
|
LightningModule._LightningModule__to_tensor = to_tensor |
|
|
|
|
|
|
|
|
pl_module.__orig_training_step__ = pl_module.training_step |
|
|
training_step = get_training_step(self.state) |
|
|
pl_module.training_step = MethodType(training_step, pl_module) |
|
|
|
|
|
|
|
|
pl_module.__orig_optimizer_step__ = pl_module.optimizer_step |
|
|
optimizer_step = get_optimizer_step(self.state) |
|
|
pl_module.optimizer_step = MethodType(optimizer_step, pl_module) |
|
|
|
|
|
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
"""Called when the train ends.""" |
|
|
if self.state.capture_iteration < 0: |
|
|
return |
|
|
|
|
|
trainer.fit_loop._combined_loader = trainer.fit_loop.__orig_combined_loader__ |
|
|
trainer.fit_loop._data_fetcher.setup(trainer.fit_loop._combined_loader) |
|
|
iter(trainer.fit_loop._data_fetcher) |
|
|
del trainer.fit_loop.__orig_combined_loader__ |
|
|
|
|
|
for optimizer in trainer.optimizers: |
|
|
optimizer.zero_grad = optimizer.__orig_zero_grad__ |
|
|
del optimizer.__orig_zero_grad__ |
|
|
|
|
|
for config in trainer.lr_scheduler_configs: |
|
|
config.scheduler.get_lr = config.scheduler.__orig_get_lr__ |
|
|
del config.scheduler.__orig_get_lr__ |
|
|
|
|
|
LightningModule._LightningModule__to_tensor = LightningModule.__orig_to_tensor__ |
|
|
del LightningModule.__orig_to_tensor__ |
|
|
|
|
|
pl_module.training_step = pl_module.__orig_training_step__ |
|
|
del pl_module.__orig_training_step__ |
|
|
|
|
|
pl_module.optimizer_step = pl_module.__orig_optimizer_step__ |
|
|
del pl_module.__orig_optimizer_step__ |
|
|
|
|
|
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
"""Called when the train epoch begins.""" |
|
|
pass |
|
|
|
|
|
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
|
"""Called when the train epoch ends. |
|
|
|
|
|
To access all batch outputs at the end of the epoch, either: |
|
|
|
|
|
1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR |
|
|
2. Cache data across train batch hooks inside the callback implementation to post-process in this hook. |
|
|
""" |
|
|
pass |
|
|
|
|
|
def on_train_batch_start( |
|
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int |
|
|
) -> None: |
|
|
"""Called when the train batch begins.""" |
|
|
pass |
|
|
|
|
|
def on_train_batch_end( |
|
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int |
|
|
) -> None: |
|
|
"""Called when the train batch ends. |
|
|
|
|
|
Note: |
|
|
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the |
|
|
loss returned from ``training_step``. |
|
|
""" |
|
|
pass |
|
|
|
|
|
def on_save_checkpoint( |
|
|
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] |
|
|
) -> None: |
|
|
r""" |
|
|
Called when saving a checkpoint to give you a chance to store anything else you might want to save. |
|
|
|
|
|
Args: |
|
|
trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance. |
|
|
pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance. |
|
|
checkpoint: the checkpoint dictionary that will be saved. |
|
|
""" |
|
|
|
|
|
|
|
|
if "optimizer_states" in checkpoint: |
|
|
for optimizer_state in checkpoint["optimizer_states"]: |
|
|
for k in list(optimizer_state.keys()): |
|
|
v = optimizer_state[k] |
|
|
if isinstance(v, MethodType) and hasattr(v, "__self__"): |
|
|
del optimizer_state[k] |
|
|
if "lr_schedulers" in checkpoint: |
|
|
for lr_scheduler in checkpoint["lr_schedulers"]: |
|
|
for k in list(lr_scheduler.keys()): |
|
|
v = lr_scheduler[k] |
|
|
if isinstance(v, MethodType) and hasattr(v, "__self__"): |
|
|
del lr_scheduler[k] |
|
|
|