Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from einops import rearrange | |
| from torch import Tensor | |
| import math | |
| from torchvision.utils import save_image | |
| from torchvision.io import read_image | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| def adaptive_attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, txt_shape: int, img_shape: int, cur_step:int, cur_block:int, info) -> Tensor: | |
| q, k = apply_rope(q, k, pe) | |
| #x = torch.nn.functional.scaled_dot_product_attention(q, k, v) | |
| x = scaled_dot_product_attention(q, k, v, txt_shape, img_shape, cur_step, cur_block, info) | |
| x = rearrange(x, "B H L D -> B L (H D)") | |
| return x | |
| def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: | |
| q, k = apply_rope(q, k, pe) | |
| x = torch.nn.functional.scaled_dot_product_attention(q, k, v) | |
| x = rearrange(x, "B H L D -> B L (H D)") | |
| return x | |
| def rope(pos: Tensor, dim: int, theta: int) -> Tensor: | |
| assert dim % 2 == 0 | |
| scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim | |
| omega = 1.0 / (theta**scale) | |
| out = torch.einsum("...n,d->...nd", pos, omega) | |
| out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) | |
| out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | |
| return out.float() | |
| def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: | |
| xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | |
| xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | |
| xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | |
| xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | |
| return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) | |
| def auto_mask(load_list, mask_accumulator, thre, info, mask_num = 4): | |
| mask_list = [] | |
| for img_path in load_list: | |
| load_mask_img = Image.open(img_path).convert('L') | |
| # Define the transformation | |
| transform = transforms.PILToTensor() | |
| mask_tensor = transform(load_mask_img) | |
| mask_tensor = mask_tensor.to(device=mask_accumulator.device, dtype=mask_accumulator.dtype) # Set device and dtype | |
| mask_tensor /= 255.0 | |
| mask_list.append(mask_tensor) # Collect masks | |
| # Sort masks based on their activation levels | |
| mask_list.sort(key=lambda x: x.sum().item(), reverse=True) | |
| # Select the 5 medium activated masks | |
| num_masks = len(mask_list) | |
| if num_masks > mask_num: | |
| #selected_masks = mask_list[num_masks//2 - mask_num : num_masks//2] | |
| start_block = info['attn_guidance'] | |
| end_block = info['attn_guidance'] + mask_num | |
| if end_block > num_masks - 1: | |
| selected_masks = mask_list[-mask_num: ] | |
| else: | |
| selected_masks = mask_list[start_block: end_block] | |
| else: | |
| selected_masks = mask_list | |
| # Accumulate the selected masks | |
| for mask in selected_masks: | |
| mask_accumulator += mask | |
| mask_tensor = (mask_accumulator / len(selected_masks)).to(dtype=mask_accumulator.dtype) # Average the masks and convert back to original dtype | |
| mask_tensor[mask_tensor >= thre] = 1 | |
| mask_tensor[mask_tensor < thre] = 0 | |
| return mask_tensor | |
| # Efficient implementation equivalent to the following: | |
| def scaled_dot_product_attention(query, key, value, txt_shape, img_shape, cur_step, cur_block, info, | |
| token_index=2, layer=range(19), attn_mask=None, dropout_p=0.0, coefficient=10, tau=0.5, thre=0.3, | |
| is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: | |
| L, S = query.size(-2), key.size(-2) | |
| scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | |
| attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda() | |
| if is_causal: | |
| assert attn_mask is None | |
| temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) | |
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | |
| attn_bias.to(query.dtype) | |
| if attn_mask is not None: | |
| if attn_mask.dtype == torch.bool: | |
| attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | |
| else: | |
| attn_bias += attn_mask | |
| if enable_gqa: | |
| key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) | |
| value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) | |
| attn_weight = query @ key.transpose(-2, -1) * scale_factor | |
| attn_weight += attn_bias | |
| attn_weight = torch.softmax(attn_weight, dim=-1) | |
| attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | |
| if not info['inverse']: | |
| # GENERATE MASK | |
| txt_img_cross = attn_weight[:, :, -img_shape:, :txt_shape] # lower left part | |
| # each column maps to a token's heatmap | |
| token_heatmap = txt_img_cross[:, :, :, token_index] # Shape: [1, 24, 1024] | |
| token_heatmap = token_heatmap.mean(dim=1)[0] # Shape: [1024] | |
| min_val, max_val = token_heatmap.min(), token_heatmap.max() | |
| norm_heatmap = (token_heatmap - min_val) / (max_val - min_val) | |
| mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5)) | |
| H = W = int(math.sqrt(mask_img.size(0))) | |
| mask_img = mask_img.reshape(H, W) | |
| save_path = f'heatmap/step_{cur_step}_layer_{cur_block}_token{token_index}.png' | |
| load_path = [f'heatmap/step_{cur_step-1}_layer_{i}_token{token_index}.png' for i in layer] #save_image(mask_img.unsqueeze(0), save_path) | |
| save_image(mask_img.unsqueeze(0), save_path) | |
| # Debug information | |
| #print(f"[DEBUG] cur_step: {cur_step}, cur_block: {cur_block}") | |
| #print(f"[DEBUG] norm_heatmap values:\n{norm_heatmap}") | |
| #print(f"[DEBUG] mask_img (before thresholding) stats: min={mask_img.min().item()}, max={mask_img.max().item()}, mean={mask_img.mean().item()}") | |
| #print(f"[DEBUG] thre value: {thre}") | |
| #print(f"[DEBUG] mask_img (before thresholding) values:\n{mask_img}") | |
| mask_img[mask_img >= thre] = 1 | |
| mask_img[mask_img < thre] = 0 | |
| #save_image(mask_img.unsqueeze(0), save_path) | |
| #print(f"[DEBUG] mask_img (after thresholding) unique values: {mask_img.unique().tolist()}") | |
| #print(f"[DEBUG] mask_img (after thresholding) values:\n{mask_img}") | |
| mask_tensor = torch.zeros_like(mask_img) # Set mask_tensor as a zero tensor | |
| if cur_step > 3: | |
| mask_accumulator = torch.zeros_like(mask_tensor.unsqueeze(0), dtype=mask_img.dtype) # Accumulator for averaging masks | |
| mask_tensor = auto_mask(load_path, mask_accumulator, thre, info, mask_num=4) | |
| if cur_block == 1: | |
| save_image(mask_tensor, f'heatmap/average_heatmaps/step_{cur_step}_layer_{cur_block}_token{token_index}.png') | |
| if not torch.all(mask_tensor == 0): | |
| highlight_factor = 2.0 # Factor to increase weights in the masked area | |
| reduce_factor = 0.8 # Factor to decrease weights in the unmasked area | |
| mask_tensor = mask_tensor.reshape(1, H * W) | |
| mask_tensor = mask_tensor.unsqueeze(1).unsqueeze(-1) | |
| # Create a multiplier tensor: 2.0 where mask is active, 0.5 where mask is inactive. | |
| multiplier = torch.where(mask_tensor.bool(), torch.tensor(highlight_factor), torch.tensor(reduce_factor)) | |
| attn_weight[:, :, -img_shape:, :15] *= multiplier | |
| return attn_weight @ value | |
| ''' | |
| if cur_step == 14 and (cur_block == 2 or cur_block == 7 or cur_block == 12): | |
| mask_img = torch.zeros_like(mask_img) | |
| for j in range(5): | |
| token_heatmap = txt_img_cross[:, :, :, j] | |
| token_heatmap = token_heatmap.mean(dim=1)[0] | |
| min_val, max_val = token_heatmap.min(), token_heatmap.max() | |
| norm_heatmap = (token_heatmap - min_val) / (max_val - min_val) | |
| mask_img = torch.sigmoid(coefficient*(norm_heatmap - 0.5)) | |
| H = W = int(math.sqrt(mask_img.size(0))) | |
| mask_img = mask_img.reshape(H, W) | |
| save_path = f'/home/hfle/personalization/FireFlow-Fast-Inversion-of-Rectified-Flow-for-Image-Semantic-Editing/heatmap/step_{cur_step}_layer_{cur_block}_token{j}.png' | |
| save_image(mask_img.unsqueeze(0), save_path)''' | |