Spaces:
Build error
Build error
| from typing import Optional | |
| from einops import rearrange | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.models.attention import Attention | |
| class InflatedConv3d(nn.Conv2d): | |
| def forward(self, x): | |
| video_length = x.shape[2] | |
| x = rearrange(x, "b c f h w -> (b f) c h w") | |
| x = super().forward(x) | |
| x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) | |
| return x | |
| class FFInflatedConv3d(nn.Conv2d): | |
| def __init__(self, in_channels, out_channels, kernel_size, **kwargs): | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| **kwargs, | |
| ) | |
| self.conv_temp = nn.Linear(3 * out_channels, out_channels) | |
| nn.init.zeros_(self.conv_temp.weight.data) # initialized to be ones | |
| nn.init.zeros_(self.conv_temp.bias.data) | |
| def forward(self, x): | |
| video_length = x.shape[2] | |
| x = rearrange(x, "b c f h w -> (b f) c h w") | |
| x = super().forward(x) | |
| *_, h, w = x.shape | |
| x = rearrange(x, "(b f) c h w -> (b h w) f c", f=video_length) | |
| head_frame_index = [0, ] * video_length | |
| prev_frame_index = torch.clamp( | |
| torch.arange(video_length) - 1, min=0.0 | |
| ).long() | |
| curr_frame_index = torch.arange(video_length).long() | |
| conv_temp_nn_input = torch.cat([ | |
| x[:, head_frame_index], | |
| x[:, prev_frame_index], | |
| x[:, curr_frame_index] | |
| ], dim=2).contiguous() | |
| x = x + self.conv_temp(conv_temp_nn_input) | |
| x = rearrange(x, "(b h w) f c -> b c f h w", h=h, w=w) | |
| return x | |
| class FFAttention(Attention): | |
| r""" | |
| A cross attention layer. | |
| Parameters: | |
| query_dim (`int`): The number of channels in the query. | |
| cross_attention_dim (`int`, *optional*): | |
| The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. | |
| heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. | |
| dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. | |
| dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
| bias (`bool`, *optional*, defaults to False): | |
| Set to `True` for the query, key, and value linear layers to contain a bias parameter. | |
| """ | |
| def __init__( | |
| self, | |
| *args, | |
| scale_qk: bool = True, | |
| processor: Optional["FFAttnProcessor"] = None, | |
| **kwargs | |
| ): | |
| super().__init__(*args, scale_qk=scale_qk, processor=processor, **kwargs) | |
| # set attention processor | |
| # We use the AttnProcessor by default when torch 2.x is used which uses | |
| # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention | |
| # but only if it has the default `scale` argument. | |
| if processor is None: | |
| processor = FFAttnProcessor() | |
| self.set_processor(processor) | |
| def forward(self, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None, | |
| **cross_attention_kwargs): | |
| # The `Attention` class can call different attention processors / attention functions | |
| # here we simply pass along all tensors to the selected processor class | |
| # For standard processors that are defined here, `**cross_attention_kwargs` is empty | |
| return self.processor( | |
| self, | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| video_length=video_length, | |
| **cross_attention_kwargs, | |
| ) | |
| class FFAttnProcessor: | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "FFAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
| def __call__(self, attn: Attention, hidden_states, video_length, encoder_hidden_states=None, attention_mask=None): | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| inner_dim = hidden_states.shape[-1] | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| # sparse causal attention | |
| former_frame_index = torch.arange(video_length) - 1 | |
| former_frame_index[0] = 0 | |
| key = rearrange(key, "(b f) d c -> b f d c", f=video_length) | |
| key = key[:, [0] * video_length].contiguous() | |
| key = rearrange(key, "b f d c -> (b f) d c") | |
| value = rearrange(value, "(b f) d c -> b f d c", f=video_length) | |
| value = value[:, [0] * video_length].contiguous() | |
| value = rearrange(value, "b f d c -> (b f) d c") | |
| head_dim = inner_dim // attn.heads | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| return hidden_states |