Text-to-Image
Diffusers
Safetensors
LibreFluxIPAdapterPipeline
LibreFlux-IP-Adapter-ControlNet / flux_ip_adapter.py
neuralvfx's picture
Upload folder using huggingface_hub
1600698 verified
from itertools import chain
import torch
from torch import nn
from diffusers.models.attention_processor import (
Attention,
AttentionProcessor,
)
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
import torch.nn.functional as F
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.attention_processor import Attention
import inspect
from functools import partial
from diffusers.models.normalization import RMSNorm
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
class IPFluxAttnProcessor2_0(nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, num_heads=0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
def __call__(
self,
attn,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
ip_encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
layer_scale: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
ip_hidden_states = ip_encoder_hidden_states
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# handle IP attention FIRST
# for ip-adapter
if ip_hidden_states != None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
# reshaping to match query shape
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = self.norm_added_k(ip_key)
# Using flux stype attention here
ip_hidden_states = F.scaled_dot_product_attention(
query,
ip_key,
ip_value,
dropout_p=0.0,
is_causal=False,
attn_mask=None,
)
# reshaping ip_hidden_states in the same way as hidden_states
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (attention_mask > 0).bool()
attention_mask = attention_mask.to(
device=hidden_states.device, dtype=query.dtype
)
original_hidden_states = hidden_states
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=0.0,
is_causal=False,
attn_mask=attention_mask,
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
layer_scale = layer_scale.view(-1, 1, 1)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Final injection of ip addapter hidden_states
if ip_hidden_states != None:
hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
# Final injection of ip addapter hidden_states
if ip_hidden_states != None:
hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
if attn.to_out is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class ImageProjModel(nn.Module):
def __init__(self, clip_dim=768, cross_attention_dim=4096, num_tokens=16):
super().__init__()
self.num_tokens = num_tokens
self.cross_attention_dim = cross_attention_dim
self.clip_dim = clip_dim
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_dim,clip_dim*2),
torch.nn.GELU(),
torch.nn.Linear(clip_dim*2, cross_attention_dim*num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self,input):
raw_proj = self.proj(input)
reshaped_proj = raw_proj.reshape(input.shape[0],self.num_tokens,self.cross_attention_dim)
reshaped_proj = self.norm( reshaped_proj )
return reshaped_proj
class LibreFluxIPAdapter(nn.Module):
def __init__(self, transformer, image_proj_model, checkpoint=None):
super().__init__()
self.transformer = transformer
self.image_proj_model = image_proj_model
# Using startswith uses only double transformer blocks, and skips the single transformer blocks
self.culled_transformer_blocks = {}
for name, module in self.transformer.named_modules():
if isinstance(module, Attention):
if name.startswith('transformer_blocks') or name.startswith('single_transformer_blocks'):
#print (f"Using Transformer: {name}")
self.culled_transformer_blocks[name] = module
#else:
#print (f"Ignoring Transformer: {name}")
# Apply the adapter to the culled blocks
self.wrap_attention_blocks()
if checkpoint:
self.load_from_checkpoint(checkpoint)
def wrap_attention_blocks(self,scale=1.0, num_tokens=16):
""" Inject the IP-Adapter modules into the Transformer model """
sample_attn = self.transformer.transformer_blocks[0].attn
hidden_size = sample_attn.inner_dim
cross_attention_dim = sample_attn.cross_attention_dim
num_heads = sample_attn.heads
scale = 1.0
num_tokens = 16
processor_list = []
for name in self.culled_transformer_blocks:
module = self.culled_transformer_blocks[name]
module.processor = IPFluxAttnProcessor2_0(
hidden_size= hidden_size,
cross_attention_dim=4096,
num_heads=num_heads,
scale=1.0,
num_tokens=16,
)
processor_list.append(module.processor )
lay_count = len(processor_list)
print (f"Added Attention IP Wrapper to {lay_count} layers")
# Store adapters as a module list for saving/loading
self.adapter_modules = torch.nn.ModuleList(processor_list)
def parameters(self):
""" Easy way to return all params """
# Apply adapter
adapter_param_list = []
for name in self.culled_transformer_blocks:
module = self.culled_transformer_blocks[name]
adapter_param_list.append(module.processor.parameters())
all_params = chain(*adapter_param_list,self.image_proj_model.parameters())
return all_params
def forward(self, ref_image, *args, layer_scale= torch.Tensor([1.0]), **kwargs):
""" Run projection and run forward """
mod_dtype = next(self.image_proj_model.parameters()).dtype
mod_device = next(self.image_proj_model.parameters()).device
ip_encoder_hidden_states = None
if ref_image != None:
ip_encoder_hidden_states = self.image_proj_model(ref_image)
# Add ip hidden states to kwargs
if 'joint_attention_kwargs' not in kwargs:
kwargs['joint_attention_kwargs'] = {}
layer_scale = layer_scale.to(dtype=mod_dtype,
device=mod_device)
kwargs['joint_attention_kwargs']['ip_layer_scale'] = layer_scale
kwargs['joint_attention_kwargs']['ip_hidden_states'] = ip_encoder_hidden_states
output = self.transformer(*args,
**kwargs)
return output
def save_pretrained(self,ckpt_path):
""" Save model weights """
state_dict = {}
state_dict["image_proj"] = self.image_proj_model.state_dict()
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
torch.save(state_dict, ckpt_path)
def load_from_checkpoint(self, ckpt_path):
""" Loader ripped from tencent repo """
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
@property
def dtype(self):
return next(self.image_proj_model.parameters()).dtype