File size: 3,770 Bytes
4185a37
 
 
 
 
 
 
 
 
cf0a3d5
4185a37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d761edd
4185a37
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from typing import Callable, List, Optional, Tuple, Union

class FluxBlendedAttnProcessor2_0(nn.Module):
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self, hidden_dim, ba_scale=1.0, num_ref=1):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("FluxBlendedAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
        
        self.blended_attention_k_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.blended_attention_v_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.ba_scale = ba_scale
        self.num_ref = num_ref

    def __call__(
        self,
        attn, #: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
        is_negative_prompt: bool = False
    ) -> torch.FloatTensor:
        assert encoder_hidden_states is None, "It should be given as None because we are applying it-blender only to the single streams."
        batch_size, _, _ = hidden_states.shape

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            normalized_query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        if image_rotary_emb is not None:
            from diffusers.models.embeddings import apply_rotary_emb

            query = apply_rotary_emb(normalized_query, image_rotary_emb)
            key = apply_rotary_emb(key, image_rotary_emb)
        

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)


        # [noisy, clean]
        chunk = batch_size//(1+self.num_ref)
        ba_query = normalized_query[:chunk]  # noisy query

        ba_key = self.blended_attention_k_proj(hidden_states[chunk:]) # clean key
        ba_value = self.blended_attention_v_proj(hidden_states[chunk:]) # clean value

        ba_key = ba_key.view(chunk, -1, attn.heads, head_dim).transpose(1, 2) # the -1 is gonna be multiplied by self.num_ref
        ba_value = ba_value.view(chunk, -1, attn.heads, head_dim).transpose(1, 2)

        ba_hidden_states = F.scaled_dot_product_attention(
            ba_query, ba_key, ba_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False#, scale=(1 / math.sqrt(ba_query.size(-1)))*self.temperature if self.num_ref > 1 else 1 / math.sqrt(ba_query.size(-1))
        )

        ba_hidden_states = ba_hidden_states.transpose(1, 2).reshape(chunk, -1, attn.heads * head_dim)
        ba_hidden_states = ba_hidden_states.to(query.dtype)

        zero_tensor_list = [torch.zeros_like(ba_hidden_states)]*self.num_ref
        ba_hidden_states = torch.cat([ba_hidden_states]+zero_tensor_list, dim=0)    
            
        hidden_states = hidden_states + self.ba_scale * ba_hidden_states
        
        return hidden_states