from torch import nn class CachedModule(nn.Module): def __init__(self, block, select_cache_step_func) -> None: super().__init__() self.block = block self.select_cache_step_func = select_cache_step_func self.cur_step = 0 self.cached_results = None self.enabled = True def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self.block, name) def if_cache(self): return self.select_cache_step_func(self.cur_step) and self.enabled def enable_cache(self): self.enabled = True def disable_cache(self): self.enabled = False self.cur_step = 0 def forward(self, *args, **kwargs): if not self.if_cache(): self.cached_results = self.block(*args, **kwargs) if self.enabled: self.cur_step += 1 return self.cached_results