|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
from typing import Callable, Optional |
|
|
from lightning.pytorch.callbacks import LambdaCallback |
|
|
|
|
|
|
|
|
class ModelCallback(LambdaCallback): |
|
|
""" |
|
|
A callback that extends LambdaCallback to intelligently handle function parameters. |
|
|
Functions can take either (trainer, pl_module), just (pl_module), or just (trainer). |
|
|
|
|
|
Supported parameter names: |
|
|
- trainer, pl_trainer |
|
|
- model, pl_model, pl_module, module |
|
|
|
|
|
Example: |
|
|
>>> # Using with torch.compile |
|
|
>>> callback = ModelCallback(on_train_start=torch.compile) |
|
|
>>> |
|
|
>>> # Using with thunder_compile |
|
|
>>> callback = ModelCallback(on_train_start=thunder_compile) |
|
|
>>> |
|
|
>>> # Mix different callbacks |
|
|
>>> callback = ModelCallback( |
|
|
... on_train_start=lambda model: torch.compile(model), |
|
|
... on_fit_start=lambda trainer, model: print(f"Starting fit with {model}") |
|
|
... ) |
|
|
""" |
|
|
|
|
|
TRAINER_PARAMS = {'trainer', 'pl_trainer'} |
|
|
MODEL_PARAMS = {'model', 'pl_model', 'pl_module', 'module'} |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
setup: Optional[Callable] = None, |
|
|
teardown: Optional[Callable] = None, |
|
|
on_fit_start: Optional[Callable] = None, |
|
|
on_fit_end: Optional[Callable] = None, |
|
|
on_sanity_check_start: Optional[Callable] = None, |
|
|
on_sanity_check_end: Optional[Callable] = None, |
|
|
on_train_batch_start: Optional[Callable] = None, |
|
|
on_train_batch_end: Optional[Callable] = None, |
|
|
on_train_epoch_start: Optional[Callable] = None, |
|
|
on_train_epoch_end: Optional[Callable] = None, |
|
|
on_validation_epoch_start: Optional[Callable] = None, |
|
|
on_validation_epoch_end: Optional[Callable] = None, |
|
|
on_test_epoch_start: Optional[Callable] = None, |
|
|
on_test_epoch_end: Optional[Callable] = None, |
|
|
on_validation_batch_start: Optional[Callable] = None, |
|
|
on_validation_batch_end: Optional[Callable] = None, |
|
|
on_test_batch_start: Optional[Callable] = None, |
|
|
on_test_batch_end: Optional[Callable] = None, |
|
|
on_train_start: Optional[Callable] = None, |
|
|
on_train_end: Optional[Callable] = None, |
|
|
on_validation_start: Optional[Callable] = None, |
|
|
on_validation_end: Optional[Callable] = None, |
|
|
on_test_start: Optional[Callable] = None, |
|
|
on_test_end: Optional[Callable] = None, |
|
|
on_exception: Optional[Callable] = None, |
|
|
on_save_checkpoint: Optional[Callable] = None, |
|
|
on_load_checkpoint: Optional[Callable] = None, |
|
|
on_before_backward: Optional[Callable] = None, |
|
|
on_after_backward: Optional[Callable] = None, |
|
|
on_before_optimizer_step: Optional[Callable] = None, |
|
|
on_before_zero_grad: Optional[Callable] = None, |
|
|
on_predict_start: Optional[Callable] = None, |
|
|
on_predict_end: Optional[Callable] = None, |
|
|
on_predict_batch_start: Optional[Callable] = None, |
|
|
on_predict_batch_end: Optional[Callable] = None, |
|
|
on_predict_epoch_start: Optional[Callable] = None, |
|
|
on_predict_epoch_end: Optional[Callable] = None, |
|
|
): |
|
|
|
|
|
callbacks = { |
|
|
name: self._wrap_func(func) |
|
|
for name, func in locals().items() |
|
|
if name != 'self' and name != '__class__' and func is not None |
|
|
} |
|
|
|
|
|
super().__init__(**callbacks) |
|
|
|
|
|
def _get_param_type(self, param_name: str) -> Optional[str]: |
|
|
"""Determine if a parameter name refers to trainer or model.""" |
|
|
param_name = param_name.lower() |
|
|
if param_name in self.TRAINER_PARAMS: |
|
|
return 'trainer' |
|
|
if param_name in self.MODEL_PARAMS: |
|
|
return 'model' |
|
|
return None |
|
|
|
|
|
def _wrap_func(self, func: Callable) -> Callable: |
|
|
"""Wraps a function to handle parameter inspection and passing.""" |
|
|
sig = inspect.signature(func) |
|
|
params = sig.parameters |
|
|
|
|
|
def wrapped(trainer, pl_module, *args, **kwargs): |
|
|
call_args = {} |
|
|
|
|
|
for param_name, param in params.items(): |
|
|
param_type = self._get_param_type(param_name) |
|
|
|
|
|
if param_type == 'trainer': |
|
|
call_args[param_name] = trainer |
|
|
elif param_type == 'model': |
|
|
call_args[param_name] = pl_module |
|
|
else: |
|
|
|
|
|
if len(params) == 1: |
|
|
call_args[param_name] = pl_module |
|
|
elif len(params) == 2: |
|
|
if len(call_args) == 0: |
|
|
call_args[param_name] = trainer |
|
|
else: |
|
|
call_args[param_name] = pl_module |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unable to determine parameter mapping for '{param_name}'. " |
|
|
f"Please use recognized parameter names: " |
|
|
f"trainer/pl_trainer for trainer, " |
|
|
f"model/pl_model/pl_module/module for model." |
|
|
) |
|
|
|
|
|
try: |
|
|
return func(**call_args) |
|
|
except TypeError as e: |
|
|
raise TypeError( |
|
|
f"Failed to call callback function {func.__name__ if hasattr(func, '__name__') else func}. " |
|
|
f"Attempted to pass arguments: {call_args.keys()}. Error: {str(e)}" |
|
|
) from e |
|
|
|
|
|
return wrapped |
|
|
|