| from torch import nn | |
| class CachedModule(nn.Module): | |
| def __init__(self, block, select_cache_step_func) -> None: | |
| super().__init__() | |
| self.block = block | |
| self.select_cache_step_func = select_cache_step_func | |
| self.cur_step = 0 | |
| self.cached_results = None | |
| self.enabled = True | |
| def __getattr__(self, name): | |
| try: | |
| return super().__getattr__(name) | |
| except AttributeError: | |
| return getattr(self.block, name) | |
| def if_cache(self): | |
| return self.select_cache_step_func(self.cur_step) and self.enabled | |
| def enable_cache(self): | |
| self.enabled = True | |
| def disable_cache(self): | |
| self.enabled = False | |
| self.cur_step = 0 | |
| def forward(self, *args, **kwargs): | |
| if not self.if_cache(): | |
| self.cached_results = self.block(*args, **kwargs) | |
| if self.enabled: | |
| self.cur_step += 1 | |
| return self.cached_results | |