import abc import cv2 import numpy as np import torch from IPython.display import display from PIL import Image from typing import Union, Tuple, List from einops import rearrange, repeat import math from torch import nn, einsum from inspect import isfunction from diffusers.utils import logging try: from diffusers.models.unet_2d_condition import UNet2DConditionOutput except: from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput try: from diffusers.models.cross_attention import CrossAttention except: from diffusers.models.attention_processor import Attention as CrossAttention MAX_NUM_WORDS = 77 LOW_RESOURCE = False class CountingCrossAttnProcessor1: def __init__(self, attnstore, place_in_unet): super().__init__() self.attnstore = attnstore self.place_in_unet = place_in_unet def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, dim = hidden_states.shape h = attn_layer.heads q = attn_layer.to_q(hidden_states) is_cross = encoder_hidden_states is not None context = encoder_hidden_states if is_cross else hidden_states k = attn_layer.to_k(context) v = attn_layer.to_v(context) # q = attn_layer.reshape_heads_to_batch_dim(q) # k = attn_layer.reshape_heads_to_batch_dim(k) # v = attn_layer.reshape_heads_to_batch_dim(v) # q = attn_layer.head_to_batch_dim(q) # k = attn_layer.head_to_batch_dim(k) # v = attn_layer.head_to_batch_dim(v) q = self.head_to_batch_dim(q, h) k = self.head_to_batch_dim(k, h) v = self.head_to_batch_dim(v, h) sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale if attention_mask is not None: attention_mask = attention_mask.reshape(batch_size, -1) max_neg_value = -torch.finfo(sim.dtype).max attention_mask = attention_mask[:, None, :].repeat(h, 1, 1) sim.masked_fill_(~attention_mask, max_neg_value) # attention, what we cannot get enough of attn_ = sim.softmax(dim=-1).clone() # softmax = nn.Softmax(dim=-1) # attn_ = softmax(sim) self.attnstore(attn_, is_cross, self.place_in_unet) out = torch.einsum("b i j, b j d -> b i d", attn_, v) # out = attn_layer.batch_to_head_dim(out) out = self.batch_to_head_dim(out, h) if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList: to_out = attn_layer.to_out[0] else: to_out = attn_layer.to_out out = to_out(out) return out def batch_to_head_dim(self, tensor, head_size): # head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def head_to_batch_dim(self, tensor, head_size, out_dim=3): # head_size = self.heads batch_size, seq_len, dim = tensor.shape tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3) if out_dim == 3: tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def register_attention_control(model, controller): attn_procs = {} cross_att_count = 0 for name in model.unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = model.unet.config.block_out_channels[-1] place_in_unet = "mid" elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] place_in_unet = "up" elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = model.unet.config.block_out_channels[block_id] place_in_unet = "down" else: continue cross_att_count += 1 # attn_procs[name] = AttendExciteCrossAttnProcessor( # attnstore=controller, place_in_unet=place_in_unet # ) attn_procs[name] = CountingCrossAttnProcessor1( attnstore=controller, place_in_unet=place_in_unet ) model.unet.set_attn_processor(attn_procs) controller.num_att_layers = cross_att_count def register_hier_output(model): self = model.unet from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding logger = logging.get_logger(__name__) # pylint: disable=invalid-name def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None, attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None, mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True): out_list = [] default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) if self.config.center_input_sample: sample = 2 * sample - 1.0 timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 1, 320, 64, 64 # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) # To support T2I-Adapter-XL if ( is_adapter and len(down_block_additional_residuals) > 0 and sample.shape == down_block_additional_residuals[0].shape ): sample += down_block_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, scale=lora_scale, ) # if i in [1, 4, 7]: out_list.append(sample) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample), out_list self.forward = forward class AttentionControl(abc.ABC): def step_callback(self, x_t): return x_t def between_steps(self): return @property def num_uncond_att_layers(self): return 0 @abc.abstractmethod def forward(self, attn, is_cross: bool, place_in_unet: str): raise NotImplementedError def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= self.num_uncond_att_layers: # self.forward(attn, is_cross, place_in_unet) if LOW_RESOURCE: attn = self.forward(attn, is_cross, place_in_unet) else: h = attn.shape[0] attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: self.cur_att_layer = 0 self.cur_step += 1 self.between_steps() return attn def reset(self): self.cur_step = 0 self.cur_att_layer = 0 def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 class EmptyControl(AttentionControl): def forward(self, attn, is_cross: bool, place_in_unet: str): return attn class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead self.step_store[key].append(attn) return attn def between_steps(self): self.attention_store = self.step_store if self.save_global_store: with torch.no_grad(): if len(self.global_store) == 0: self.global_store = self.step_store else: for key in self.global_store: for i in range(len(self.global_store[key])): self.global_store[key][i] += self.step_store[key][i].detach() self.step_store = self.get_empty_store() self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = self.attention_store return average_attention def get_average_global_attention(self): average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in self.attention_store} return average_attention def reset(self): super(AttentionStore, self).reset() self.step_store = self.get_empty_store() self.attention_store = {} self.global_store = {} def __init__(self, max_size=32, save_global_store=False): ''' Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion process ''' super(AttentionStore, self).__init__() self.save_global_store = save_global_store self.max_size = max_size self.step_store = self.get_empty_store() self.attention_store = {} self.global_store = {} self.curr_step_index = 0 def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): out = [] attention_maps = attention_store.get_average_attention() num_pixels = res ** 2 for location in from_where: for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: if item.shape[1] == num_pixels: cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] out.append(cross_maps) out = torch.cat(out, dim=0) out = out.sum(0) / out.shape[0] return out def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0): tokens = tokenizer.encode(prompts[select]) decoder = tokenizer.decode attention_maps = aggregate_attention(attention_store, res, from_where, True, select) images = [] for i in range(len(tokens)): image = attention_maps[:, :, i] image = 255 * image / image.max() image = image.unsqueeze(-1).expand(*image.shape, 3) image = image.numpy().astype(np.uint8) image = np.array(Image.fromarray(image).resize((256, 256))) image = text_under_image(image, decoder(int(tokens[i]))) images.append(image) view_images(np.stack(images, axis=0)) def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], max_com=10, select: int = 0): attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) images = [] for i in range(max_com): image = vh[i].reshape(res, res) image = image - image.min() image = 255 * image / image.max() image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) image = Image.fromarray(image).resize((256, 256)) image = np.array(image) images.append(image) view_images(np.concatenate(images, axis=1)) def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): h, w, c = image.shape offset = int(h * .2) img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 font = cv2.FONT_HERSHEY_SIMPLEX # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) img[:h] = image textsize = cv2.getTextSize(text, font, 1, 2)[0] text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) return img def view_images(images, num_rows=1, offset_ratio=0.02): if type(images) is list: num_empty = len(images) % num_rows elif images.ndim == 4: num_empty = images.shape[0] % num_rows else: images = [images] num_empty = 0 empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty num_items = len(images) h, w, c = images[0].shape offset = int(h * offset_ratio) num_cols = num_items // num_rows image_ = np.ones((h * num_rows + offset * (num_rows - 1), w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 for i in range(num_rows): for j in range(num_cols): image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ i * num_cols + j] pil_img = Image.fromarray(image_) display(pil_img) def self_cross_attn(self_attn, cross_attn): res = self_attn.shape[0] assert res == cross_attn.shape[0] # cross attn [res, res] -> [res*res] cross_attn_ = cross_attn.reshape([res*res]) # self_attn [res, res, res*res] self_cross_attn = cross_attn_ * self_attn self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0) return self_cross_attn