| | import inspect |
| | import math |
| | from typing import Callable, List, Optional, Tuple, Union |
| | from einops import rearrange |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from torch import Tensor |
| | from diffusers.models.attention_processor import Attention |
| | import os |
| | import os.path as osp |
| | import numpy as np |
| |
|
| | class LoRALinearLayer(nn.Module): |
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | rank: int = 4, |
| | network_alpha: Optional[float] = None, |
| | device: Optional[Union[torch.device, str]] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | cond_width=512, |
| | cond_height=512, |
| | number=0, |
| | n_loras=1 |
| | ): |
| | super().__init__() |
| | self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) |
| | self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) |
| | |
| | |
| | self.network_alpha = network_alpha |
| | self.rank = rank |
| | self.out_features = out_features |
| | self.in_features = in_features |
| |
|
| | nn.init.normal_(self.down.weight, std=1 / rank) |
| | nn.init.zeros_(self.up.weight) |
| | |
| | self.cond_height = cond_height |
| | self.cond_width = cond_width |
| | self.number = number |
| | self.n_loras = n_loras |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | orig_dtype = hidden_states.dtype |
| | dtype = self.down.weight.dtype |
| |
|
| | |
| | batch_size = hidden_states.shape[0] |
| | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 |
| | block_size = hidden_states.shape[1] - cond_size * self.n_loras |
| | shape = (batch_size, hidden_states.shape[1], 3072) |
| | mask = torch.ones(shape, device=hidden_states.device, dtype=dtype) |
| | mask[:, :block_size+self.number*cond_size, :] = 0 |
| | mask[:, block_size+(self.number+1)*cond_size:, :] = 0 |
| | hidden_states = mask * hidden_states |
| | |
| | |
| | down_hidden_states = self.down(hidden_states.to(dtype)) |
| | up_hidden_states = self.up(down_hidden_states) |
| |
|
| | if self.network_alpha is not None: |
| | up_hidden_states *= self.network_alpha / self.rank |
| |
|
| | return up_hidden_states.to(orig_dtype) |
| | |
| |
|
| | class MultiSingleStreamBlockLoraProcessor(nn.Module): |
| | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1): |
| | super().__init__() |
| | |
| | self.n_loras = n_loras |
| | self.cond_width = cond_width |
| | self.cond_height = cond_height |
| | |
| | self.q_loras = nn.ModuleList([ |
| | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) |
| | for i in range(n_loras) |
| | ]) |
| | self.k_loras = nn.ModuleList([ |
| | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) |
| | for i in range(n_loras) |
| | ]) |
| | self.v_loras = nn.ModuleList([ |
| | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) |
| | for i in range(n_loras) |
| | ]) |
| | self.lora_weights = lora_weights |
| | |
| |
|
| | def __call__(self, |
| | attn: Attention, |
| | hidden_states: torch.FloatTensor, |
| | encoder_hidden_states: torch.FloatTensor = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | image_rotary_emb: Optional[torch.Tensor] = None, |
| | use_cond = False, |
| | call_ids = None, |
| | cuboids_segmasks: torch.Tensor = None, |
| | ) -> torch.FloatTensor: |
| | |
| | batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| | query = attn.to_q(hidden_states) |
| | key = attn.to_k(hidden_states) |
| | value = attn.to_v(hidden_states) |
| | |
| | for i in range(self.n_loras): |
| | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states) |
| | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states) |
| | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states) |
| |
|
| | inner_dim = key.shape[-1] |
| | head_dim = inner_dim // attn.heads |
| | |
| | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | if attn.norm_q is not None: |
| | query = attn.norm_q(query) |
| | if attn.norm_k is not None: |
| | key = attn.norm_k(key) |
| |
|
| | if image_rotary_emb is not None: |
| | from diffusers.models.embeddings import apply_rotary_emb |
| | query = apply_rotary_emb(query, image_rotary_emb) |
| | key = apply_rotary_emb(key, image_rotary_emb) |
| |
|
| | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 |
| | block_size = hidden_states.shape[1] - cond_size * self.n_loras |
| | scaled_cond_size = cond_size |
| | scaled_block_size = block_size |
| | scaled_seq_len = query.shape[2] |
| |
|
| | num_cond_blocks = self.n_loras |
| | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device) |
| | |
| | mask[ :scaled_block_size, :] = 0 |
| | for i in range(num_cond_blocks): |
| | start = i * scaled_cond_size + scaled_block_size |
| | end = (i + 1) * scaled_cond_size + scaled_block_size |
| | mask[start:end, start:end] = 0 |
| |
|
| | assert mask.shape[0] == scaled_block_size + num_cond_blocks*scaled_cond_size, f"{mask.shape = }, {scaled_block_size=}, {num_cond_blocks=}, {scaled_cond_size=}" |
| |
|
| | if call_ids is not None: |
| | |
| | mask = mask.unsqueeze(0).unsqueeze(0).repeat(len(call_ids), 1, 1, 1) |
| | num_img_tokens = scaled_block_size - 512 |
| | for batch_idx in range(len(call_ids)): |
| | call_ids_this_example = call_ids[batch_idx] |
| | for subject_idx, call_ids_this_subject in enumerate(call_ids_this_example): |
| | |
| | cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] |
| | cuboid_mask = cuboid_mask.to(torch.bool) |
| |
|
| | for i in range(num_cond_blocks): |
| | cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] |
| | cuboid_mask = cuboid_mask.to(torch.bool) |
| | |
| | mask_subset = mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] |
| | mask_subset[:, cuboid_mask.flatten()] = 0 |
| |
|
| | mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] = mask_subset |
| |
|
| |
|
| | mask = mask * -1e20 |
| | mask = mask.to(query.dtype) |
| |
|
| | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask) |
| |
|
| | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| | hidden_states = hidden_states.to(query.dtype) |
| |
|
| | cond_hidden_states = hidden_states[:, block_size:,:] |
| | hidden_states = hidden_states[:, : block_size,:] |
| |
|
| | return hidden_states if not use_cond else (hidden_states, cond_hidden_states) |
| |
|
| |
|
| | class MultiDoubleStreamBlockLoraProcessor(nn.Module): |
| | def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1): |
| | super().__init__() |
| | |
| | |
| | self.n_loras = n_loras |
| | self.cond_width = cond_width |
| | self.cond_height = cond_height |
| | self.q_loras = nn.ModuleList([ |
| | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) |
| | for i in range(n_loras) |
| | ]) |
| | self.k_loras = nn.ModuleList([ |
| | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) |
| | for i in range(n_loras) |
| | ]) |
| | self.v_loras = nn.ModuleList([ |
| | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) |
| | for i in range(n_loras) |
| | ]) |
| | self.proj_loras = nn.ModuleList([ |
| | LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras) |
| | for i in range(n_loras) |
| | ]) |
| | self.lora_weights = lora_weights |
| |
|
| |
|
| | def __call__(self, |
| | attn: Attention, |
| | hidden_states: torch.FloatTensor, |
| | encoder_hidden_states: torch.FloatTensor = None, |
| | attention_mask: Optional[torch.FloatTensor] = None, |
| | image_rotary_emb: Optional[torch.Tensor] = None, |
| | use_cond=False, |
| | call_ids = None, |
| | cuboids_segmasks: torch.Tensor = None, |
| | ) -> torch.FloatTensor: |
| | |
| | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| |
|
| | |
| | inner_dim = 3072 |
| | head_dim = inner_dim // attn.heads |
| | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) |
| | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) |
| | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) |
| |
|
| | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( |
| | batch_size, -1, attn.heads, head_dim |
| | ).transpose(1, 2) |
| | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( |
| | batch_size, -1, attn.heads, head_dim |
| | ).transpose(1, 2) |
| | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( |
| | batch_size, -1, attn.heads, head_dim |
| | ).transpose(1, 2) |
| |
|
| | if attn.norm_added_q is not None: |
| | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
| | if attn.norm_added_k is not None: |
| | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) |
| | |
| | query = attn.to_q(hidden_states) |
| | key = attn.to_k(hidden_states) |
| | value = attn.to_v(hidden_states) |
| | for i in range(self.n_loras): |
| | query = query + self.lora_weights[i] * self.q_loras[i](hidden_states) |
| | key = key + self.lora_weights[i] * self.k_loras[i](hidden_states) |
| | value = value + self.lora_weights[i] * self.v_loras[i](hidden_states) |
| |
|
| | inner_dim = key.shape[-1] |
| | head_dim = inner_dim // attn.heads |
| | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | if attn.norm_q is not None: |
| | query = attn.norm_q(query) |
| | if attn.norm_k is not None: |
| | key = attn.norm_k(key) |
| | |
| | |
| | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) |
| | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) |
| | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) |
| |
|
| | if image_rotary_emb is not None: |
| | from diffusers.models.embeddings import apply_rotary_emb |
| | query = apply_rotary_emb(query, image_rotary_emb) |
| | key = apply_rotary_emb(key, image_rotary_emb) |
| |
|
| | cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64 |
| | block_size = hidden_states.shape[1] - cond_size * self.n_loras |
| | scaled_cond_size = cond_size |
| | scaled_seq_len = query.shape[2] |
| | scaled_block_size = scaled_seq_len - cond_size * self.n_loras |
| | |
| | num_cond_blocks = self.n_loras |
| | mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device) |
| | mask[ :scaled_block_size, :] = 0 |
| | for i in range(num_cond_blocks): |
| | start = i * scaled_cond_size + scaled_block_size |
| | end = (i + 1) * scaled_cond_size + scaled_block_size |
| | mask[start:end, start:end] = 0 |
| |
|
| | assert mask.shape[0] == scaled_block_size + num_cond_blocks*scaled_cond_size, f"{mask.shape = }, {scaled_block_size=}, {num_cond_blocks=}, {scaled_cond_size=}" |
| |
|
| | if call_ids is not None: |
| | |
| | mask = mask.unsqueeze(0).unsqueeze(0).repeat(len(call_ids), 1, 1, 1) |
| | num_img_tokens = scaled_block_size - 512 |
| | for batch_idx in range(len(call_ids)): |
| | call_ids_this_example = call_ids[batch_idx] |
| | for subject_idx, call_ids_this_subject in enumerate(call_ids_this_example): |
| | |
| | cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] |
| | cuboid_mask = cuboid_mask.to(torch.bool) |
| |
|
| | for i in range(num_cond_blocks): |
| | cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] |
| | cuboid_mask = cuboid_mask.to(torch.bool) |
| | |
| | mask_subset = mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] |
| | mask_subset[:, cuboid_mask.flatten()] = 0 |
| |
|
| | mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] = mask_subset |
| |
|
| |
|
| | mask = mask * -1e20 |
| | mask = mask.to(query.dtype) |
| | |
| | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask) |
| |
|
| | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| | hidden_states = hidden_states.to(query.dtype) |
| | |
| | encoder_hidden_states, hidden_states = ( |
| | hidden_states[:, : encoder_hidden_states.shape[1]], |
| | hidden_states[:, encoder_hidden_states.shape[1] :], |
| | ) |
| |
|
| | |
| | hidden_states = attn.to_out[0](hidden_states) |
| | for i in range(self.n_loras): |
| | hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states) |
| | |
| | hidden_states = attn.to_out[1](hidden_states) |
| | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
| | |
| | cond_hidden_states = hidden_states[:, block_size:,:] |
| | hidden_states = hidden_states[:, :block_size,:] |
| | |
| | return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states) |