File size: 6,971 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass, field

import torch
from lightning.pytorch.callbacks.callback import Callback

from nemo.lightning.io.mixin import IOMixin


def extract_module_attr_name(pl_module: "pl.LightningModule") -> str:
    """Extracts the held nn.Module from a pl.LightningModule, will try "module", "model", or fail.

    Args:
        pl_module (pl.LightningModule): the LightningModule used in training.

    Raises:
        ValueError: if the pl_module has neither a .mdoel or .module

    Returns:
        str: the attr-name of the nn.Module
    """
    if hasattr(pl_module, 'module'):
        return 'module'
    elif hasattr(pl_module, 'model'):
        return 'model'
    else:
        raise ValueError("Expected lightning_module to have a .model or .module attr.")


def listify(x):
    """Wraps input in a list, if not already a list.

    Args:
        x (Anything): the input, can be anything.

    Returns:
        Anything | list(Anything): Anything (if it's already a list) o/w list(Anything)
    """
    if not isinstance(x, list):
        return [x]
    return x


def get_modules_from_selector(model, module_selector):
    """Iterator over model's modules whose FQN match the module_selector.

    Args:
        model (nn.Module): the model to iterate over.
        module_selector (str): module selector, if empty or '*' will return the whole model. If
        there's an asterisk in the name will match it as a regexp.

    Raises:
        AttributeError: if the user provides an invalid selector.
        AttributeError: if user's selector selects a non-nn.Module attribute.

    Yields:
        Iterator(nn.Module): iterator over modules whose FQN matches module_selector
    """
    if module_selector is None or module_selector == '' or module_selector == '*':
        yield model
        return

    assert isinstance(module_selector, str), module_selector
    atoms: List[str] = module_selector.split('.')
    tmp = model

    for i, item in enumerate(atoms):
        if '*' in item:
            # handle wildcard selector
            # TODO(@akoumparouli): support more complex selectors e.g. net_b.*.net_c.*.conv
            for name, module in tmp.named_children():
                if re.match(item.replace('*', '.*'), name):
                    yield module
            return

        if not hasattr(tmp, item):
            raise AttributeError(tmp._get_name() + " has no " "attribute `" + item + "`")
        tmp = getattr(tmp, item)

        if not isinstance(tmp, torch.nn.Module):
            raise AttributeError("`" + item + "` is not " "an nn.Module")

    yield tmp


def compile_module(config, module):
    """Jit-compiles an nn.Module

    Args:
        config (JitConfig): jit config
        module (nn.Module): the module to be compiled

    Returns:
        nn.Module: the (potentially) compiled module
    """
    if config.use_torch:
        module.compile(**config.torch_kwargs)
        return True
    elif config.use_thunder:
        import thunder
        import thunder.dynamo
        from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform

        # With this setting, Dynamo Graphs inline all the modules (so Dynamo FXGraph just
        # consists of `call_function` nodes only and no `call_module` node.
        # This is the default setting in PyTorch 2.5 onwards
        # (see https://github.com/pytorch/pytorch/pull/131275)
        torch._dynamo.config.inline_inbuilt_nn_modules = True

        xforms: list = [NvtxProfileTransform()] if config.profile_thunder else []
        module.compile(backend=thunder.dynamo.ThunderCompiler(transforms=xforms))
        return True
    else:
        return False


@dataclass
class JitConfig:
    """Config POD for Jit transforms (e.g. torch.compile or thunder)
    Options:
    - module_selector (str): reg-exp to match modules to apply JitTransform to, useful for multi-trunk
      models where you want to apply it on one of them only. If empty will apply transform to root
      module.
    - use_torch (bool): whether to use torch.compile or not.
    - torch_kwargs (dict): kwargs to pass to torch.compile.
    - use_thunder (bool): whether to use thunder or not.
    - profile_thunder (bool): toggle for thunder's profiler.
    """

    module_selector: str = ''
    use_torch: bool = False
    torch_kwargs: dict = field(default_factory=dict)
    use_thunder: bool = False
    profile_thunder: bool = False

    def __post_init__(self):
        assert not (self.use_torch and self.use_thunder), "use_torch cannot be used at the same time with use_thunder"


class JitTransform(Callback, IOMixin):
    """
    Apply JIT-compling on PyTorch model

    Args:
        config (JitConfig): The jit-compiler config to use.

    Example:
        >>> from nemo.lightning.pytorch.callbacks import JitTransform
        >>> trainer = Trainer(callbacks=[JitTransform(JitConfig(use_torch=True))])
    """

    def __init__(self, config: JitConfig):
        assert config is not None
        self.config = config
        assert not (self.config.use_torch and self.config.use_thunder)

    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        """Jit-compiles the model at the start of the epoch.
        While other events such as on_train_start are more suitable, we use on_train_epoch_start
        since that is what is used in peft (we want to jit after adding the adapters).

        Args:
            trainer (pl.Trainer): PTL trainer
            pl_module (pl.LightningModule): PTL module
        """
        if self.config is None:
            return
        if not self.config.use_thunder and not self.config.use_torch:
            return

        attr_name = extract_module_attr_name(pl_module)
        model = getattr(pl_module, attr_name)

        if getattr(pl_module, '_compiled', False) == True:
            return

        # TODO(@akoumparouli): you want to concatenate (via regex OR-operator) all expressions
        # and trigger the compile if anyone matches, instead of iterating over all O(N^2).
        compiled = False
        for config in listify(self.config):
            for module in get_modules_from_selector(model, config.module_selector):
                compiled |= compile_module(config, module)

        setattr(pl_module, '_compiled', compiled)