scy639's picture
Upload folder using huggingface_hub
2b534de verified
from .misc_4ddpm import *
from ldm.modules.attention import BasicTransformerBlock
class Bank:
def __init__(self,reader:nn.Module, writer:nn.Module) -> None:
"""
For the DFS model, mark every BasicTransformerBlock with name_4bank and isReader_4bank flags.
Similar logic applies for the writer while checking for BasicTransformerBlock instances.
"""
self.name2data = {}
self.name2count = {} # track how many times each name has been retrieved
self.WHEN_clear_a_field = 2 # clear the entry after this many gets
skip_names = [
'input_blocks.1.1.transformer_blocks.0',
'input_blocks.2.1.transformer_blocks.0',
# 'input_blocks.4.1.transformer_blocks.0',
# 'input_blocks.5.1.transformer_blocks.0',
# 'input_blocks.7.1.transformer_blocks.0',
# 'input_blocks.8.1.transformer_blocks.0',
##-----------all middle and output_blocks (everything outside input_blocks)----
'middle_block.1.transformer_blocks.0',
'output_blocks.3.1.transformer_blocks.0',
'output_blocks.4.1.transformer_blocks.0',
'output_blocks.5.1.transformer_blocks.0',
'output_blocks.6.1.transformer_blocks.0',
'output_blocks.7.1.transformer_blocks.0',
'output_blocks.8.1.transformer_blocks.0',
'output_blocks.9.1.transformer_blocks.0',
'output_blocks.10.1.transformer_blocks.0',
'output_blocks.11.1.transformer_blocks.0',
]
# print(f"{skip_names=}")
l_name = []
for name, _module in writer.named_modules():
if isinstance(_module, BasicTransformerBlock):
if DEBUG:
print(f"{name=}")
if name in skip_names:
# print(f"skip {name=}")
continue
_module.bank = self
_module.name4bank = name
_module.isReader_4bank = False
l_name.append(name)
# print(f"{l_name=}")
for name, _module in reader.named_modules():
if isinstance(_module, BasicTransformerBlock):
if name not in l_name:
continue
_module.bank = self
_module.name4bank = name
_module.isReader_4bank = True
def set(self,name,data):
self.name2data[name] = data
# self.name2count[name] = 0
def get(self,name):
printC('bank get', name)
if name in self.name2data:
if name not in self.name2count:
self.name2count[name] = 0
self.name2count[name] += 1
data = self.name2data[name]
if self.name2count[name] >= self.WHEN_clear_a_field: # once the max get count is reached, remove the entry
del self.name2data[name]
del self.name2count[name]
return data
raise Exception(f"{name}\n{list(self.name2data.keys())}")
return None
def clear(self,):
printC('clear')
printC('mean ct:', sum( self.name2count.values() ) / len( self.name2count.values() ) if len( self.name2count.values() )>0 else 'null' )
self.name2data.clear()
self.name2count.clear()