XiangpengYang's picture
first commit
5602c9a
"""
register the attention controller into the UNet of stable diffusion
Build a customized attention function `_attention'
Replace the original attention function with `forward' and `spatial_temporal_forward' in attention_controlled_forward function
Most of spatial_temporal_forward is directly copy from `video_diffusion/models/attention.py'
TODO FIXME: merge redundant code with attention.py
"""
from einops import rearrange
import torch
import torch.nn.functional as F
import math
from diffusers.utils.import_utils import is_xformers_available
import numpy as np
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def register_attention_control(model, controller, text_cond, clip_length, height, width, ddim_inversion):
"Connect a model with a controller"
def attention_controlled_forward(self, place_in_unet, attention_type='cross'):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def _attention(query, key, value, is_cross, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()
# print("query",query.shape)
# print("key",key.shape)
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
#print("attention_scores",attention_scores.shape)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
# START OF CORE FUNCTION
# if not ddim_inversion:
attention_probs = controller(reshape_batch_dim_to_temporal_heads(attention_scores),
is_cross, place_in_unet)
attention_probs = reshape_temporal_heads_to_batch_dim(attention_probs)
# END OF CORE FUNCTION
attention_probs = attention_probs.softmax(dim=-1)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
# reshape hidden_states
hidden_states = reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def reshape_temporal_heads_to_batch_dim(tensor):
head_size = self.heads
tensor = rearrange(tensor, " b h s t -> (b h) s t ", h = head_size)
return tensor
def reshape_batch_dim_to_temporal_heads(tensor):
head_size = self.heads
tensor = rearrange(tensor, "(b h) s t -> b h s t", h = head_size)
return tensor
def reshape_heads_to_batch_dim3(tensor):
batch_size1, batch_size2, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size1, batch_size2, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 3, 1, 2, 4)
return tensor
def reshape_heads_to_batch_dim(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def _memory_efficient_attention_xformers(query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def forward(hidden_states, encoder_hidden_states=None, attention_mask=None):
# hidden_states: torch.Size([16, 4096, 320])
# encoder_hidden_states: torch.Size([16, 77, 768])
is_cross = encoder_hidden_states is not None
#encoder_hidden_states = encoder_hidden_states
text_cond_frames = text_cond.repeat_interleave(clip_length, 0) # wrong implementation text_cond.repeat(clip_length,1,1)
######for debug######
# text_cond_repeat_interleave = text_cond.repeat_interleave(clip_length, 0)
# print("after repeat interleave", text_cond_repeat_interleave.shape, text_cond_repeat_interleave.view(-1)[:20])
# text_cond_repeat = text_cond.repeat(clip_length,1,1)
# print("First 20 elements after repeat:", text_cond_repeat.shape, text_cond_repeat.view(-1)[:20])
######for debug######
encoder_hidden_states = text_cond_frames
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
query = reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = reshape_heads_to_batch_dim(key)
value = reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = text_cond_frames if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = reshape_heads_to_batch_dim(key)
value = reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
if self._use_memory_efficient_attention_xformers and query.shape[-2] > ((height//2) * (width//2)):
# for large attention map of 64X64, use xformers to save memory
hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
hidden_states = _attention(query, key, value, is_cross=is_cross, attention_mask=attention_mask)
# else:
# hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
#dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def spatial_temporal_forward(
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
clip_length: int = None,
SparseCausalAttention_index: list = [-1, 'first'] #list = [0]
):
"""
Most of spatial_temporal_forward is directly copy from `video_diffusion.models.attention.SparseCausalAttention'
We add two modification
1. use self defined attention function that is controlled by AttentionControlEdit module
2. remove the dropout to reduce randomness
FIXME: merge redundant code with attention.py
"""
if (
self.added_kv_proj_dim is not None
or encoder_hidden_states is not None
or attention_mask is not None
):
raise NotImplementedError
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
query = reshape_heads_to_batch_dim(query)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
if clip_length is not None:
key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)
value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)
# *********************** Start of Spatial-temporal attention **********
frame_index_list = []
if len(SparseCausalAttention_index) > 0:
for index in SparseCausalAttention_index:
if isinstance(index, str):
if index == 'first':
frame_index = [0] * clip_length
if index == 'last':
frame_index = [clip_length-1] * clip_length
if (index == 'mid') or (index == 'middle'):
frame_index = [int((clip_length-1)//2)] * clip_length
else:
assert isinstance(index, int), 'relative index must be int'
frame_index = torch.arange(clip_length) + index
frame_index = frame_index.clip(0, clip_length-1)
frame_index_list.append(frame_index)
# print("frame_index_list",frame_index_list) [bz, frame, 4096, 320]
key = torch.cat([ key[:, frame_index] for frame_index in frame_index_list #[bz, frame, 8192, 320])
], dim=2)
value = torch.cat([ value[:, frame_index] for frame_index in frame_index_list
], dim=2)
# *********************** End of Spatial-temporal attention **********
key = rearrange(key, "b f d c -> (b f) d c", f=clip_length)
value = rearrange(value, "b f d c -> (b f) d c", f=clip_length)
# print("key after rearrange",key.shape)
# print("value after rearrange",value.shape)
key = reshape_heads_to_batch_dim(key)
value = reshape_heads_to_batch_dim(value)
# print("query after head to batch dim",query.shape)
# print("key after head to batch dim",key.shape)
if torch.isnan(query.reshape(-1)[0]):
print("nan value query",query.reshape(-1)[:10])
print("nan value key",key.reshape(-1)[:10])
exit()
# print("query after reshape heads to batch ",query.shape)
# print("key after reshape heads to batch",key.shape)
if self._use_memory_efficient_attention_xformers and query.shape[-2] > ((height//2) * (width//2)):
# FIXME there should be only one variable to control whether use xformers
# if self._use_memory_efficient_attention_xformers:
# for large attention map of 64X64, use xformers to save memory
hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
# if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = _attention(query, key, value, attention_mask=attention_mask, is_cross=False)
# else:
# hidden_states = self._sliced_attention(
# query, key, value, hidden_states.shape[1], dim, attention_mask
# )
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _sliced_attention(query, key, value, sequence_length, dim, attention_mask):
#query (bz*heads, t x h x w, org_dim//heads )
is_cross = False
batch_size_attention = query.shape[0] # bz * heads
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
if ddim_inversion:
per_frame_len = sequence_length//clip_length
attention_store = torch.zeros((batch_size_attention, clip_length, per_frame_len, per_frame_len), device=query.device, dtype=query.dtype)
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
if self.upcast_softmax:
attn_slice = attn_slice.float()
if i < self.heads:
if not ddim_inversion:
attention_probs = controller((attn_slice.unsqueeze(1)),is_cross, place_in_unet)
attn_slice = attention_probs.squeeze(1)
attn_slice = attn_slice.softmax(dim=-1)
# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
## bz == 1, sliced head
if ddim_inversion:
# attn_slice (1, thw, thw)
bz, thw, thw = attn_slice.shape
t = clip_length
hw = thw // t
# 初始化 per_frame_attention
# (1, t, hxw)
per_frame_attention = torch.empty((bz, t, hw, hw), device=attn_slice.device)
# # 循环提取每一帧的对角线注意力
for idx in range(t):
start_idx_ = idx * hw
end_idx_ = (idx + 1) * hw
# per frame attention extraction
per_frame_attention[:, idx, :, :] = attn_slice[:, start_idx_:end_idx_, start_idx_:end_idx_]
# current_query_block = attn_slice[:, start_idx_:end_idx_, :]
# aggregated_attention = current_query_block.view(bz, hw, t, hw).mean(dim=2)
# # print('aggregated_attention',aggregated_attention.shape)
# per_frame_attention[:, idx, :, :] = aggregated_attention
per_frame_attention = rearrange(per_frame_attention, "b t h w -> (b t) h w")
attention_store[start_idx:end_idx] = per_frame_attention
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
if ddim_inversion:
# attention store (bz*heads, t , h, w) h=res, w=res
_ = controller(attention_store, is_cross, place_in_unet)
# reshape hidden_states
hidden_states = reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def fully_frame_forward(hidden_states, encoder_hidden_states=None, attention_mask=None, clip_length=None, inter_frame=False, **kwargs):
batch_size, sequence_length, _ = hidden_states.shape
# print("hidden_states.shape",hidden_states.shape)
# print("sequence_length",sequence_length)
encoder_hidden_states = encoder_hidden_states
h = kwargs['height']
w = kwargs['width']
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states) # (bf) x d(hw) x c
self.q = query
if self.inject_q is not None:
query = self.inject_q
dim = query.shape[-1]
query_old = query.clone()
# All frames
#init query (bz*t, hxw, dim)
query = rearrange(query, "(b f) d c -> b (f d) c", f=clip_length)
query = reshape_heads_to_batch_dim(query) #(bz*heads, txhxw, dim//heads)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
self.k = key
if self.inject_k is not None:
key = self.inject_k
key_old = key.clone()
value = self.to_v(encoder_hidden_states)
if inter_frame:
key = rearrange(key, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]]
value = rearrange(value, "(b f) d c -> b f d c", f=clip_length)[:, [0, -1]]
key = rearrange(key, "b f d c -> b (f d) c",)
value = rearrange(value, "b f d c -> b (f d) c")
else:
# All frames
key = rearrange(key, "(b f) d c -> b (f d) c", f=clip_length)
value = rearrange(value, "(b f) d c -> b (f d) c", f=clip_length)
key = reshape_heads_to_batch_dim(key)
value = reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
#print("query.shape[0]",query.shape[0]) # 16
self._slice_size = 1 ### 8
sequence_length_full_frame = query.shape[1]
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers and query.shape[-2] > clip_length*(32 ** 2):
hidden_states = _memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
# if ddim_inversion:
# #if self._slice_size is None or query.shape[0] // self._slice_size == 1:
# hidden_states = _attention(query, key, value, attention_mask)
# else:
hidden_states = _sliced_attention(query, key, value, sequence_length_full_frame, dim, attention_mask)
if [h,w] in kwargs['flatten_res']:
hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length)
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if kwargs["old_qk"] == 1:
query = query_old
key = key_old
else:
query = hidden_states
key = hidden_states
value = hidden_states
traj = kwargs["traj"]
traj = rearrange(traj, '(f n) l d -> f n l d', f=clip_length, n=sequence_length)
mask = rearrange(kwargs["mask"], '(f n) l -> f n l', f=clip_length, n=sequence_length)
mask = torch.cat([mask[:, :, 0].unsqueeze(-1), mask[:, :, -clip_length+1:]], dim=-1)
#print('traj',traj.shape)
#print('mask',mask.shape)
traj_key_sequence_inds = torch.cat([traj[:, :, 0, :].unsqueeze(-2), traj[:, :, -clip_length+1:, :]], dim=-2)
t_inds = traj_key_sequence_inds[:, :, :, 0]
x_inds = traj_key_sequence_inds[:, :, :, 1]
y_inds = traj_key_sequence_inds[:, :, :, 2]
query_tempo = query.unsqueeze(-2)
_key = rearrange(key, '(b f) (h w) d -> b f h w d', b=int(batch_size/clip_length), f=clip_length, h=h, w=w)
_value = rearrange(value, '(b f) (h w) d -> b f h w d', b=int(batch_size/clip_length), f=clip_length, h=h, w=w)
key_tempo = _key[:, t_inds, x_inds, y_inds]
value_tempo = _value[:, t_inds, x_inds, y_inds]
key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d')
value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d')
mask = rearrange(torch.stack([mask, mask]), 'b f n l -> (b f) n l')
mask = mask[:,None].repeat(1, self.heads, 1, 1).unsqueeze(-2)
attn_bias = torch.zeros_like(mask, dtype=key_tempo.dtype) # regular zeros_like
attn_bias[~mask] = -torch.inf
# print('attn_bias',attn_bias.shape)
# print('query_tempo',query_tempo.shape)
# print('key_tempo',key_tempo.shape)
# flow attention
query_tempo = reshape_heads_to_batch_dim3(query_tempo)
key_tempo = reshape_heads_to_batch_dim3(key_tempo)
value_tempo = reshape_heads_to_batch_dim3(value_tempo)
attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt(query_tempo.size(-1)) + attn_bias
attn_matrix2 = F.softmax(attn_matrix2, dim=-1)
out = (attn_matrix2@value_tempo).squeeze(-2)
hidden_states = rearrange(out,'(b f) k (h w) d -> b (f h w) (k d)', b=int(batch_size/clip_length), f=clip_length, h=h, w=w)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
# All frames
hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=clip_length)
return hidden_states
if attention_type == 'CrossAttention':
# return mod_forward
return forward
elif attention_type == "SparseCausalAttention":
#return mod_forward
return spatial_temporal_forward
elif attention_type == "FullyFrameAttention":
#return mod_forward
return fully_frame_forward
class DummyController:
def __call__(self, *args):
return args[0]
def __init__(self):
self.num_att_layers = 0
if controller is None:
controller = DummyController()
def register_recr(net_, count, place_in_unet):
if net_[1].__class__.__name__ == 'CrossAttention' \
or net_[1].__class__.__name__ == 'FullyFrameAttention' \
or net_[1].__class__.__name__ == 'SparseCausalAttention' :
net_[1].forward = attention_controlled_forward(net_[1], place_in_unet, attention_type = net_[1].__class__.__name__)
return count + 1
elif hasattr(net_[1], 'children'):
for net in net_[1].named_children():
if net[0] !='attn_temporal':
count = register_recr(net, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = model.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net, 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net, 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net, 0, "mid")
#print(f"Number of attention layer registered {cross_att_count}")
controller.num_att_layers = cross_att_count