ZhouZJ36DL's picture
modified: src/flux/model.py
4ca21ee
import torch
from einops import rearrange
from torch import Tensor
import math
from torchvision.utils import save_image
from torchvision.io import read_image
from PIL import Image
import torchvision.transforms as transforms
def adaptive_attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, txt_shape: int, img_shape: int, cur_step:int, cur_block:int, info) -> Tensor:
q, k = apply_rope(q, k, pe)
#x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = scaled_dot_product_attention(q, k, v, txt_shape, img_shape, cur_step, cur_block, info)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def auto_mask(load_list, mask_accumulator, thre, info, mask_num = 4):
mask_list = []
for img_path in load_list:
load_mask_img = Image.open(img_path).convert('L')
# Define the transformation
transform = transforms.PILToTensor()
mask_tensor = transform(load_mask_img)
mask_tensor = mask_tensor.to(device=mask_accumulator.device, dtype=mask_accumulator.dtype) # Set device and dtype
mask_tensor /= 255.0
mask_list.append(mask_tensor) # Collect masks
# Sort masks based on their activation levels
mask_list.sort(key=lambda x: x.sum().item(), reverse=True)
# Select the 5 medium activated masks
num_masks = len(mask_list)
if num_masks > mask_num:
#selected_masks = mask_list[num_masks//2 - mask_num : num_masks//2]
start_block = info['attn_guidance']
end_block = info['attn_guidance'] + mask_num
if end_block > num_masks - 1:
selected_masks = mask_list[-mask_num: ]
else:
selected_masks = mask_list[start_block: end_block]
else:
selected_masks = mask_list
# Accumulate the selected masks
for mask in selected_masks:
mask_accumulator += mask
mask_tensor = (mask_accumulator / len(selected_masks)).to(dtype=mask_accumulator.dtype) # Average the masks and convert back to original dtype
mask_tensor[mask_tensor >= thre] = 1
mask_tensor[mask_tensor < thre] = 0
return mask_tensor
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, txt_shape, img_shape, cur_step, cur_block, info,
token_index=2, layer=range(19), attn_mask=None, dropout_p=0.0, coefficient=10, tau=0.5, thre=0.3,
is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
if enable_gqa:
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
if not info['inverse']:
# GENERATE MASK
txt_img_cross = attn_weight[:, :, -img_shape:, :txt_shape] # lower left part
# each column maps to a token's heatmap
token_heatmap = txt_img_cross[:, :, :, token_index] # Shape: [1, 24, 1024]
token_heatmap = token_heatmap.mean(dim=1)[0] # Shape: [1024]
min_val, max_val = token_heatmap.min(), token_heatmap.max()
norm_heatmap = (token_heatmap - min_val) / (max_val - min_val)
mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5))
H = W = int(math.sqrt(mask_img.size(0)))
mask_img = mask_img.reshape(H, W)
save_path = f'heatmap/step_{cur_step}_layer_{cur_block}_token{token_index}.png'
load_path = [f'heatmap/step_{cur_step-1}_layer_{i}_token{token_index}.png' for i in layer] #save_image(mask_img.unsqueeze(0), save_path)
save_image(mask_img.unsqueeze(0), save_path)
# Debug information
#print(f"[DEBUG] cur_step: {cur_step}, cur_block: {cur_block}")
#print(f"[DEBUG] norm_heatmap values:\n{norm_heatmap}")
#print(f"[DEBUG] mask_img (before thresholding) stats: min={mask_img.min().item()}, max={mask_img.max().item()}, mean={mask_img.mean().item()}")
#print(f"[DEBUG] thre value: {thre}")
#print(f"[DEBUG] mask_img (before thresholding) values:\n{mask_img}")
mask_img[mask_img >= thre] = 1
mask_img[mask_img < thre] = 0
#save_image(mask_img.unsqueeze(0), save_path)
#print(f"[DEBUG] mask_img (after thresholding) unique values: {mask_img.unique().tolist()}")
#print(f"[DEBUG] mask_img (after thresholding) values:\n{mask_img}")
mask_tensor = torch.zeros_like(mask_img) # Set mask_tensor as a zero tensor
if cur_step > 3:
mask_accumulator = torch.zeros_like(mask_tensor.unsqueeze(0), dtype=mask_img.dtype) # Accumulator for averaging masks
mask_tensor = auto_mask(load_path, mask_accumulator, thre, info, mask_num=4)
if cur_block == 1:
save_image(mask_tensor, f'heatmap/average_heatmaps/step_{cur_step}_layer_{cur_block}_token{token_index}.png')
if not torch.all(mask_tensor == 0):
highlight_factor = 2.0 # Factor to increase weights in the masked area
reduce_factor = 0.8 # Factor to decrease weights in the unmasked area
mask_tensor = mask_tensor.reshape(1, H * W)
mask_tensor = mask_tensor.unsqueeze(1).unsqueeze(-1)
# Create a multiplier tensor: 2.0 where mask is active, 0.5 where mask is inactive.
multiplier = torch.where(mask_tensor.bool(), torch.tensor(highlight_factor), torch.tensor(reduce_factor))
attn_weight[:, :, -img_shape:, :15] *= multiplier
return attn_weight @ value
'''
if cur_step == 14 and (cur_block == 2 or cur_block == 7 or cur_block == 12):
mask_img = torch.zeros_like(mask_img)
for j in range(5):
token_heatmap = txt_img_cross[:, :, :, j]
token_heatmap = token_heatmap.mean(dim=1)[0]
min_val, max_val = token_heatmap.min(), token_heatmap.max()
norm_heatmap = (token_heatmap - min_val) / (max_val - min_val)
mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5))
H = W = int(math.sqrt(mask_img.size(0)))
mask_img = mask_img.reshape(H, W)
save_path = f'/home/hfle/personalization/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing/heatmap/step_{cur_step}_layer_{cur_block}_token{j}.png'
save_image(mask_img.unsqueeze(0), save_path)'''