|
|
from itertools import chain
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from diffusers.models.attention_processor import (
|
|
|
Attention,
|
|
|
AttentionProcessor,
|
|
|
)
|
|
|
|
|
|
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
|
|
import torch.nn.functional as F
|
|
|
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
|
|
from diffusers.models.attention_processor import Attention
|
|
|
import inspect
|
|
|
from functools import partial
|
|
|
from diffusers.models.normalization import RMSNorm
|
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
class IPFluxAttnProcessor2_0(nn.Module):
|
|
|
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
|
|
|
|
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, num_heads=0):
|
|
|
super().__init__()
|
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
self.cross_attention_dim = cross_attention_dim
|
|
|
self.scale = scale
|
|
|
self.num_tokens = num_tokens
|
|
|
|
|
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
|
|
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
|
|
|
|
|
self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
attn,
|
|
|
hidden_states: torch.FloatTensor,
|
|
|
encoder_hidden_states: torch.FloatTensor = None,
|
|
|
ip_encoder_hidden_states: torch.FloatTensor = None,
|
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
|
|
layer_scale: Optional[torch.Tensor] = None,
|
|
|
) -> torch.FloatTensor:
|
|
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
|
|
|
|
ip_hidden_states = ip_encoder_hidden_states
|
|
|
|
|
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
key = attn.to_k(hidden_states)
|
|
|
value = attn.to_v(hidden_states)
|
|
|
|
|
|
inner_dim = key.shape[-1]
|
|
|
head_dim = inner_dim // attn.heads
|
|
|
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
|
|
if attn.norm_q is not None:
|
|
|
query = attn.norm_q(query)
|
|
|
if attn.norm_k is not None:
|
|
|
key = attn.norm_k(key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ip_hidden_states != None:
|
|
|
ip_key = self.to_k_ip(ip_hidden_states)
|
|
|
ip_value = self.to_v_ip(ip_hidden_states)
|
|
|
|
|
|
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
|
|
ip_key = self.norm_added_k(ip_key)
|
|
|
|
|
|
|
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention(
|
|
|
query,
|
|
|
ip_key,
|
|
|
ip_value,
|
|
|
dropout_p=0.0,
|
|
|
is_causal=False,
|
|
|
attn_mask=None,
|
|
|
)
|
|
|
|
|
|
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
|
|
batch_size, -1, attn.heads * head_dim
|
|
|
)
|
|
|
|
|
|
|
|
|
if encoder_hidden_states is not None:
|
|
|
|
|
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
|
|
|
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
|
|
|
|
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
|
|
batch_size, -1, attn.heads, head_dim
|
|
|
).transpose(1, 2)
|
|
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
|
|
batch_size, -1, attn.heads, head_dim
|
|
|
).transpose(1, 2)
|
|
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
|
|
batch_size, -1, attn.heads, head_dim
|
|
|
).transpose(1, 2)
|
|
|
|
|
|
if attn.norm_added_q is not None:
|
|
|
encoder_hidden_states_query_proj = attn.norm_added_q(
|
|
|
encoder_hidden_states_query_proj
|
|
|
)
|
|
|
if attn.norm_added_k is not None:
|
|
|
encoder_hidden_states_key_proj = attn.norm_added_k(
|
|
|
encoder_hidden_states_key_proj
|
|
|
)
|
|
|
|
|
|
|
|
|
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
|
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
|
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
|
|
|
|
|
if image_rotary_emb is not None:
|
|
|
from diffusers.models.embeddings import apply_rotary_emb
|
|
|
query = apply_rotary_emb(query, image_rotary_emb)
|
|
|
|
|
|
key = apply_rotary_emb(key, image_rotary_emb)
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
|
attention_mask = (attention_mask > 0).bool()
|
|
|
attention_mask = attention_mask.to(
|
|
|
device=hidden_states.device, dtype=query.dtype
|
|
|
)
|
|
|
original_hidden_states = hidden_states
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention(
|
|
|
query,
|
|
|
key,
|
|
|
value,
|
|
|
dropout_p=0.0,
|
|
|
is_causal=False,
|
|
|
attn_mask=attention_mask,
|
|
|
)
|
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(
|
|
|
batch_size, -1, attn.heads * head_dim
|
|
|
)
|
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
|
|
|
|
|
layer_scale = layer_scale.view(-1, 1, 1)
|
|
|
|
|
|
if encoder_hidden_states is not None:
|
|
|
|
|
|
encoder_hidden_states, hidden_states = (
|
|
|
hidden_states[:, : encoder_hidden_states.shape[1]],
|
|
|
hidden_states[:, encoder_hidden_states.shape[1] :],
|
|
|
)
|
|
|
|
|
|
|
|
|
if ip_hidden_states != None:
|
|
|
hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
|
|
|
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
|
|
|
|
|
return hidden_states, encoder_hidden_states
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
if ip_hidden_states != None:
|
|
|
hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
|
|
|
|
|
|
if attn.to_out is not None:
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
class ImageProjModel(nn.Module):
|
|
|
def __init__(self, clip_dim=768, cross_attention_dim=4096, num_tokens=16):
|
|
|
super().__init__()
|
|
|
|
|
|
self.num_tokens = num_tokens
|
|
|
self.cross_attention_dim = cross_attention_dim
|
|
|
self.clip_dim = clip_dim
|
|
|
|
|
|
self.proj = torch.nn.Sequential(
|
|
|
torch.nn.Linear(clip_dim,clip_dim*2),
|
|
|
torch.nn.GELU(),
|
|
|
torch.nn.Linear(clip_dim*2, cross_attention_dim*num_tokens),
|
|
|
)
|
|
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
|
|
|
|
|
def forward(self,input):
|
|
|
|
|
|
raw_proj = self.proj(input)
|
|
|
reshaped_proj = raw_proj.reshape(input.shape[0],self.num_tokens,self.cross_attention_dim)
|
|
|
reshaped_proj = self.norm( reshaped_proj )
|
|
|
|
|
|
return reshaped_proj
|
|
|
|
|
|
|
|
|
class LibreFluxIPAdapter(nn.Module):
|
|
|
def __init__(self, transformer, image_proj_model, checkpoint=None):
|
|
|
super().__init__()
|
|
|
self.transformer = transformer
|
|
|
self.image_proj_model = image_proj_model
|
|
|
|
|
|
|
|
|
self.culled_transformer_blocks = {}
|
|
|
for name, module in self.transformer.named_modules():
|
|
|
if isinstance(module, Attention):
|
|
|
if name.startswith('transformer_blocks') or name.startswith('single_transformer_blocks'):
|
|
|
|
|
|
self.culled_transformer_blocks[name] = module
|
|
|
|
|
|
|
|
|
|
|
|
self.wrap_attention_blocks()
|
|
|
|
|
|
if checkpoint:
|
|
|
self.load_from_checkpoint(checkpoint)
|
|
|
|
|
|
def wrap_attention_blocks(self,scale=1.0, num_tokens=16):
|
|
|
""" Inject the IP-Adapter modules into the Transformer model """
|
|
|
sample_attn = self.transformer.transformer_blocks[0].attn
|
|
|
|
|
|
hidden_size = sample_attn.inner_dim
|
|
|
cross_attention_dim = sample_attn.cross_attention_dim
|
|
|
num_heads = sample_attn.heads
|
|
|
scale = 1.0
|
|
|
num_tokens = 16
|
|
|
|
|
|
processor_list = []
|
|
|
for name in self.culled_transformer_blocks:
|
|
|
module = self.culled_transformer_blocks[name]
|
|
|
module.processor = IPFluxAttnProcessor2_0(
|
|
|
hidden_size= hidden_size,
|
|
|
cross_attention_dim=4096,
|
|
|
num_heads=num_heads,
|
|
|
scale=1.0,
|
|
|
num_tokens=16,
|
|
|
)
|
|
|
processor_list.append(module.processor )
|
|
|
lay_count = len(processor_list)
|
|
|
print (f"Added Attention IP Wrapper to {lay_count} layers")
|
|
|
|
|
|
|
|
|
self.adapter_modules = torch.nn.ModuleList(processor_list)
|
|
|
|
|
|
def parameters(self):
|
|
|
""" Easy way to return all params """
|
|
|
|
|
|
adapter_param_list = []
|
|
|
for name in self.culled_transformer_blocks:
|
|
|
module = self.culled_transformer_blocks[name]
|
|
|
adapter_param_list.append(module.processor.parameters())
|
|
|
|
|
|
all_params = chain(*adapter_param_list,self.image_proj_model.parameters())
|
|
|
return all_params
|
|
|
|
|
|
def forward(self, ref_image, *args, layer_scale= torch.Tensor([1.0]), **kwargs):
|
|
|
""" Run projection and run forward """
|
|
|
mod_dtype = next(self.image_proj_model.parameters()).dtype
|
|
|
mod_device = next(self.image_proj_model.parameters()).device
|
|
|
|
|
|
ip_encoder_hidden_states = None
|
|
|
if ref_image != None:
|
|
|
ip_encoder_hidden_states = self.image_proj_model(ref_image)
|
|
|
|
|
|
|
|
|
if 'joint_attention_kwargs' not in kwargs:
|
|
|
kwargs['joint_attention_kwargs'] = {}
|
|
|
layer_scale = layer_scale.to(dtype=mod_dtype,
|
|
|
device=mod_device)
|
|
|
|
|
|
kwargs['joint_attention_kwargs']['ip_layer_scale'] = layer_scale
|
|
|
kwargs['joint_attention_kwargs']['ip_hidden_states'] = ip_encoder_hidden_states
|
|
|
|
|
|
output = self.transformer(*args,
|
|
|
**kwargs)
|
|
|
|
|
|
return output
|
|
|
|
|
|
def save_pretrained(self,ckpt_path):
|
|
|
""" Save model weights """
|
|
|
state_dict = {}
|
|
|
|
|
|
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
|
|
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
|
|
torch.save(state_dict, ckpt_path)
|
|
|
|
|
|
def load_from_checkpoint(self, ckpt_path):
|
|
|
""" Loader ripped from tencent repo """
|
|
|
|
|
|
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
|
|
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
|
|
|
|
|
state_dict = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
|
|
|
|
|
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
|
|
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
|
|
|
|
|
|
|
|
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
|
|
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
|
|
|
|
|
|
|
|
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
|
|
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
|
|
|
|
|
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
|
|
|
|
|
|
|
|
@property
|
|
|
def dtype(self):
|
|
|
return next(self.image_proj_model.parameters()).dtype
|
|
|
|