|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Dict, List, Optional, Union |
|
|
|
|
|
import lightning.pytorch as pl |
|
|
import torch |
|
|
from lightning.pytorch.callbacks import Callback |
|
|
|
|
|
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule |
|
|
from nemo.utils import logging |
|
|
|
|
|
|
|
|
def collect_precision(tensor: torch.Tensor) -> Dict[str, str]: |
|
|
"""Returns tensor's precision""" |
|
|
if isinstance(tensor, torch.Tensor): |
|
|
return {"Precision": str(tensor.dtype)} |
|
|
else: |
|
|
return {"Precision": "not-a-tensor"} |
|
|
|
|
|
|
|
|
def collect_precision_and_shape(tensor: torch.Tensor) -> Dict[str, str]: |
|
|
"""Returns tensor's shape & precision""" |
|
|
if isinstance(tensor, torch.Tensor): |
|
|
return {"Shape": str(tensor.shape), "Precision": str(tensor.dtype)} |
|
|
else: |
|
|
return {"Shape": "not-a-tensor", "Precision": "not-a-tensor"} |
|
|
|
|
|
|
|
|
class ParameterDebugger(Callback): |
|
|
""" |
|
|
Debugging tool to help inspect parameters and gradients at any callback event. |
|
|
|
|
|
This callback handles the boilerplate needed to iterate over the model parameters and gradients, |
|
|
and applies user specified functions to them. These functions can be used to log attributes or |
|
|
apply asserts on the param and grad tensors. Attributes are logged in a table, with a row for each parameter name. |
|
|
Default behavior is to log the precision and shapes of each parameter and its gradient. |
|
|
|
|
|
Args: |
|
|
param_fn: Function to apply to model parameters. Can be used to apply assertions on the tensor, |
|
|
or return a mapping of labels and values to log for each parameter. |
|
|
grad_fn: Function to apply to model gradients. Can be used to apply assertions on the tensor, |
|
|
or return a mapping of labels and values to log for each gradient. |
|
|
log_on_hooks: PTL callback hook name or list of hook names on which to apply param_fn and grad_fn. |
|
|
See `PTL docs <https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html#hooks>`_ for more info |
|
|
on callback hooks. Note that some hooks that occur before the model is constructed are invalid. |
|
|
|
|
|
Example: |
|
|
>>> fn = lambda x: {"Norm": str(x.norm(2).item())} |
|
|
>>> callback = ParameterDebugger(param_fn=fn, log_on_hooks=["on_train_start", "on_train_end"]) |
|
|
>>> trainer = Trainer(callbacks=[callback]) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
param_fn: Optional[Callable[[torch.Tensor], Optional[Dict[str, str]]]] = collect_precision_and_shape, |
|
|
grad_fn: Optional[Callable[[torch.Tensor], Optional[Dict[str, str]]]] = collect_precision, |
|
|
log_on_hooks: Union[List[str], str] = "on_train_start", |
|
|
): |
|
|
self.param_fn = param_fn |
|
|
self.grad_fn = grad_fn |
|
|
|
|
|
valid_hooks = set( |
|
|
[ |
|
|
"teardown", |
|
|
"on_fit_end", |
|
|
"on_sanity_check_start", |
|
|
"on_sanity_check_end", |
|
|
"on_train_batch_start", |
|
|
"on_train_batch_end", |
|
|
"on_train_epoch_start", |
|
|
"on_train_epoch_end", |
|
|
"on_validation_epoch_start", |
|
|
"on_validation_epoch_end", |
|
|
"on_test_epoch_start", |
|
|
"on_test_epoch_end", |
|
|
"on_predict_epoch_start", |
|
|
"on_predict_epoch_end", |
|
|
"on_validation_batch_start", |
|
|
"on_validation_batch_end", |
|
|
"on_test_batch_start", |
|
|
"on_test_batch_end", |
|
|
"on_predict_batch_start", |
|
|
"on_predict_batch_end", |
|
|
"on_train_start", |
|
|
"on_train_end", |
|
|
"on_validation_start", |
|
|
"on_validation_end", |
|
|
"on_test_start", |
|
|
"on_test_end", |
|
|
"on_predict_start", |
|
|
"on_predict_end", |
|
|
"on_exception", |
|
|
"on_save_checkpoint", |
|
|
"on_load_checkpoint", |
|
|
"on_before_backward", |
|
|
"on_after_backward", |
|
|
"on_before_optimizer_step", |
|
|
"on_before_zero_grad", |
|
|
] |
|
|
) |
|
|
|
|
|
if isinstance(log_on_hooks, str): |
|
|
log_on_hooks = [log_on_hooks] |
|
|
for hook_name in log_on_hooks: |
|
|
assert hook_name in valid_hooks, ( |
|
|
"Hook {} supplied to log_on_hooks is not valid or " "can not be used. Valid hooks are {}" |
|
|
).format(hook_name, valid_hooks) |
|
|
setattr(self, hook_name, self._apply_user_funcs) |
|
|
|
|
|
def _apply_user_funcs(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None: |
|
|
""" |
|
|
Iterate over model parameters, find gradient tensor, apply and collect outputs of |
|
|
param_fn and grad_fn, and log outputs in a table. |
|
|
""" |
|
|
|
|
|
def find_grad_tensor(param: torch.Tensor) -> Optional[torch.Tensor]: |
|
|
"""If using MCore optimizer, search the grad buckets for param's grad tensor.""" |
|
|
if not isinstance(getattr(pl_module, 'optim', None), MegatronOptimizerModule): |
|
|
return param.grad |
|
|
|
|
|
for buf in pl_module.buffers: |
|
|
if param in buf.param_to_bucket: |
|
|
return buf.param_to_bucket[param].grad_data |
|
|
|
|
|
return None |
|
|
|
|
|
names_col, params_output, grads_output = [], [], [] |
|
|
for param_name, param_tensor in pl_module.named_parameters(): |
|
|
grad_tensor = find_grad_tensor(param_tensor) |
|
|
short_name = param_name.replace("module.", "").replace(".weight", "") |
|
|
names_col.append(short_name) |
|
|
|
|
|
for tensor, fn, out_col in zip( |
|
|
[param_tensor, grad_tensor], [self.param_fn, self.grad_fn], [params_output, grads_output] |
|
|
): |
|
|
if fn is not None: |
|
|
if tensor is not None: |
|
|
out_col.append(fn(tensor)) |
|
|
else: |
|
|
out_col.append({}) |
|
|
|
|
|
|
|
|
param_keys, grad_keys = set([]), set([]) |
|
|
for output in params_output: |
|
|
if output is not None: |
|
|
param_keys.update(output.keys()) |
|
|
for output in grads_output: |
|
|
if output is not None: |
|
|
grad_keys.update(output.keys()) |
|
|
|
|
|
|
|
|
if any(param_keys) or any(grad_keys): |
|
|
from prettytable import PrettyTable |
|
|
|
|
|
debug_table = PrettyTable() |
|
|
debug_table.add_column("Parameter", names_col) |
|
|
|
|
|
for prefix, keys, output_list in zip( |
|
|
["Param ", "Grad "], [param_keys, grad_keys], [params_output, grads_output] |
|
|
): |
|
|
for k in keys: |
|
|
col_to_log = [] |
|
|
for output in output_list: |
|
|
if output is not None: |
|
|
col_to_log.append(output.get(k, None)) |
|
|
else: |
|
|
col_to_log.append(None) |
|
|
if col_to_log != []: |
|
|
debug_table.add_column(prefix + k, col_to_log) |
|
|
|
|
|
debug_table.align = "l" |
|
|
logging.info("\n" + debug_table.get_string()) |
|
|
|