Spaces:
Configuration error
Configuration error
| """ | |
| Collect all function in prompt_attention folder. | |
| Provide a API `make_controller' to return an initialized AttentionControlEdit class object in the main validation loop. | |
| """ | |
| from typing import Optional, Union, Tuple, List, Dict | |
| import abc | |
| import numpy as np | |
| import copy | |
| from einops import rearrange | |
| import torch | |
| import torch.nn.functional as F | |
| import video_diffusion.prompt_attention.ptp_utils as ptp_utils | |
| from video_diffusion.prompt_attention.visualization import show_cross_attention,show_cross_attention_plus_org_img,show_self_attention_comp,aggregate_attention | |
| from video_diffusion.prompt_attention.attention_store import AttentionStore, AttentionControl | |
| from video_diffusion.prompt_attention.attention_register import register_attention_control | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| from PIL import Image | |
| import os | |
| from video_diffusion.common.image_util import save_gif_mp4_folder_type,make_grid | |
| import cv2 | |
| import math | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| import math | |
| import os | |
| class EmptyControl: | |
| def step_callback(self, x_t): | |
| return x_t | |
| def between_steps(self): | |
| return | |
| def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
| return attn | |
| def apply_jet_colormap(weight): | |
| # 将权重规范化到0-255 | |
| weight = 255*(weight - weight.min()) / (weight.max() - weight.min()+1e-6) | |
| weight = weight.astype(np.uint8) | |
| # 应用Jet颜色映射 | |
| color_mapped_weight = cv2.applyColorMap(weight, cv2.COLORMAP_JET) | |
| return color_mapped_weight | |
| def show_self_attention_comp(self_attention_map, video, h_index:int, w_index:int, res: int, frames:int, place_in_unet: List[str], step:int ): | |
| attention_maps = self_attention_map.reshape(frames, res, res, frames, res, res) | |
| weights = attention_maps[0,h_index,w_index,:,:,:] | |
| attention_list = [] | |
| video_frames = [] | |
| #video f,c,h,w | |
| for i in range(frames): | |
| weight = weights[i].cpu().numpy() | |
| weight_colored = apply_jet_colormap(weight) | |
| weight_colored = weight_colored[:, :, ::-1] # BGR到RGB的转换 | |
| weight_colored = np.array(Image.fromarray(weight_colored).resize((256, 256))) | |
| attention_list.append(weight_colored) | |
| frame = video[i].permute(1,2,0).cpu().numpy() | |
| mean = np.array((0.48145466, 0.4578275, 0.40821073)).reshape((1, 1, 3)) # [h, w, c] | |
| varas = np.array((0.26862954, 0.26130258, 0.27577711)).reshape((1, 1, 3)) | |
| frame = frame * varas + mean | |
| frame = (frame - frame.min()) / (frame.max() - frame.min() + 1e-6) * 255 | |
| frame = frame.astype(np.uint8) | |
| video_frames.append(frame) | |
| alpha = 0.5 | |
| overlay_frames = [] | |
| for frame, attention in zip(video_frames, attention_list): | |
| attention_resized = cv2.resize(attention, (frame.shape[1], frame.shape[0])) | |
| overlay_frame = cv2.addWeighted(frame, alpha, attention_resized, 1 - alpha, 0) | |
| overlay_frames.append(overlay_frame) | |
| print('vis self attn') | |
| save_path = "with_st_layout_vis_self_attn/vis_self_attn" | |
| os.makedirs(save_path, exist_ok=True) | |
| video_save_path = f'{save_path}/self-attn-{place_in_unet}-{step}-query-frame0-h{h_index}-w{w_index}.gif' | |
| save_gif_mp4_folder_type(overlay_frames, video_save_path,save_gif=False) | |
| def draw_grid_on_image(image, grid_size, line_color="gray"): | |
| draw = ImageDraw.Draw(image) | |
| w, h = image.size | |
| for i in range(0, w, grid_size): | |
| draw.line([(i, 0), (i, h)], fill=line_color) | |
| for i in range(0, h, grid_size): | |
| draw.line([(0, i), (w, i)], fill=line_color) | |
| return image | |
| def identify_self_attention_max_min(sim, video, h_index:int, w_index:int, res: int, frames:int, place_in_unet: str, step:int): | |
| attention_maps = sim.reshape(frames, res, res, frames, res, res) | |
| weights = attention_maps[0, h_index, w_index, :, :, :] | |
| flattened_weights = weights.reshape(-1) | |
| global_max_index = flattened_weights.argmax().cpu().numpy() | |
| global_min_index = flattened_weights.argmin().cpu().numpy() | |
| print('weights.shape',weights.shape) | |
| frame_max, h_max, w_max = np.unravel_index(global_max_index, weights.shape) | |
| frame_min, h_min, w_min = np.unravel_index(global_min_index, weights.shape) | |
| video_frames = [] | |
| query_frame_index = 0 | |
| query_h = h_index | |
| query_w = w_index | |
| for i in range(frames): | |
| frame = video[i].permute(1, 2, 0).cpu().numpy() | |
| mean = np.array((0.48145466, 0.4578275, 0.40821073)).reshape((1, 1, 3)) | |
| varas = np.array((0.26862954, 0.26130258, 0.27577711)).reshape((1, 1, 3)) | |
| frame = (frame * varas + mean) * 255 | |
| frame = np.clip(frame, 0, 255).astype(np.uint8) | |
| frame_img = Image.fromarray(frame) | |
| grid_size = 512 // res | |
| frame_img = draw_grid_on_image(frame_img, grid_size) | |
| draw = ImageDraw.Draw(frame_img) | |
| if i == frame_max: | |
| max_pixel_pos = (w_max * grid_size, h_max * grid_size) | |
| draw.rectangle([max_pixel_pos, (max_pixel_pos[0] + grid_size, max_pixel_pos[1] + grid_size)], outline="red", width=2) | |
| if i == frame_min: | |
| min_pixel_pos = (w_min * grid_size, h_min * grid_size) | |
| draw.rectangle([min_pixel_pos, (min_pixel_pos[0] + grid_size, min_pixel_pos[1] + grid_size)], outline="blue", width=2) | |
| if i == query_frame_index: | |
| query_pixel_pos = (query_w * grid_size, query_h * grid_size) | |
| draw.rectangle([query_pixel_pos, (query_pixel_pos[0] + grid_size, query_pixel_pos[1] + grid_size)], outline="yellow", width=2) | |
| video_frames.append(frame_img) | |
| save_path = "/visualization/correspondence_with_query" | |
| os.makedirs(save_path, exist_ok=True) | |
| video_save_path = os.path.join(save_path, f'self-attn-{place_in_unet}-{step}-query-frame0-h{h_index}-w{w_index}.gif') | |
| save_gif_mp4_folder_type(video_frames, video_save_path, save_gif=False) | |
| class ST_Layout_Attn_Control(AttentionControl, abc.ABC): | |
| def __init__(self, end_step=15, total_steps=50, step_idx=None, text_cond=None, sreg_maps=None, creg_maps=None, reg_sizes=None,reg_sizes_c=None, time_steps=None,clip_length=None,attention_type=None): | |
| """ | |
| Spatial-Temporal Layout-guided Attention (ST-Layout Attn) for Stable-Diffusion model | |
| note: without vis cross attention weight function. | |
| Args: | |
| end_step: the step to end st-layout attn control | |
| total_steps: the total number of steps | |
| step_idx: list the steps to apply mutual self-attention control | |
| text_cond: discrete text embedding for each region. | |
| sreg_maps: spatial-temporal self-attention qk condition maps. | |
| creg_maps: cross-attention qk condition maps | |
| reg_sizes/reg_sizes_c: size regularzation maps for each instance in self_attn/cross_attention | |
| clip_length: frames len of video | |
| attention_type: FullyFrameAttention_sliced_attn/FullyFrameAttention/SparseCausalAttention | |
| """ | |
| super().__init__() | |
| self.total_steps = total_steps | |
| self.step_idx = list(range(0, end_step)) | |
| self.total_infer_steps = 50 | |
| self.text_cond = text_cond | |
| self.sreg_maps = sreg_maps | |
| self.creg_maps = creg_maps | |
| self.reg_sizes = reg_sizes | |
| self.reg_sizes_c = reg_sizes_c | |
| self.clip_length = clip_length | |
| self.attention_type = attention_type | |
| self.sreg = .3 | |
| self.creg = 1. | |
| self.count = 0 | |
| self.reg_part = .3 | |
| self.time_steps = time_steps | |
| print("Modulated Ctrl at denoising steps: ", self.step_idx) | |
| def forward(self, sim, is_cross, place_in_unet, **kwargs): | |
| """ | |
| Attention forward function | |
| """ | |
| #print("self.cur_step",self.cur_step) | |
| if self.cur_step not in self.step_idx: | |
| return super().forward(sim, is_cross, place_in_unet, **kwargs) | |
| ### sim for "SparseCausalAttention": (frames, heads=8,res, 2*res) | |
| ### sim for "FullyFrameAttention" : 1, heads, frame*res,frane*res [1, 8, 12288, 12288]) | |
| num_heads = sim.shape[1] | |
| if num_heads == 1: | |
| self.attention_type == "FullyFrameAttention_sliced_attn" | |
| treg = torch.pow((self.time_steps[self.cur_step]-1)/1000, 5) | |
| if not is_cross: | |
| min_value = sim.min(-1)[0].unsqueeze(-1) | |
| max_value = sim.max(-1)[0].unsqueeze(-1) | |
| if self.attention_type == "SparseCausalAttention": | |
| mask = self.sreg_maps[sim.size(2)].repeat(1,num_heads,1,1) | |
| size_reg = self.reg_sizes[sim.size(2)].repeat(1,num_heads,1,1) | |
| elif self.attention_type == "FullyFrameAttention": | |
| mask = self.sreg_maps[sim.size(2)//self.clip_length].repeat(1,num_heads,1,1) | |
| size_reg = self.reg_sizes[sim.size(2)//self.clip_length].repeat(1,num_heads,1,1) | |
| elif self.attention_type == "FullyFrameAttention_sliced_attn": | |
| mask = self.sreg_maps[sim.size(2)//self.clip_length] | |
| size_reg = self.reg_sizes[sim.size(2)//self.clip_length] | |
| else: | |
| print("unknown attention type") | |
| exit() | |
| # if place_in_unet == "up" and res == 32: | |
| # # h_index 11 w_index =15 | |
| # show_self_attention_comp(sim,video=self.video,h_index=11,w_index=15,res=32,frames=self.clip_length,place_in_unet="up",step=self.cur_step) | |
| #if place_in_unet == "up" and res == 8: | |
| # identify_self_attention_max_min(sim,video=self.video,h_index=3,w_index=4,res=8,frames=self.clip_length,place_in_unet="up",step=self.cur_step) | |
| sim += (mask>0)*size_reg*self.sreg*treg*(max_value-sim) | |
| sim -= ~(mask>0)*size_reg*self.sreg*treg*(sim-min_value) | |
| else: | |
| min_value = sim.min(-1)[0].unsqueeze(-1) | |
| max_value = sim.max(-1)[0].unsqueeze(-1) | |
| mask = self.creg_maps[sim.size(2)].repeat(1,num_heads,1,1) | |
| size_reg = self.reg_sizes_c[sim.size(2)].repeat(1,num_heads,1,1) | |
| sim += (mask>0)*size_reg*self.creg*treg*(max_value-sim) | |
| sim -= ~(mask>0)*size_reg*self.creg*treg*(sim-min_value) | |
| self.count +=1 | |
| return sim | |
| class Attention_Record_Processor(AttentionStore, abc.ABC): | |
| """ record ddim inversion self attention and cross attention """ | |
| def __init__(self, additional_attention_store: AttentionStore =None,save_self_attention: bool=True,disk_store=False): | |
| super(Attention_Record_Processor, self).__init__( | |
| save_self_attention=save_self_attention, | |
| disk_store=disk_store) | |
| self.additional_attention_store = additional_attention_store | |
| self.attention_position_counter_dict = { | |
| 'down_cross': 0, | |
| 'mid_cross': 0, | |
| 'up_cross': 0, | |
| 'down_self': 0, | |
| 'mid_self': 0, | |
| 'up_self': 0, | |
| } | |
| #print("Modulated Ctrl at denoising steps: ", self.step_idx) | |
| def update_attention_position_dict(self, current_attention_key): | |
| self.attention_position_counter_dict[current_attention_key] +=1 | |
| def forward(self, sim, is_cross: bool, place_in_unet: str,**kwargs): | |
| super(Attention_Record_Processor, self).forward(sim, is_cross, place_in_unet,**kwargs) | |
| key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
| self.update_attention_position_dict(key) | |
| return sim | |
| def between_steps(self): | |
| super().between_steps() | |
| self.step_store = self.get_empty_store() | |
| self.attention_position_counter_dict = { | |
| 'down_cross': 0, | |
| 'mid_cross': 0, | |
| 'up_cross': 0, | |
| 'down_self': 0, | |
| 'mid_self': 0, | |
| 'up_self': 0, | |
| } | |
| return | |
| class ST_Layout_Attn_ControlEdit(AttentionStore, abc.ABC): | |
| def __init__(self, end_step=15, total_steps=50, step_idx=None, text_cond=None, sreg_maps=None, creg_maps=None, reg_sizes=None,reg_sizes_c=None, | |
| time_steps=None, | |
| clip_length=None,attention_type=None, | |
| additional_attention_store: AttentionStore =None, | |
| save_self_attention: bool=True, | |
| disk_store=False, | |
| video = None, | |
| ): | |
| """ | |
| Spatial-Temporal Layout-guided Attention (ST-Layout Attn) for Stable-Diffusion model | |
| note: with vis cross attention weight function. | |
| Args: | |
| end_step: the step to end st-layout attn control | |
| total_steps: the total number of steps | |
| step_idx: list the steps to apply mutual self-attention control | |
| text_cond: discrete text embedding for each region. | |
| sreg_maps: spatial-temporal self-attention qk condition maps. | |
| creg_maps: cross-attention qk condition maps | |
| reg_sizes/reg_sizes_c: size regularzation maps for each instance in self_attn/cross_attention | |
| clip_length: frames len of video | |
| attention_type: FullyFrameAttention_sliced_attn/FullyFrameAttention/SparseCausalAttention | |
| """ | |
| super(ST_Layout_Attn_ControlEdit, self).__init__( | |
| save_self_attention=save_self_attention, | |
| disk_store=disk_store) | |
| self.total_steps = total_steps | |
| self.step_idx = list(range(0, end_step)) | |
| self.total_infer_steps = 50 | |
| self.text_cond = text_cond | |
| self.sreg_maps = sreg_maps | |
| self.creg_maps = creg_maps | |
| self.reg_sizes = reg_sizes | |
| self.reg_sizes_c = reg_sizes_c | |
| self.clip_length = clip_length | |
| self.attention_type = attention_type | |
| self.sreg = .3 | |
| self.creg = 1. | |
| self.count = 0 | |
| self.reg_part = .3 | |
| self.time_steps = time_steps | |
| self.additional_attention_store = additional_attention_store | |
| self.attention_position_counter_dict = { | |
| 'down_cross': 0, | |
| 'mid_cross': 0, | |
| 'up_cross': 0, | |
| 'down_self': 0, | |
| 'mid_self': 0, | |
| 'up_self': 0, | |
| } | |
| self.video = video | |
| def update_attention_position_dict(self, current_attention_key): | |
| self.attention_position_counter_dict[current_attention_key] +=1 | |
| def forward(self, sim, is_cross: bool, place_in_unet: str,**kwargs): | |
| super(ST_Layout_Attn_ControlEdit, self).forward(sim, is_cross, place_in_unet,**kwargs) | |
| # print("self.cur_step",self.cur_step) | |
| key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
| self.update_attention_position_dict(key) | |
| if self.cur_step not in self.step_idx: | |
| return sim | |
| num_heads = sim.shape[1] | |
| if num_heads == 1: | |
| self.attention_type == "FullyFrameAttention_sliced_attn" | |
| treg = torch.pow((self.time_steps[self.cur_step]-1)/1000, 5) | |
| if not is_cross: | |
| ## Modulate self-attention | |
| min_value = sim.min(-1)[0].unsqueeze(-1) | |
| max_value = sim.max(-1)[0].unsqueeze(-1) | |
| if self.attention_type == "SparseCausalAttention": | |
| mask = self.sreg_maps[sim.size(2)].repeat(1,num_heads,1,1) | |
| size_reg = self.reg_sizes[sim.size(2)].repeat(1,num_heads,1,1) | |
| elif self.attention_type == "FullyFrameAttention": | |
| mask = self.sreg_maps[sim.size(2)//self.clip_length].repeat(1,num_heads,1,1) | |
| size_reg = self.reg_sizes[sim.size(2)//self.clip_length].repeat(1,num_heads,1,1) | |
| elif self.attention_type == "FullyFrameAttention_sliced_attn": | |
| mask = self.sreg_maps[sim.size(2)//self.clip_length] | |
| size_reg = self.reg_sizes[sim.size(2)//self.clip_length] | |
| else: | |
| print("unknown attention type") | |
| exit() | |
| sim += (mask>0)*size_reg*self.sreg*treg*(max_value-sim) | |
| sim -= ~(mask>0)*size_reg*self.sreg*treg*(sim-min_value) | |
| else: | |
| #Modulate cross-attention | |
| min_value = sim.min(-1)[0].unsqueeze(-1) | |
| max_value = sim.max(-1)[0].unsqueeze(-1) | |
| mask = self.creg_maps[sim.size(2)].repeat(1,num_heads,1,1) | |
| size_reg = self.reg_sizes_c[sim.size(2)].repeat(1,num_heads,1,1) | |
| sim += (mask>0)*size_reg*self.creg*treg*(max_value-sim) | |
| sim -= ~(mask>0)*size_reg*self.creg*treg*(sim-min_value) | |
| self.count +=1 | |
| return sim | |
| def between_steps(self): | |
| super().between_steps() | |
| self.step_store = self.get_empty_store() | |
| self.attention_position_counter_dict = { | |
| 'down_cross': 0, | |
| 'mid_cross': 0, | |
| 'up_cross': 0, | |
| 'down_self': 0, | |
| 'mid_self': 0, | |
| 'up_self': 0, | |
| } | |
| return | |