Text-to-Image
Diffusers
Safetensors
LibreFluxIPAdapterPipeline
File size: 13,580 Bytes
1600698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
from itertools import chain
import torch
from torch import nn
from diffusers.models.attention_processor import (
    Attention,
    AttentionProcessor,
)

from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
import torch.nn.functional as F
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.attention_processor import Attention
import inspect
from functools import partial
from diffusers.models.normalization import RMSNorm
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn


class IPFluxAttnProcessor2_0(nn.Module):
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, num_heads=0):
        super().__init__()

        self.hidden_size = hidden_size 
        self.cross_attention_dim = cross_attention_dim 
        self.scale = scale
        self.num_tokens = num_tokens

        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)

        self.norm_added_k = RMSNorm(128, eps=1e-5, elementwise_affine=False)

    def __call__(

        self,

        attn,

        hidden_states: torch.FloatTensor,

        encoder_hidden_states: torch.FloatTensor = None,

        ip_encoder_hidden_states: torch.FloatTensor = None,

        attention_mask: Optional[torch.FloatTensor] = None,

        image_rotary_emb: Optional[torch.Tensor] = None,

        layer_scale: Optional[torch.Tensor] = None,

    ) -> torch.FloatTensor:
        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        
        ip_hidden_states = ip_encoder_hidden_states
        
        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # handle IP attention FIRST


        # for ip-adapter
        if ip_hidden_states != None:
            ip_key = self.to_k_ip(ip_hidden_states)
            ip_value = self.to_v_ip(ip_hidden_states)

            # reshaping to match query shape
            ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

            ip_key = self.norm_added_k(ip_key)


            # Using flux stype attention here
            ip_hidden_states = F.scaled_dot_product_attention(
                query,
                ip_key,
                ip_value,
                dropout_p=0.0,
                is_causal=False,
                attn_mask=None,
            )

            # reshaping ip_hidden_states in the same way as hidden_states
            ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
                batch_size, -1, attn.heads * head_dim
            )

        # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
        if encoder_hidden_states is not None:
            # `context` projections.
            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)

            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

            encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)

            if attn.norm_added_q is not None:
                encoder_hidden_states_query_proj = attn.norm_added_q(
                    encoder_hidden_states_query_proj
                )
            if attn.norm_added_k is not None:
                encoder_hidden_states_key_proj = attn.norm_added_k(
                    encoder_hidden_states_key_proj
                )

            # attention
            query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

        if image_rotary_emb is not None:
            from diffusers.models.embeddings import apply_rotary_emb
            query = apply_rotary_emb(query, image_rotary_emb)

            key = apply_rotary_emb(key, image_rotary_emb)

        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = (attention_mask > 0).bool()
            attention_mask = attention_mask.to(
                device=hidden_states.device, dtype=query.dtype
            )
        original_hidden_states = hidden_states

        hidden_states = F.scaled_dot_product_attention(
            query,
            key,
            value,
            dropout_p=0.0,
            is_causal=False,
            attn_mask=attention_mask,
        )
        
        hidden_states = hidden_states.transpose(1, 2).reshape(
            batch_size, -1, attn.heads * head_dim
        )
        hidden_states = hidden_states.to(query.dtype)


        layer_scale = layer_scale.view(-1, 1, 1)

        if encoder_hidden_states is not None:

                    encoder_hidden_states, hidden_states = (
                        hidden_states[:, : encoder_hidden_states.shape[1]],
                        hidden_states[:, encoder_hidden_states.shape[1] :],
                    )
                
                    # Final injection of ip addapter hidden_states
                    if ip_hidden_states != None:
                      hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states

                    # linear proj
                    hidden_states = attn.to_out[0](hidden_states)
                    # dropout
                    hidden_states = attn.to_out[1](hidden_states)
                    encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

                    return hidden_states, encoder_hidden_states

        else:
            
                    # Final injection of ip addapter hidden_states
                    if ip_hidden_states != None:
                      hidden_states = hidden_states + (self.scale * layer_scale) * ip_hidden_states

                    if attn.to_out is not None:
                        hidden_states = attn.to_out[0](hidden_states)
                        hidden_states = attn.to_out[1](hidden_states)

                    return hidden_states


