| | from .sd_unet import SDUNet, Attention, GEGLU |
| | import torch |
| | from einops import rearrange, repeat |
| |
|
| |
|
| | class TemporalTransformerBlock(torch.nn.Module): |
| |
|
| | def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32): |
| | super().__init__() |
| |
|
| | |
| | self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) |
| | self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) |
| | self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) |
| |
|
| | |
| | self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) |
| | self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) |
| | self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) |
| |
|
| | |
| | self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True) |
| | self.act_fn = GEGLU(dim, dim * 4) |
| | self.ff = torch.nn.Linear(dim * 4, dim) |
| |
|
| |
|
| | def forward(self, hidden_states, batch_size=1): |
| |
|
| | |
| | norm_hidden_states = self.norm1(hidden_states) |
| | norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) |
| | attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]]) |
| | attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) |
| | hidden_states = attn_output + hidden_states |
| |
|
| | |
| | norm_hidden_states = self.norm2(hidden_states) |
| | norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) |
| | attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]]) |
| | attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) |
| | hidden_states = attn_output + hidden_states |
| |
|
| | |
| | norm_hidden_states = self.norm3(hidden_states) |
| | ff_output = self.act_fn(norm_hidden_states) |
| | ff_output = self.ff(ff_output) |
| | hidden_states = ff_output + hidden_states |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class TemporalBlock(torch.nn.Module): |
| | |
| | def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): |
| | super().__init__() |
| | inner_dim = num_attention_heads * attention_head_dim |
| |
|
| | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) |
| | self.proj_in = torch.nn.Linear(in_channels, inner_dim) |
| |
|
| | self.transformer_blocks = torch.nn.ModuleList([ |
| | TemporalTransformerBlock( |
| | inner_dim, |
| | num_attention_heads, |
| | attention_head_dim |
| | ) |
| | for d in range(num_layers) |
| | ]) |
| |
|
| | self.proj_out = torch.nn.Linear(inner_dim, in_channels) |
| |
|
| | def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1): |
| | batch, _, height, width = hidden_states.shape |
| | residual = hidden_states |
| |
|
| | hidden_states = self.norm(hidden_states) |
| | inner_dim = hidden_states.shape[1] |
| | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) |
| | hidden_states = self.proj_in(hidden_states) |
| |
|
| | for block in self.transformer_blocks: |
| | hidden_states = block( |
| | hidden_states, |
| | batch_size=batch_size |
| | ) |
| |
|
| | hidden_states = self.proj_out(hidden_states) |
| | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() |
| | hidden_states = hidden_states + residual |
| |
|
| | return hidden_states, time_emb, text_emb, res_stack |
| |
|
| |
|
| | class SDMotionModel(torch.nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.motion_modules = torch.nn.ModuleList([ |
| | TemporalBlock(8, 40, 320, eps=1e-6), |
| | TemporalBlock(8, 40, 320, eps=1e-6), |
| | TemporalBlock(8, 80, 640, eps=1e-6), |
| | TemporalBlock(8, 80, 640, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 160, 1280, eps=1e-6), |
| | TemporalBlock(8, 80, 640, eps=1e-6), |
| | TemporalBlock(8, 80, 640, eps=1e-6), |
| | TemporalBlock(8, 80, 640, eps=1e-6), |
| | TemporalBlock(8, 40, 320, eps=1e-6), |
| | TemporalBlock(8, 40, 320, eps=1e-6), |
| | TemporalBlock(8, 40, 320, eps=1e-6), |
| | ]) |
| | self.call_block_id = { |
| | 1: 0, |
| | 4: 1, |
| | 9: 2, |
| | 12: 3, |
| | 17: 4, |
| | 20: 5, |
| | 24: 6, |
| | 26: 7, |
| | 29: 8, |
| | 32: 9, |
| | 34: 10, |
| | 36: 11, |
| | 40: 12, |
| | 43: 13, |
| | 46: 14, |
| | 50: 15, |
| | 53: 16, |
| | 56: 17, |
| | 60: 18, |
| | 63: 19, |
| | 66: 20 |
| | } |
| | |
| | def forward(self): |
| | pass |
| |
|
| | @staticmethod |
| | def state_dict_converter(): |
| | return SDMotionModelStateDictConverter() |
| |
|
| |
|
| | class SDMotionModelStateDictConverter: |
| | def __init__(self): |
| | pass |
| |
|
| | def from_diffusers(self, state_dict): |
| | rename_dict = { |
| | "norm": "norm", |
| | "proj_in": "proj_in", |
| | "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", |
| | "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", |
| | "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", |
| | "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", |
| | "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", |
| | "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", |
| | "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", |
| | "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", |
| | "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", |
| | "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", |
| | "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", |
| | "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", |
| | "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", |
| | "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", |
| | "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", |
| | "proj_out": "proj_out", |
| | } |
| | name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) |
| | name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) |
| | name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) |
| | state_dict_ = {} |
| | last_prefix, module_id = "", -1 |
| | for name in name_list: |
| | names = name.split(".") |
| | prefix_index = names.index("temporal_transformer") + 1 |
| | prefix = ".".join(names[:prefix_index]) |
| | if prefix != last_prefix: |
| | last_prefix = prefix |
| | module_id += 1 |
| | middle_name = ".".join(names[prefix_index:-1]) |
| | suffix = names[-1] |
| | if "pos_encoder" in names: |
| | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) |
| | else: |
| | rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) |
| | state_dict_[rename] = state_dict[name] |
| | return state_dict_ |
| | |
| | def from_civitai(self, state_dict): |
| | return self.from_diffusers(state_dict) |
| |
|