| from .misc_4ddpm import * |
|
|
|
|
| from ldm.modules.attention import BasicTransformerBlock |
| class Bank: |
| def __init__(self,reader:nn.Module, writer:nn.Module) -> None: |
| """ |
| For the DFS model, mark every BasicTransformerBlock with name_4bank and isReader_4bank flags. |
| Similar logic applies for the writer while checking for BasicTransformerBlock instances. |
| """ |
| self.name2data = {} |
| self.name2count = {} |
| self.WHEN_clear_a_field = 2 |
| skip_names = [ |
| 'input_blocks.1.1.transformer_blocks.0', |
| 'input_blocks.2.1.transformer_blocks.0', |
| |
| |
| |
| |
| |
| 'middle_block.1.transformer_blocks.0', |
| 'output_blocks.3.1.transformer_blocks.0', |
| 'output_blocks.4.1.transformer_blocks.0', |
| 'output_blocks.5.1.transformer_blocks.0', |
| 'output_blocks.6.1.transformer_blocks.0', |
| 'output_blocks.7.1.transformer_blocks.0', |
| 'output_blocks.8.1.transformer_blocks.0', |
| 'output_blocks.9.1.transformer_blocks.0', |
| 'output_blocks.10.1.transformer_blocks.0', |
| 'output_blocks.11.1.transformer_blocks.0', |
| ] |
| |
| |
| l_name = [] |
| for name, _module in writer.named_modules(): |
| if isinstance(_module, BasicTransformerBlock): |
| if DEBUG: |
| print(f"{name=}") |
| if name in skip_names: |
| |
| continue |
| _module.bank = self |
| _module.name4bank = name |
| _module.isReader_4bank = False |
| l_name.append(name) |
| |
| |
| for name, _module in reader.named_modules(): |
| if isinstance(_module, BasicTransformerBlock): |
| if name not in l_name: |
| continue |
| _module.bank = self |
| _module.name4bank = name |
| _module.isReader_4bank = True |
| def set(self,name,data): |
| self.name2data[name] = data |
| |
| def get(self,name): |
| printC('bank get', name) |
| if name in self.name2data: |
| if name not in self.name2count: |
| self.name2count[name] = 0 |
| self.name2count[name] += 1 |
| data = self.name2data[name] |
| if self.name2count[name] >= self.WHEN_clear_a_field: |
| del self.name2data[name] |
| del self.name2count[name] |
| return data |
| raise Exception(f"{name}\n{list(self.name2data.keys())}") |
| return None |
| def clear(self,): |
| printC('clear') |
| printC('mean ct:', sum( self.name2count.values() ) / len( self.name2count.values() ) if len( self.name2count.values() )>0 else 'null' ) |
| self.name2data.clear() |
| self.name2count.clear() |
|
|