File size: 6,971 Bytes
b386992 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass, field
import torch
from lightning.pytorch.callbacks.callback import Callback
from nemo.lightning.io.mixin import IOMixin
def extract_module_attr_name(pl_module: "pl.LightningModule") -> str:
"""Extracts the held nn.Module from a pl.LightningModule, will try "module", "model", or fail.
Args:
pl_module (pl.LightningModule): the LightningModule used in training.
Raises:
ValueError: if the pl_module has neither a .mdoel or .module
Returns:
str: the attr-name of the nn.Module
"""
if hasattr(pl_module, 'module'):
return 'module'
elif hasattr(pl_module, 'model'):
return 'model'
else:
raise ValueError("Expected lightning_module to have a .model or .module attr.")
def listify(x):
"""Wraps input in a list, if not already a list.
Args:
x (Anything): the input, can be anything.
Returns:
Anything | list(Anything): Anything (if it's already a list) o/w list(Anything)
"""
if not isinstance(x, list):
return [x]
return x
def get_modules_from_selector(model, module_selector):
"""Iterator over model's modules whose FQN match the module_selector.
Args:
model (nn.Module): the model to iterate over.
module_selector (str): module selector, if empty or '*' will return the whole model. If
there's an asterisk in the name will match it as a regexp.
Raises:
AttributeError: if the user provides an invalid selector.
AttributeError: if user's selector selects a non-nn.Module attribute.
Yields:
Iterator(nn.Module): iterator over modules whose FQN matches module_selector
"""
if module_selector is None or module_selector == '' or module_selector == '*':
yield model
return
assert isinstance(module_selector, str), module_selector
atoms: List[str] = module_selector.split('.')
tmp = model
for i, item in enumerate(atoms):
if '*' in item:
# handle wildcard selector
# TODO(@akoumparouli): support more complex selectors e.g. net_b.*.net_c.*.conv
for name, module in tmp.named_children():
if re.match(item.replace('*', '.*'), name):
yield module
return
if not hasattr(tmp, item):
raise AttributeError(tmp._get_name() + " has no " "attribute `" + item + "`")
tmp = getattr(tmp, item)
if not isinstance(tmp, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
yield tmp
def compile_module(config, module):
"""Jit-compiles an nn.Module
Args:
config (JitConfig): jit config
module (nn.Module): the module to be compiled
Returns:
nn.Module: the (potentially) compiled module
"""
if config.use_torch:
module.compile(**config.torch_kwargs)
return True
elif config.use_thunder:
import thunder
import thunder.dynamo
from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform
# With this setting, Dynamo Graphs inline all the modules (so Dynamo FXGraph just
# consists of `call_function` nodes only and no `call_module` node.
# This is the default setting in PyTorch 2.5 onwards
# (see https://github.com/pytorch/pytorch/pull/131275)
torch._dynamo.config.inline_inbuilt_nn_modules = True
xforms: list = [NvtxProfileTransform()] if config.profile_thunder else []
module.compile(backend=thunder.dynamo.ThunderCompiler(transforms=xforms))
return True
else:
return False
@dataclass
class JitConfig:
"""Config POD for Jit transforms (e.g. torch.compile or thunder)
Options:
- module_selector (str): reg-exp to match modules to apply JitTransform to, useful for multi-trunk
models where you want to apply it on one of them only. If empty will apply transform to root
module.
- use_torch (bool): whether to use torch.compile or not.
- torch_kwargs (dict): kwargs to pass to torch.compile.
- use_thunder (bool): whether to use thunder or not.
- profile_thunder (bool): toggle for thunder's profiler.
"""
module_selector: str = ''
use_torch: bool = False
torch_kwargs: dict = field(default_factory=dict)
use_thunder: bool = False
profile_thunder: bool = False
def __post_init__(self):
assert not (self.use_torch and self.use_thunder), "use_torch cannot be used at the same time with use_thunder"
class JitTransform(Callback, IOMixin):
"""
Apply JIT-compling on PyTorch model
Args:
config (JitConfig): The jit-compiler config to use.
Example:
>>> from nemo.lightning.pytorch.callbacks import JitTransform
>>> trainer = Trainer(callbacks=[JitTransform(JitConfig(use_torch=True))])
"""
def __init__(self, config: JitConfig):
assert config is not None
self.config = config
assert not (self.config.use_torch and self.config.use_thunder)
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Jit-compiles the model at the start of the epoch.
While other events such as on_train_start are more suitable, we use on_train_epoch_start
since that is what is used in peft (we want to jit after adding the adapters).
Args:
trainer (pl.Trainer): PTL trainer
pl_module (pl.LightningModule): PTL module
"""
if self.config is None:
return
if not self.config.use_thunder and not self.config.use_torch:
return
attr_name = extract_module_attr_name(pl_module)
model = getattr(pl_module, attr_name)
if getattr(pl_module, '_compiled', False) == True:
return
# TODO(@akoumparouli): you want to concatenate (via regex OR-operator) all expressions
# and trigger the compile if anyone matches, instead of iterating over all O(N^2).
compiled = False
for config in listify(self.config):
for module in get_modules_from_selector(model, config.module_selector):
compiled |= compile_module(config, module)
setattr(pl_module, '_compiled', compiled)
|