| | import logging |
| | from transformers import BartForSequenceClassification |
| |
|
| | logger = logging.getLogger("ModelLogger") |
| |
|
| | class ModifiedBartForSequenceClassificationWithHook(BartForSequenceClassification): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.register_forward_hook(self.forward_hook) |
| |
|
| | @staticmethod |
| | def forward_hook(module, inputs, outputs): |
| | logger.info(f"Called forward method of {module.__class__.__name__}") |
| | print(f"Called forward method of {module.__class__.__name__}") |
| |
|
| |
|
| | import logging |
| | from transformers import AutoModel |
| |
|
| | logger = logging.getLogger("ModelLogger") |
| |
|
| | class ModifiedAutoModelWithHook(AutoModel): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.register_forward_hook(self.forward_hook) |
| |
|
| | @staticmethod |
| | def forward_hook(module, inputs, outputs): |
| | logger.info(f"Called forward method of {module.__class__.__name__}") |
| | print(f"Called forward method of {module.__class__.__name__}") |
| |
|
| |
|