File size: 5,648 Bytes
c50dde6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 | import torch
import math
import torch.nn.functional as F
DIFFUSION_LAYERS = [
'down_blocks[0].attentions[0].transformer_blocks[0].attn1', # 0
'down_blocks[0].attentions[0].transformer_blocks[0].attn2', # 1
'down_blocks[0].attentions[1].transformer_blocks[0].attn1', # 2
'down_blocks[0].attentions[1].transformer_blocks[0].attn2', # 3
'down_blocks[1].attentions[0].transformer_blocks[0].attn1', # 4
'down_blocks[1].attentions[0].transformer_blocks[0].attn2', # 5
'down_blocks[1].attentions[1].transformer_blocks[0].attn1', # 6
'down_blocks[1].attentions[1].transformer_blocks[0].attn2', # 7
'down_blocks[2].attentions[0].transformer_blocks[0].attn1', # 8
'down_blocks[2].attentions[0].transformer_blocks[0].attn2', # 9
'down_blocks[2].attentions[1].transformer_blocks[0].attn1', # 10
'down_blocks[2].attentions[1].transformer_blocks[0].attn2', # 11
'mid_block.attentions[0].transformer_blocks[0].attn1',
'mid_block.attentions[0].transformer_blocks[0].attn2',
'up_blocks[1].attentions[0].transformer_blocks[0].attn1', # -18
"up_blocks[1].attentions[0].transformer_blocks[0].attn2", # -17
'up_blocks[1].attentions[1].transformer_blocks[0].attn1', # -16
"up_blocks[1].attentions[1].transformer_blocks[0].attn2", # -15
'up_blocks[1].attentions[2].transformer_blocks[0].attn1', # -14
"up_blocks[1].attentions[2].transformer_blocks[0].attn2", # -13
'up_blocks[2].attentions[0].transformer_blocks[0].attn1', # -12
"up_blocks[2].attentions[0].transformer_blocks[0].attn2", # -11
'up_blocks[2].attentions[1].transformer_blocks[0].attn1', # -10
"up_blocks[2].attentions[1].transformer_blocks[0].attn2", # -9
'up_blocks[2].attentions[2].transformer_blocks[0].attn1', # -8
'up_blocks[2].attentions[2].transformer_blocks[0].attn2', # -7
"up_blocks[3].attentions[0].transformer_blocks[0].attn1", # -6
'up_blocks[3].attentions[0].transformer_blocks[0].attn2', # -5
"up_blocks[3].attentions[1].transformer_blocks[0].attn1", # -4
'up_blocks[3].attentions[1].transformer_blocks[0].attn2', # -3
"up_blocks[3].attentions[2].transformer_blocks[0].attn1", # -2
'up_blocks[3].attentions[2].transformer_blocks[0].attn2', # -1
]
class AttnProcessorForCallBack:
def __init__(self, model, layer):
self.model = model
self.layer = layer
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# ! add code
h=w=int(math.sqrt(query.shape[1]))
low_res_query=query.clone().transpose(-2,-1).view(batch_size, -1, h,w)
low_res_key=key.clone().transpose(-2,-1).view(batch_size, -1,h,w)
low_res_query_ds = F.interpolate(low_res_query, size=(35, 35), mode='bilinear', align_corners=False)
low_res_key_ds = F.interpolate(low_res_key, size=(35, 35), mode='bilinear', align_corners=False)
low_res_query_ds=low_res_query_ds.flatten(start_dim=-2).transpose(-2,-1)
low_res_key_ds=low_res_key_ds.flatten(start_dim=-2).transpose(-2,-1)
low_res_query_ds = attn.head_to_batch_dim(low_res_query_ds)
low_res_key_ds = attn.head_to_batch_dim(low_res_key_ds)
low_res_attention_probs = attn.get_attention_scores(low_res_query_ds, low_res_key_ds, attention_mask)
# ! add code
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
head_size = attn.heads
batch_size, q_len, v_len = attention_probs.shape
attention_probs = attention_probs.reshape(batch_size // head_size, head_size, q_len, v_len)
self.model.attention_maps[self.layer] = low_res_attention_probs
return hidden_states
|