JitTransform Class ================== .. currentmodule:: nemo.lightning.pytorch.callbacks .. autoclass:: JitTransform :members: :undoc-members: :show-inheritance: Description ----------- The `JitTransform` class is a callback for PyTorch Lightning that applies Just-In-Time (JIT) compilation to PyTorch models. It leverages either `torch.compile` or the Thunder compiler to optimize model performance during training. This transformation is particularly useful for enhancing the efficiency of models with complex architectures or when deploying models in production environments. Parameters ---------- .. py:attribute:: config **Type:** `JitConfig` The configuration object specifying how JIT compilation should be applied. It includes options to select the compiler (`torch.compile` or Thunder), module selectors to target specific parts of the model, and additional parameters for profiling and compiler behavior. Usage Example ------------- Below is an example of how to integrate `JitTransform` into a PyTorch Lightning `Trainer`: .. code-block:: python from nemo.lightning.pytorch.callbacks import JitTransform from nemo.lightning.pytorch.callbacks import JitConfig from pytorch_lightning import Trainer # Define the JIT configuration jit_config = JitConfig( use_torch=True, torch_kwargs={'mode': 'default'}, module_selector='*.layer*' ) # Initialize the JitTransform callback jit_callback = JitTransform(config=jit_config) # Initialize the Trainer with the JitTransform callback trainer = Trainer(callbacks=[jit_callback]) # Proceed with training as usual trainer.fit(model) Attributes ---------- .. py:attribute:: config The `JitConfig` instance containing the JIT compilation settings. Methods ------- .. py:method:: on_train_epoch_start(trainer, pl_module) Called at the start of each training epoch. This method applies JIT compilation to the specified modules within the model based on the provided configuration. **Parameters:** - `trainer` (`pl.Trainer`): The PyTorch Lightning trainer instance. - `pl_module` (`pl.LightningModule`): The PyTorch Lightning module being trained. Detailed Description -------------------- The `JitTransform` callback applies JIT compilation at the beginning of each training epoch. It identifies the target modules within the model based on the `module_selector` patterns provided in the `JitConfig`. Depending on the configuration, it either uses `torch.compile` or Thunder to compile the selected modules, enhancing their execution performance. The callback ensures that compilation occurs only once by setting a `_compiled` flag on the `pl_module`. This prevents redundant compilation in subsequent epochs. Additionally, it supports profiling through Thunder's profiler if enabled in the configuration. Constraints ----------- - **Mutually Exclusive Options:** The `use_torch` and `use_thunder` options in `JitConfig` are mutually exclusive. Enabling both simultaneously will raise an assertion error. - **Module Selection:** The `module_selector` supports simple wildcard patterns. Complex selectors may require extending the matching logic. See Also -------- - :class:`JitConfig` - `PyTorch Lightning Callbacks `_ - `torch.compile `_