File size: 13,580 Bytes
1600698 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
from itertools import chain
import torch
from torch import nn
from diffusers.models.attention_processor import (
Attention,
AttentionProcessor,
)
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
import torch.nn.functional as F
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.attention_processor import Attention
import inspect
from functools import partial
from diffusers.models.normalization import RMSNorm
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
class IPFluxAttnProcessor2_0(nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, num_heads=0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)
def __call__(
self,
attn,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
ip_encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
layer_scale: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
ip_hidden_states = ip_encoder_hidden_states
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# handle IP attention FIRST
# for ip-adapter
if ip_hidden_states != None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
# reshaping to match query shape
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = self.norm_added_k(ip_key)
# Using flux stype attention here
ip_hidden_states = F.scaled_dot_product_attention(
query,
ip_key,
ip_value,
dropout_p=0.0,
is_causal=False,
attn_mask=None,
)
# reshaping ip_hidden_states in the same way as hidden_states
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(
encoder_hidden_states_query_proj
)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(
encoder_hidden_states_key_proj
)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (attention_mask > 0).bool()
attention_mask = attention_mask.to(
device=hidden_states.device, dtype=query.dtype
)
original_hidden_states = hidden_states
hidden_states = F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=0.0,
is_causal=False,
attn_mask=attention_mask,
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
layer_scale = layer_scale.view(-1, 1, 1)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Final injection of ip addapter hidden_states
if ip_hidden_states != None:
hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
# Final injection of ip addapter hidden_states
if ip_hidden_states != None:
hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states
if attn.to_out is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class ImageProjModel(nn.Module):
def __init__(self, clip_dim=768, cross_attention_dim=4096, num_tokens=16):
super().__init__()
self.num_tokens = num_tokens
self.cross_attention_dim = cross_attention_dim
self.clip_dim = clip_dim
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_dim,clip_dim*2),
torch.nn.GELU(),
torch.nn.Linear(clip_dim*2, cross_attention_dim*num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self,input):
raw_proj = self.proj(input)
reshaped_proj = raw_proj.reshape(input.shape[0],self.num_tokens,self.cross_attention_dim)
reshaped_proj = self.norm( reshaped_proj )
return reshaped_proj
class LibreFluxIPAdapter(nn.Module):
def __init__(self, transformer, image_proj_model, checkpoint=None):
super().__init__()
self.transformer = transformer
self.image_proj_model = image_proj_model
# Using startswith uses only double transformer blocks, and skips the single transformer blocks
self.culled_transformer_blocks = {}
for name, module in self.transformer.named_modules():
if isinstance(module, Attention):
if name.startswith('transformer_blocks') or name.startswith('single_transformer_blocks'):
#print (f"Using Transformer: {name}")
self.culled_transformer_blocks[name] = module
#else:
#print (f"Ignoring Transformer: {name}")
# Apply the adapter to the culled blocks
self.wrap_attention_blocks()
if checkpoint:
self.load_from_checkpoint(checkpoint)
def wrap_attention_blocks(self,scale=1.0, num_tokens=16):
""" Inject the IP-Adapter modules into the Transformer model """
sample_attn = self.transformer.transformer_blocks[0].attn
hidden_size = sample_attn.inner_dim
cross_attention_dim = sample_attn.cross_attention_dim
num_heads = sample_attn.heads
scale = 1.0
num_tokens = 16
processor_list = []
for name in self.culled_transformer_blocks:
module = self.culled_transformer_blocks[name]
module.processor = IPFluxAttnProcessor2_0(
hidden_size= hidden_size,
cross_attention_dim=4096,
num_heads=num_heads,
scale=1.0,
num_tokens=16,
)
processor_list.append(module.processor )
lay_count = len(processor_list)
print (f"Added Attention IP Wrapper to {lay_count} layers")
# Store adapters as a module list for saving/loading
self.adapter_modules = torch.nn.ModuleList(processor_list)
def parameters(self):
""" Easy way to return all params """
# Apply adapter
adapter_param_list = []
for name in self.culled_transformer_blocks:
module = self.culled_transformer_blocks[name]
adapter_param_list.append(module.processor.parameters())
all_params = chain(*adapter_param_list,self.image_proj_model.parameters())
return all_params
def forward(self, ref_image, *args, layer_scale= torch.Tensor([1.0]), **kwargs):
""" Run projection and run forward """
mod_dtype = next(self.image_proj_model.parameters()).dtype
mod_device = next(self.image_proj_model.parameters()).device
ip_encoder_hidden_states = None
if ref_image != None:
ip_encoder_hidden_states = self.image_proj_model(ref_image)
# Add ip hidden states to kwargs
if 'joint_attention_kwargs' not in kwargs:
kwargs['joint_attention_kwargs'] = {}
layer_scale = layer_scale.to(dtype=mod_dtype,
device=mod_device)
kwargs['joint_attention_kwargs']['ip_layer_scale'] = layer_scale
kwargs['joint_attention_kwargs']['ip_hidden_states'] = ip_encoder_hidden_states
output = self.transformer(*args,
**kwargs)
return output
def save_pretrained(self,ckpt_path):
""" Save model weights """
state_dict = {}
state_dict["image_proj"] = self.image_proj_model.state_dict()
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
torch.save(state_dict, ckpt_path)
def load_from_checkpoint(self, ckpt_path):
""" Loader ripped from tencent repo """
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
@property
def dtype(self):
return next(self.image_proj_model.parameters()).dtype
|