Spaces:
Runtime error
Runtime error
| from einops import repeat, rearrange | |
| from typing import Callable, Optional, Union | |
| from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention | |
| # from t2v_enhanced.model.diffusers_conditional.controldiffusers.models.attention import Attention | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from t2v_enhanced.model.pl_module_params_controlnet import AttentionMaskParams | |
| import torch | |
| import torch.nn.functional as F | |
| if is_xformers_available(): | |
| import xformers | |
| import xformers.ops | |
| else: | |
| xformers = None | |
| def set_use_memory_efficient_attention_xformers( | |
| model, num_frame_conditioning: int, num_frames: int, attention_mask_params: AttentionMaskParams, valid: bool = True, attention_op: Optional[Callable] = None | |
| ) -> None: | |
| # Recursively walk through all the children. | |
| # Any children which exposes the set_use_memory_efficient_attention_xformers method | |
| # gets the message | |
| def fn_recursive_set_mem_eff(module: torch.nn.Module): | |
| if hasattr(module, "set_processor"): | |
| module.set_processor(XFormersAttnProcessor(attention_op=attention_op, | |
| num_frame_conditioning=num_frame_conditioning, | |
| num_frames=num_frames, | |
| attention_mask_params=attention_mask_params,) | |
| ) | |
| for child in module.children(): | |
| fn_recursive_set_mem_eff(child) | |
| for module in model.children(): | |
| if isinstance(module, torch.nn.Module): | |
| fn_recursive_set_mem_eff(module) | |
| class XFormersAttnProcessor: | |
| def __init__(self, | |
| attention_mask_params: AttentionMaskParams, | |
| attention_op: Optional[Callable] = None, | |
| num_frame_conditioning: int = None, | |
| num_frames: int = None, | |
| use_image_embedding: bool = False, | |
| ): | |
| self.attention_op = attention_op | |
| self.num_frame_conditioning = num_frame_conditioning | |
| self.num_frames = num_frames | |
| self.temp_attend_on_neighborhood_of_condition_frames = attention_mask_params.temp_attend_on_neighborhood_of_condition_frames | |
| self.spatial_attend_on_condition_frames = attention_mask_params.spatial_attend_on_condition_frames | |
| self.use_image_embedding = use_image_embedding | |
| def __call__(self, attn: Attention, hidden_states, hidden_state_height=None, hidden_state_width=None, encoder_hidden_states=None, attention_mask=None): | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| key_img = None | |
| value_img = None | |
| hidden_states_img = None | |
| if attention_mask is not None: | |
| attention_mask = repeat( | |
| attention_mask, "1 F D -> B F D", B=batch_size) | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size) | |
| query = attn.to_q(hidden_states) | |
| is_cross_attention = encoder_hidden_states is not None | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states) | |
| default_attention = not hasattr(attn, "is_spatial_attention") | |
| if default_attention: | |
| assert not self.temp_attend_on_neighborhood_of_condition_frames, "special attention must be implemented with new interface" | |
| assert not self.spatial_attend_on_condition_frames, "special attention must be implemented with new interface" | |
| is_spatial_attention = attn.is_spatial_attention if hasattr( | |
| attn, "is_spatial_attention") else False | |
| use_image_embedding = attn.use_image_embedding if hasattr( | |
| attn, "use_image_embedding") else False | |
| if is_spatial_attention and use_image_embedding and attn.cross_attention_mode: | |
| assert not self.spatial_attend_on_condition_frames, "Not implemented together with image embedding" | |
| alpha = attn.alpha | |
| encoder_hidden_states_txt = encoder_hidden_states[:, :77, :] | |
| encoder_hidden_states_mixed = attn.conv(encoder_hidden_states) | |
| encoder_hidden_states_mixed = attn.conv_ln(encoder_hidden_states_mixed) | |
| encoder_hidden_states = encoder_hidden_states_txt + encoder_hidden_states_mixed * F.silu(alpha) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| else: | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| if not default_attention and not is_spatial_attention and self.temp_attend_on_neighborhood_of_condition_frames and not attn.cross_attention_mode: | |
| # normal attention | |
| query_condition = query[:, :self.num_frame_conditioning] | |
| query_condition = attn.head_to_batch_dim( | |
| query_condition).contiguous() | |
| key_condition = key | |
| value_condition = value | |
| key_condition = attn.head_to_batch_dim(key_condition).contiguous() | |
| value_condition = attn.head_to_batch_dim( | |
| value_condition).contiguous() | |
| hidden_states_condition = xformers.ops.memory_efficient_attention( | |
| query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale | |
| ) | |
| hidden_states_condition = hidden_states_condition.to(query.dtype) | |
| hidden_states_condition = attn.batch_to_head_dim( | |
| hidden_states_condition) | |
| # | |
| query_uncondition = query[:, self.num_frame_conditioning:] | |
| key = key[:, :self.num_frame_conditioning] | |
| value = value[:, :self.num_frame_conditioning] | |
| key = rearrange(key, "(B W H) F C -> B W H F C", | |
| H=hidden_state_height, W=hidden_state_width) | |
| value = rearrange(value, "(B W H) F C -> B W H F C", | |
| H=hidden_state_height, W=hidden_state_width) | |
| keys = [] | |
| values = [] | |
| for shifts_width in [-1, 0, 1]: | |
| for shifts_height in [-1, 0, 1]: | |
| keys.append(torch.roll(key, shifts=( | |
| shifts_width, shifts_height), dims=(1, 2))) | |
| values.append(torch.roll(value, shifts=( | |
| shifts_width, shifts_height), dims=(1, 2))) | |
| key = rearrange(torch.cat(keys, dim=3), "B W H F C -> (B W H) F C") | |
| value = rearrange(torch.cat(values, dim=3), | |
| 'B W H F C -> (B W H) F C') | |
| query = attn.head_to_batch_dim(query_uncondition).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| hidden_states = torch.cat( | |
| [hidden_states_condition, hidden_states], dim=1) | |
| elif not default_attention and is_spatial_attention and self.spatial_attend_on_condition_frames and not attn.cross_attention_mode: | |
| # (B F) W H C -> B F W H C | |
| query_condition = rearrange( | |
| query, "(B F) S C -> B F S C", F=self.num_frames) | |
| query_condition = query_condition[:, :self.num_frame_conditioning] | |
| query_condition = rearrange( | |
| query_condition, "B F S C -> (B F) S C") | |
| query_condition = attn.head_to_batch_dim( | |
| query_condition).contiguous() | |
| key_condition = rearrange( | |
| key, "(B F) S C -> B F S C", F=self.num_frames) | |
| key_condition = key_condition[:, :self.num_frame_conditioning] | |
| key_condition = rearrange(key_condition, "B F S C -> (B F) S C") | |
| value_condition = rearrange( | |
| value, "(B F) S C -> B F S C", F=self.num_frames) | |
| value_condition = value_condition[:, :self.num_frame_conditioning] | |
| value_condition = rearrange( | |
| value_condition, "B F S C -> (B F) S C") | |
| key_condition = attn.head_to_batch_dim(key_condition).contiguous() | |
| value_condition = attn.head_to_batch_dim( | |
| value_condition).contiguous() | |
| hidden_states_condition = xformers.ops.memory_efficient_attention( | |
| query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale | |
| ) | |
| hidden_states_condition = hidden_states_condition.to(query.dtype) | |
| hidden_states_condition = attn.batch_to_head_dim( | |
| hidden_states_condition) | |
| query_uncondition = rearrange( | |
| query, "(B F) S C -> B F S C", F=self.num_frames) | |
| query_uncondition = query_uncondition[:, | |
| self.num_frame_conditioning:] | |
| key_uncondition = rearrange( | |
| key, "(B F) S C -> B F S C", F=self.num_frames) | |
| value_uncondition = rearrange( | |
| value, "(B F) S C -> B F S C", F=self.num_frames) | |
| key_uncondition = key_uncondition[:, | |
| self.num_frame_conditioning-1, None] | |
| value_uncondition = value_uncondition[:, | |
| self.num_frame_conditioning-1, None] | |
| # if self.trainer.training: | |
| # import pdb | |
| # pdb.set_trace() | |
| # print("now") | |
| query_uncondition = rearrange( | |
| query_uncondition, "B F S C -> (B F) S C") | |
| key_uncondition = repeat(rearrange( | |
| key_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning) | |
| value_uncondition = repeat(rearrange( | |
| value_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning) | |
| query_uncondition = attn.head_to_batch_dim( | |
| query_uncondition).contiguous() | |
| key_uncondition = attn.head_to_batch_dim( | |
| key_uncondition).contiguous() | |
| value_uncondition = attn.head_to_batch_dim( | |
| value_uncondition).contiguous() | |
| hidden_states_uncondition = xformers.ops.memory_efficient_attention( | |
| query_uncondition, key_uncondition, value_uncondition, attn_bias=None, op=self.attention_op, scale=attn.scale | |
| ) | |
| hidden_states_uncondition = hidden_states_uncondition.to( | |
| query.dtype) | |
| hidden_states_uncondition = attn.batch_to_head_dim( | |
| hidden_states_uncondition) | |
| hidden_states = torch.cat([rearrange(hidden_states_condition, "(B F) S C -> B F S C", F=self.num_frame_conditioning), rearrange( | |
| hidden_states_uncondition, "(B F) S C -> B F S C", F=self.num_frames-self.num_frame_conditioning)], dim=1) | |
| hidden_states = rearrange(hidden_states, "B F S C -> (B F) S C") | |
| else: | |
| query = attn.head_to_batch_dim(query).contiguous() | |
| key = attn.head_to_batch_dim(key).contiguous() | |
| value = attn.head_to_batch_dim(value).contiguous() | |
| hidden_states = xformers.ops.memory_efficient_attention( | |
| query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states | |