Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,193 Bytes
f0e942d 7f77208 4ca21ee 5700f1e f0e942d 1aafc08 f0e942d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import torch
from einops import rearrange
from torch import Tensor
import math
from torchvision.utils import save_image
from torchvision.io import read_image
from PIL import Image
import torchvision.transforms as transforms
def adaptive_attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, txt_shape: int, img_shape: int, cur_step:int, cur_block:int, info) -> Tensor:
q, k = apply_rope(q, k, pe)
#x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = scaled_dot_product_attention(q, k, v, txt_shape, img_shape, cur_step, cur_block, info)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def auto_mask(load_list, mask_accumulator, thre, info, mask_num = 4):
mask_list = []
for img_path in load_list:
load_mask_img = Image.open(img_path).convert('L')
# Define the transformation
transform = transforms.PILToTensor()
mask_tensor = transform(load_mask_img)
mask_tensor = mask_tensor.to(device=mask_accumulator.device, dtype=mask_accumulator.dtype) # Set device and dtype
mask_tensor /= 255.0
mask_list.append(mask_tensor) # Collect masks
# Sort masks based on their activation levels
mask_list.sort(key=lambda x: x.sum().item(), reverse=True)
# Select the 5 medium activated masks
num_masks = len(mask_list)
if num_masks > mask_num:
#selected_masks = mask_list[num_masks//2 - mask_num : num_masks//2]
start_block = info['attn_guidance']
end_block = info['attn_guidance'] + mask_num
if end_block > num_masks - 1:
selected_masks = mask_list[-mask_num: ]
else:
selected_masks = mask_list[start_block: end_block]
else:
selected_masks = mask_list
# Accumulate the selected masks
for mask in selected_masks:
mask_accumulator += mask
mask_tensor = (mask_accumulator / len(selected_masks)).to(dtype=mask_accumulator.dtype) # Average the masks and convert back to original dtype
mask_tensor[mask_tensor >= thre] = 1
mask_tensor[mask_tensor < thre] = 0
return mask_tensor
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, txt_shape, img_shape, cur_step, cur_block, info,
token_index=2, layer=range(19), attn_mask=None, dropout_p=0.0, coefficient=10, tau=0.5, thre=0.3,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
if not info['inverse']:
# GENERATE MASK
txt_img_cross = attn_weight[:, :, -img_shape:, :txt_shape] # lower left part
# each column maps to a token's heatmap
token_heatmap = txt_img_cross[:, :, :, token_index] # Shape: [1, 24, 1024]
token_heatmap = token_heatmap.mean(dim=1)[0] # Shape: [1024]
min_val, max_val = token_heatmap.min(), token_heatmap.max()
norm_heatmap = (token_heatmap - min_val) / (max_val - min_val)
mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5))
H = W = int(math.sqrt(mask_img.size(0)))
mask_img = mask_img.reshape(H, W)
save_path = f'heatmap/step_{cur_step}_layer_{cur_block}_token{token_index}.png'
load_path = [f'heatmap/step_{cur_step-1}_layer_{i}_token{token_index}.png' for i in layer] #save_image(mask_img.unsqueeze(0), save_path)
save_image(mask_img.unsqueeze(0), save_path)
# Debug information
#print(f"[DEBUG] cur_step: {cur_step}, cur_block: {cur_block}")
#print(f"[DEBUG] norm_heatmap values:\n{norm_heatmap}")
#print(f"[DEBUG] mask_img (before thresholding) stats: min={mask_img.min().item()}, max={mask_img.max().item()}, mean={mask_img.mean().item()}")
#print(f"[DEBUG] thre value: {thre}")
#print(f"[DEBUG] mask_img (before thresholding) values:\n{mask_img}")
mask_img[mask_img >= thre] = 1
mask_img[mask_img < thre] = 0
#save_image(mask_img.unsqueeze(0), save_path)
#print(f"[DEBUG] mask_img (after thresholding) unique values: {mask_img.unique().tolist()}")
#print(f"[DEBUG] mask_img (after thresholding) values:\n{mask_img}")
mask_tensor = torch.zeros_like(mask_img) # Set mask_tensor as a zero tensor
if cur_step > 3:
mask_accumulator = torch.zeros_like(mask_tensor.unsqueeze(0), dtype=mask_img.dtype) # Accumulator for averaging masks
mask_tensor = auto_mask(load_path, mask_accumulator, thre, info, mask_num=4)
if cur_block == 1:
save_image(mask_tensor, f'heatmap/average_heatmaps/step_{cur_step}_layer_{cur_block}_token{token_index}.png')
if not torch.all(mask_tensor == 0):
highlight_factor = 2.0 # Factor to increase weights in the masked area
reduce_factor = 0.8 # Factor to decrease weights in the unmasked area
mask_tensor = mask_tensor.reshape(1, H * W)
mask_tensor = mask_tensor.unsqueeze(1).unsqueeze(-1)
# Create a multiplier tensor: 2.0 where mask is active, 0.5 where mask is inactive.
multiplier = torch.where(mask_tensor.bool(), torch.tensor(highlight_factor), torch.tensor(reduce_factor))
attn_weight[:, :, -img_shape:, :15] *= multiplier
return attn_weight @ value
'''
if cur_step == 14 and (cur_block == 2 or cur_block == 7 or cur_block == 12):
mask_img = torch.zeros_like(mask_img)
for j in range(5):
token_heatmap = txt_img_cross[:, :, :, j]
token_heatmap = token_heatmap.mean(dim=1)[0]
min_val, max_val = token_heatmap.min(), token_heatmap.max()
norm_heatmap = (token_heatmap - min_val) / (max_val - min_val)
mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5))
H = W = int(math.sqrt(mask_img.size(0)))
mask_img = mask_img.reshape(H, W)
save_path = f'/home/hfle/personalization/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing/heatmap/step_{cur_step}_layer_{cur_block}_token{j}.png'
save_image(mask_img.unsqueeze(0), save_path)'''
|