| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import gc |
| from typing import Any |
|
|
| import lightning.pytorch as pl |
| from nemo.utils import logging |
|
|
|
|
| class GarbageCollectionCallback(pl.Callback): |
| """Callback for synchronized manual Garbage Collection. This is required for distributed training |
| as all processes on different rank need to synchronize to garbage collect at the same time, without which |
| one process might hog or straggle all the rest of the processes. |
| |
| Migration from NeMo 1.0: |
| When mitrating from NeMo1, |
| - gc_interval = 0 implied no GC, simply do not add this callback to the trainer |
| - gc_interval > 0, this config is maps => gc_interval_train |
| |
| - env-var:NEMO_MANUAL_GC_IN_VALIDATION=0 or doesn't exist => Set gc_interval_val to a very high value that it does not practically run. |
| - env-var:NEMO_MANUAL_GC_IN_VALIDATION=1 => Set gc_interval_val to the same value as gc_interval |
| |
| Moving from boolean flag (NEMO_MANUAL_GC_IN_VALIDATION) to integer is to allow user to set a specific value based on the size of the |
| validation datasets. |
| |
| Note: This callback does not run gc at the start or the end of training or validation. |
| """ |
|
|
| def __init__(self, gc_interval_train, gc_interval_val) -> None: |
| """_summary_ |
| |
| Args: |
| gc_interval (int, mandatory): Number of global train steps at which garbage collection is done. |
| gc_interval_val (int, mandatory): Number of global validation steps at which garbage collection is done. |
| """ |
| assert gc_interval_train > 0, "gc_interval_train should be an integer value larger than 0." |
| assert gc_interval_val > 0, "gc_interval_val should be an integer value larger than 0." |
|
|
| super().__init__() |
| self.gc_interval_train = gc_interval_train |
| self.gc_interval_val = gc_interval_val |
| |
| gc.disable() |
| |
| self.validation_global_step = 0 |
|
|
| def on_train_batch_end( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: pl.utilities.types.STEP_OUTPUT, |
| batch: Any, |
| batch_idx: int, |
| ) -> None: |
| if trainer.global_step % self.gc_interval_train == 0: |
| logging.info(f"Running garbage collection at train global_step: {trainer.global_step}") |
| gc.collect() |
|
|
| def on_validation_batch_end( |
| self, |
| trainer: pl.Trainer, |
| pl_module: pl.LightningModule, |
| outputs: pl.utilities.types.STEP_OUTPUT, |
| batch: Any, |
| batch_idx: int, |
| ) -> None: |
| self.validation_global_step += 1 |
| if self.validation_global_step % self.gc_interval_val == 0: |
| logging.info(f"Running garbage collection at validation step: {self.validation_global_step}") |
| gc.collect() |
|
|