| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
| from typing import Dict, List, Tuple, Union |
|
|
| from lightning.pytorch.callbacks.callback import Callback |
|
|
| from nemo.lightning.io.mixin import IOMixin |
|
|
| ScheduleValue = Union[int, List[int], Tuple[int, int]] |
|
|
|
|
| def _resolve_attr(root, path: str): |
| """ |
| Traverse dotted attribute path (“encoder.layer1”) from root. |
| """ |
| m = root |
| for part in path.split('.'): |
| m = getattr(m, part) |
| return m |
|
|
|
|
| def make_start_end(name: str, spec: Union[int, list[int]]): |
| """Translates spec to start/end steps, for example, |
| spec = -1 -> (0, inf) |
| spec = N (int>0) -> (N, int) |
| spec = [start, end] -> (start, end) |
| |
| |
| Args: |
| name (str): name layer |
| spec (Union[int, list[int]]): spec. |
| |
| Raises: |
| ValueError: if spec is not int/list/tuple |
| |
| Returns: |
| tuple(int, int): returns start/end |
| """ |
| |
| if isinstance(spec, int): |
| if spec == -1: |
| start, end = 0, math.inf |
| else: |
| start, end = 0, spec - 1 |
| elif isinstance(spec, (list, tuple)) and len(spec) == 2: |
| start, end = spec |
| start = max(start, 0) |
| if end < 0: |
| end = math.inf |
| else: |
| raise ValueError(f"Invalid schedule for '{name}': {spec}") |
| return start, end |
|
|
|
|
| class LayerFreezer(Callback, IOMixin): |
| """ |
| Freezes sub-modules of a LightningModule based on the list provided. The list of layers should |
| be the full FQN. |
| |
| Instantiate |
| ----------- |
| # to keep layers frozen for all training |
| callback = LayerFreezer(['layer1', 'layer2',]) |
| # for some steps |
| callback = LayerFreezer({'layer1': 10, 'layer2': (10, 100)}) |
| |
| trainer = pl.Trainer(callbacks=[callback], ...) |
| """ |
|
|
| def __init__(self, schedule: Union[List[str], Dict[str, ScheduleValue]]): |
| """ |
| Args |
| ---- |
| schedule: Union[list, dict] |
| - dict |
| key = attribute path of sub-module inside LightningModule |
| value = one of |
| : -1 -> frozen for entire run |
| : N (int>0) -> frozen for first N steps (0..N-1) |
| : [start, end] -> frozen if start <= step <= end |
| use -1 for "until end of training" |
| - list: |
| key = attribute path of sub-module inside LightningModule |
| value = -1 (hardcoded; use a dict if you want to specify manually). |
| """ |
| super().__init__() |
| assert isinstance(schedule, (list, dict)), type(schedule) |
| if isinstance(schedule, list): |
| schedule = {item: -1 for item in schedule} |
|
|
| self.schedule: Dict[str, Tuple[int, float]] = {} |
| self.frozen_state: Dict[str, bool] = {} |
|
|
| for name, spec in schedule.items(): |
| self.schedule[name] = make_start_end(name, spec) |
|
|
| |
| |
| |
| @staticmethod |
| def _resolve_attr(root, path: str): |
| """ |
| Traverse dotted attribute path (“encoder.layer1”) from root. |
| """ |
| m = root |
| for part in path.split('.'): |
| m = getattr(m, part) |
| return m |
|
|
| def _apply_freeze(self, module, freeze: bool): |
| """ |
| Enable/disable gradients + switch (eval/train) mode. |
| """ |
| for p in module.parameters(): |
| p.requires_grad = not freeze |
| |
| module.eval() if freeze else module.train() |
|
|
| |
| |
| |
| def on_train_batch_start(self, trainer, pl_module, *_): |
| """ |
| freezes layers listed on frozen_layers |
| Args: |
| trainer (Trainer): the trainer |
| pl_module (LightningModule): model |
| """ |
| step = trainer.global_step |
|
|
| for name, (start, end) in self.schedule.items(): |
| should_be_frozen = start <= step <= end |
| |
| if self.frozen_state.get(name, None) == should_be_frozen: |
| continue |
|
|
| submod = self._resolve_attr(pl_module, name) |
| self._apply_freeze(submod, should_be_frozen) |
| self.frozen_state[name] = should_be_frozen |
|
|
| def on_train_start(self, trainer, pl_module): |
| """ |
| on_train_start |
| In case we resume from checkpoint, re-establish correct state |
| Args: |
| trainer (Trainer): the trainer |
| pl_module (LightningModule): model |
| """ |
| self.on_train_batch_start(trainer, pl_module, None, 0) |
|
|