File size: 8,193 Bytes
f0e942d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f77208
4ca21ee
5700f1e
 
 
 
 
f0e942d
 
 
 
1aafc08
 
f0e942d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import torch
from einops import rearrange
from torch import Tensor
import math
from torchvision.utils import save_image
from torchvision.io import read_image
from PIL import Image
import torchvision.transforms as transforms


def adaptive_attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, txt_shape: int, img_shape: int, cur_step:int, cur_block:int, info) -> Tensor:
    q, k = apply_rope(q, k, pe)

    #x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    x = scaled_dot_product_attention(q, k, v, txt_shape, img_shape, cur_step, cur_block, info)
    x = rearrange(x, "B H L D -> B L (H D)")

    return x


def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
    q, k = apply_rope(q, k, pe)

    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    x = rearrange(x, "B H L D -> B L (H D)")

    return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
    assert dim % 2 == 0
    scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
    omega = 1.0 / (theta**scale)
    out = torch.einsum("...n,d->...nd", pos, omega)
    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    return out.float()


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


def auto_mask(load_list, mask_accumulator, thre, info, mask_num = 4):

    mask_list = []
    for img_path in load_list:
        load_mask_img = Image.open(img_path).convert('L')
        # Define the transformation
        transform = transforms.PILToTensor()
        mask_tensor = transform(load_mask_img)
        mask_tensor = mask_tensor.to(device=mask_accumulator.device, dtype=mask_accumulator.dtype)  # Set device and dtype
        mask_tensor /= 255.0
        mask_list.append(mask_tensor)  # Collect masks

    # Sort masks based on their activation levels
    mask_list.sort(key=lambda x: x.sum().item(), reverse=True)
    # Select the 5 medium activated masks
    num_masks = len(mask_list)
    if num_masks > mask_num:
        #selected_masks = mask_list[num_masks//2 - mask_num : num_masks//2]
        start_block = info['attn_guidance']
        end_block = info['attn_guidance'] + mask_num
        if end_block > num_masks - 1:
            selected_masks = mask_list[-mask_num: ]
        else:
            selected_masks = mask_list[start_block: end_block]
    else:
        selected_masks = mask_list

     # Accumulate the selected masks
    for mask in selected_masks:
        mask_accumulator += mask

    mask_tensor = (mask_accumulator / len(selected_masks)).to(dtype=mask_accumulator.dtype)  # Average the masks and convert back to original dtype
    mask_tensor[mask_tensor >= thre] = 1
    mask_tensor[mask_tensor < thre] = 0

    return mask_tensor


# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, txt_shape, img_shape, cur_step, cur_block, info, 
        token_index=2, layer=range(19), attn_mask=None, dropout_p=0.0, coefficient=10, tau=0.5, thre=0.3, 
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)

    if not info['inverse']:
        # GENERATE MASK
        txt_img_cross = attn_weight[:, :, -img_shape:, :txt_shape]  # lower left part
        # each column maps to a token's heatmap
        token_heatmap = txt_img_cross[:, :, :, token_index]  # Shape: [1, 24, 1024]    
        token_heatmap = token_heatmap.mean(dim=1)[0]  # Shape: [1024]
        min_val, max_val = token_heatmap.min(), token_heatmap.max()
        norm_heatmap = (token_heatmap - min_val) / (max_val - min_val)

        mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5))
        
        H = W = int(math.sqrt(mask_img.size(0)))
        mask_img = mask_img.reshape(H, W)

        save_path = f'heatmap/step_{cur_step}_layer_{cur_block}_token{token_index}.png'
        load_path = [f'heatmap/step_{cur_step-1}_layer_{i}_token{token_index}.png' for i in layer]        #save_image(mask_img.unsqueeze(0), save_path)
        save_image(mask_img.unsqueeze(0), save_path)
        # Debug information
        #print(f"[DEBUG] cur_step: {cur_step}, cur_block: {cur_block}")

        #print(f"[DEBUG] norm_heatmap values:\n{norm_heatmap}")
        #print(f"[DEBUG] mask_img (before thresholding) stats: min={mask_img.min().item()}, max={mask_img.max().item()}, mean={mask_img.mean().item()}")
        #print(f"[DEBUG] thre value: {thre}")
        #print(f"[DEBUG] mask_img (before thresholding) values:\n{mask_img}")

        mask_img[mask_img >= thre] = 1
        mask_img[mask_img < thre] = 0
        #save_image(mask_img.unsqueeze(0), save_path)
        #print(f"[DEBUG] mask_img (after thresholding) unique values: {mask_img.unique().tolist()}")
        #print(f"[DEBUG] mask_img (after thresholding) values:\n{mask_img}")

        mask_tensor = torch.zeros_like(mask_img)  # Set mask_tensor as a zero tensor
        if cur_step > 3:
            mask_accumulator = torch.zeros_like(mask_tensor.unsqueeze(0), dtype=mask_img.dtype)  # Accumulator for averaging masks
            mask_tensor = auto_mask(load_path, mask_accumulator, thre, info, mask_num=4)
            if cur_block == 1:
                save_image(mask_tensor, f'heatmap/average_heatmaps/step_{cur_step}_layer_{cur_block}_token{token_index}.png')


        if not torch.all(mask_tensor == 0):
            highlight_factor = 2.0  # Factor to increase weights in the masked area
            reduce_factor = 0.8  # Factor to decrease weights in the unmasked area

            mask_tensor = mask_tensor.reshape(1, H * W)
            mask_tensor = mask_tensor.unsqueeze(1).unsqueeze(-1)
            # Create a multiplier tensor: 2.0 where mask is active, 0.5 where mask is inactive.
            multiplier = torch.where(mask_tensor.bool(), torch.tensor(highlight_factor), torch.tensor(reduce_factor))
            attn_weight[:, :, -img_shape:, :15] *= multiplier

    return attn_weight @ value

'''
    if cur_step == 14 and (cur_block == 2 or cur_block == 7 or cur_block == 12):
        mask_img = torch.zeros_like(mask_img)
        for j in range(5):
            token_heatmap = txt_img_cross[:, :, :, j]
            token_heatmap = token_heatmap.mean(dim=1)[0]
            min_val, max_val = token_heatmap.min(), token_heatmap.max()
            norm_heatmap = (token_heatmap - min_val) / (max_val - min_val)

            mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5)) 
            
            H = W = int(math.sqrt(mask_img.size(0)))
            mask_img = mask_img.reshape(H, W)
            save_path = f'/home/hfle/personalization/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing/heatmap/step_{cur_step}_layer_{cur_block}_token{j}.png'
            save_image(mask_img.unsqueeze(0), save_path)'''