Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Optional, Callable | |
| import xformers | |
| from omegaconf import OmegaConf | |
| import yaml | |
| from .util import classify_blocks | |
| def identify_blocks(block_list, name): | |
| block_name = None | |
| for block in block_list: | |
| if block in name: | |
| block_name = block | |
| break | |
| return block_name | |
| class MySelfAttnProcessor: | |
| def __init__(self, attention_op: Optional[Callable] = None): | |
| self.attention_op = attention_op | |
| def __call__(self, attn, hidden_states, query, key, value, attention_mask): | |
| # self.attn = attn | |
| self.key = key | |
| self.query = query | |
| # self.value = value | |
| # self.attention_mask = attention_mask | |
| # self.hidden_state = hidden_states.detach() | |
| # return hidden_states | |
| def record_qkv(self, attn, hidden_states, query, key, value, attention_mask): | |
| # self.attn = attn | |
| self.key = key | |
| self.query = query | |
| # self.value = value | |
| # # self.attention_mask = attention_mask | |
| # self.hidden_state = hidden_states.detach() | |
| # # import pdb; pdb.set_trace() | |
| def record_attn_mask(self, attn, hidden_states, query, key, value, attention_mask): | |
| self.attn = attn | |
| self.attention_mask = attention_mask | |
| def prep_unet_attention(unet,motion_gudiance_blocks): | |
| # replace the fwd function | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if "VersatileAttention" in module_name and classify_blocks(motion_gudiance_blocks, name): # the temporary attention in guidance blocks | |
| module.set_processor(MySelfAttnProcessor()) | |
| # print(module_name) | |
| return unet | |
| def get_self_attn_feat(unet, injection_config, config): | |
| hidden_state_dict = dict() | |
| query_dict = dict() | |
| key_dict = dict() | |
| value_dict = dict() | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if "CrossAttention" in module_name and 'attn1' in name and classify_blocks(injection_config.blocks, name=name): | |
| res = int(math.sqrt(module.processor.hidden_state.shape[1])) | |
| # import pdb; pdb.set_trace() | |
| bs = module.processor.hidden_state.shape[0] # 20 * 16 = 320 | |
| # block_name = identify_blocks(injection_config.blocks, name=name) | |
| # block_id = int(block_name.split('.')[-1]) | |
| # h = config.H // (32 * block_id) | |
| # w = config.W // (32 * block_id) | |
| hidden_state_dict[name] = module.processor.hidden_state.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) | |
| res = int(math.sqrt(module.processor.query.shape[1])) | |
| query_dict[name] = module.processor.query.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) | |
| key_dict[name] = module.processor.key.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) | |
| value_dict[name] = module.processor.value.cpu().permute(0, 2, 1).reshape(bs, -1, res, res) | |
| # import pdb; pdb.set_trace() | |
| # import pdb; pdb.set_trace() | |
| return hidden_state_dict, query_dict, key_dict, value_dict | |
| def clean_attn_buffer(unet): | |
| for name, module in unet.named_modules(): | |
| module_name = type(module).__name__ | |
| if module_name == "Attention" and 'attn' in name: | |
| if 'injection_config' in module.processor.__dict__.keys(): | |
| module.processor.injection_config = None | |
| if 'injection_mask' in module.processor.__dict__.keys(): | |
| module.processor.injection_mask = None | |
| if 'obj_index' in module.processor.__dict__.keys(): | |
| module.processor.obj_index = None | |
| if 'pca_weight' in module.processor.__dict__.keys(): | |
| module.processor.pca_weight = None | |
| if 'pca_weight_changed' in module.processor.__dict__.keys(): | |
| module.processor.pca_weight_changed = None | |
| if 'pca_info' in module.processor.__dict__.keys(): | |
| module.processor.pca_info = None | |
| if 'step' in module.processor.__dict__.keys(): | |
| module.processor.step = None | |