class ImageProjModel(nn.Module):
    def __init__(self, clip_dim=768, cross_attention_dim=4096, num_tokens=16):
        super().__init__()

        self.num_tokens = num_tokens
        self.cross_attention_dim = cross_attention_dim
        self.clip_dim = clip_dim

        self.proj = torch.nn.Sequential(
            torch.nn.Linear(clip_dim,clip_dim*2),
            torch.nn.GELU(),
            torch.nn.Linear(clip_dim*2, cross_attention_dim*num_tokens),
        )        
        self.norm = torch.nn.LayerNorm(cross_attention_dim)
    
    def forward(self,input):
        
        raw_proj = self.proj(input)
        reshaped_proj = raw_proj.reshape(input.shape[0],self.num_tokens,self.cross_attention_dim)
        reshaped_proj = self.norm( reshaped_proj )

        return reshaped_proj


class LibreFluxIPAdapter(nn.Module):
    def __init__(self, transformer, image_proj_model, checkpoint=None):
        super().__init__()
        self.transformer = transformer
        self.image_proj_model = image_proj_model

        # Using startswith uses only double transformer blocks, and skips the single transformer blocks
        self.culled_transformer_blocks = {}
        for name, module in self.transformer.named_modules():
            if isinstance(module, Attention):
                if name.startswith('transformer_blocks') or name.startswith('single_transformer_blocks'):
                    #print (f"Using Transformer: {name}")
                    self.culled_transformer_blocks[name] = module
                #else:
                    #print (f"Ignoring Transformer: {name}")
        # Apply the adapter to the culled blocks
        self.wrap_attention_blocks()
        
        if checkpoint:
            self.load_from_checkpoint(checkpoint)

    def wrap_attention_blocks(self,scale=1.0, num_tokens=16):
        """ Inject the IP-Adapter modules into the Transformer model """
        sample_attn = self.transformer.transformer_blocks[0].attn

        hidden_size = sample_attn.inner_dim
        cross_attention_dim = sample_attn.cross_attention_dim
        num_heads = sample_attn.heads
        scale = 1.0
        num_tokens = 16
 
        processor_list = []
        for name in self.culled_transformer_blocks:
            module = self.culled_transformer_blocks[name]
            module.processor = IPFluxAttnProcessor2_0(
                    hidden_size= hidden_size,
                    cross_attention_dim=4096,
                    num_heads=num_heads,
                    scale=1.0,
                    num_tokens=16,
                )
            processor_list.append(module.processor )
        lay_count = len(processor_list)
        print (f"Added Attention IP Wrapper to {lay_count} layers")

        # Store adapters as a module list for saving/loading
        self.adapter_modules = torch.nn.ModuleList(processor_list)
        
    def parameters(self):
        """ Easy way to return all params """
        # Apply adapter
        adapter_param_list = []
        for name in self.culled_transformer_blocks:
            module = self.culled_transformer_blocks[name]            
            adapter_param_list.append(module.processor.parameters())
                    
        all_params = chain(*adapter_param_list,self.image_proj_model.parameters())
        return all_params

    def forward(self, ref_image, *args, layer_scale= torch.Tensor([1.0]), **kwargs):
        """ Run projection and run forward """
        mod_dtype = next(self.image_proj_model.parameters()).dtype
        mod_device = next(self.image_proj_model.parameters()).device

        ip_encoder_hidden_states = None
        if ref_image != None:
          ip_encoder_hidden_states = self.image_proj_model(ref_image)

        # Add ip hidden states to kwargs
        if 'joint_attention_kwargs' not in kwargs:
            kwargs['joint_attention_kwargs'] = {}
        layer_scale = layer_scale.to(dtype=mod_dtype,
        device=mod_device)   

        kwargs['joint_attention_kwargs']['ip_layer_scale'] = layer_scale
        kwargs['joint_attention_kwargs']['ip_hidden_states'] = ip_encoder_hidden_states

        output = self.transformer(*args,
                **kwargs)

        return output

    def save_pretrained(self,ckpt_path):
        """ Save model weights """
        state_dict = {}

        state_dict["image_proj"] = self.image_proj_model.state_dict()
        state_dict["ip_adapter"] = self.adapter_modules.state_dict()
        torch.save(state_dict, ckpt_path)

    def load_from_checkpoint(self, ckpt_path):
        """ Loader ripped from tencent repo """
        # Calculate original checksums
        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        state_dict = torch.load(ckpt_path, map_location="cpu")

        # Load state dict for image_proj_model and adapter_modules
        self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
        self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)

        # Calculate new checksums
        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        # Verify if the weights have changed
        assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
        assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"

        print(f"Successfully loaded weights from checkpoint {ckpt_path}")


    @property
    def dtype(self):
        return next(self.image_proj_model.parameters()).dtype