Instructions to use hansQAQ/icip_source_2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use hansQAQ/icip_source_2 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("hansQAQ/icip_source_2", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from typing import Callable, List, Optional, Tuple, Union, Any | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers.models.attention_processor import Attention | |
| from diffusers.utils import logging | |
| from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available | |
| from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph | |
| from einops import rearrange | |
| from torch import nn | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class TripoSGAttnProcessor2_0: | |
| r""" | |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is | |
| used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector. | |
| """ | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| temb: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim) | |
| # instead of .view(..., 3, attn.heads, dim). So we need to re-split here. | |
| if not attn.is_cross_attention: | |
| qkv = torch.cat((query, key, value), dim=-1) | |
| split_size = qkv.shape[-1] // attn.heads // 3 | |
| qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) | |
| query, key, value = torch.split(qkv, split_size, dim=-1) | |
| else: | |
| kv = torch.cat((key, value), dim=-1) | |
| split_size = kv.shape[-1] // attn.heads // 2 | |
| kv = kv.view(batch_size, -1, attn.heads, split_size * 2) | |
| key, value = torch.split(kv, split_size, dim=-1) | |
| head_dim = key.shape[-1] | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # Apply RoPE if needed | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if not attn.is_cross_attention: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class FusedTripoSGAttnProcessor2_0: | |
| r""" | |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused | |
| projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on | |
| query and key vector. | |
| """ | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "FusedTripoSGAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| temb: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| # NOTE that pre-trained split heads first, then split qkv | |
| if encoder_hidden_states is None: | |
| qkv = attn.to_qkv(hidden_states) | |
| split_size = qkv.shape[-1] // attn.heads // 3 | |
| qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) | |
| query, key, value = torch.split(qkv, split_size, dim=-1) | |
| else: | |
| if attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| query = attn.to_q(hidden_states) | |
| kv = attn.to_kv(encoder_hidden_states) | |
| split_size = kv.shape[-1] // attn.heads // 2 | |
| kv = kv.view(batch_size, -1, attn.heads, split_size * 2) | |
| key, value = torch.split(kv, split_size, dim=-1) | |
| head_dim = key.shape[-1] | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # Apply RoPE if needed | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if not attn.is_cross_attention: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class MIAttnProcessor2_0: | |
| r""" | |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is | |
| used in the MIDI model. It applies a normalization layer and rotary embedding on query and key vector. | |
| """ | |
| def __init__(self, use_mi: bool = True): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| self.use_mi = use_mi | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| temb: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| num_instances: Optional[Union[int, torch.IntTensor]] = None, | |
| num_instances_per_batch: Optional[int] = None, | |
| ) -> torch.Tensor: | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim) | |
| # instead of .view(..., 3, attn.heads, dim). So we need to re-split here. | |
| if not attn.is_cross_attention: | |
| qkv = torch.cat((query, key, value), dim=-1) | |
| split_size = qkv.shape[-1] // attn.heads // 3 | |
| qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) | |
| query, key, value = torch.split(qkv, split_size, dim=-1) | |
| else: | |
| kv = torch.cat((key, value), dim=-1) | |
| split_size = kv.shape[-1] // attn.heads // 2 | |
| kv = kv.view(batch_size, -1, attn.heads, split_size * 2) | |
| key, value = torch.split(kv, split_size, dim=-1) | |
| head_dim = key.shape[-1] | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # Apply RoPE if needed | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if not attn.is_cross_attention: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| if not self.use_mi: | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, | |
| key, | |
| value, | |
| attn_mask=attention_mask, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| ) | |
| elif num_instances is not None and num_instances_per_batch is None: | |
| # for inference | |
| key = rearrange( | |
| key, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances | |
| ).repeat_interleave(num_instances, dim=0) | |
| value = rearrange( | |
| value, "(b ni) h nt c -> b h (ni nt) c", ni=num_instances | |
| ).repeat_interleave(num_instances, dim=0) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, | |
| key, | |
| value, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| ) | |
| elif num_instances is not None and num_instances_per_batch is not None: | |
| # for training (the same batch size is required) | |
| # process multi-instance attention among the first `num_instances` samples | |
| # and do object self-attention on the left samples in per batch | |
| patch_hidden_states = [] | |
| start_idx = 0 | |
| while start_idx < batch_size: # for classifier-free guidance | |
| for num in num_instances: | |
| # Multi-object self-attention | |
| query_ = query[start_idx : start_idx + num] | |
| key_ = rearrange( | |
| key[start_idx : start_idx + num], | |
| "(b ni) h nt c -> b h (ni nt) c", | |
| ni=num, | |
| ).repeat_interleave(num, dim=0) | |
| value_ = rearrange( | |
| value[start_idx : start_idx + num], | |
| "(b ni) h nt c -> b h (ni nt) c", | |
| ni=num, | |
| ).repeat_interleave(num, dim=0) | |
| patch_hidden_states.append( | |
| F.scaled_dot_product_attention( | |
| query_, | |
| key_, | |
| value_, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| ) | |
| ) | |
| # Single-object self-attention for padding and regularization | |
| query_ = query[ | |
| start_idx + num : start_idx + num_instances_per_batch | |
| ] | |
| key_ = key[start_idx + num : start_idx + num_instances_per_batch] | |
| value_ = value[ | |
| start_idx + num : start_idx + num_instances_per_batch | |
| ] | |
| if query_.shape[0] > 0: | |
| patch_hidden_states.append( | |
| F.scaled_dot_product_attention( | |
| query_, | |
| key_, | |
| value_, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| ) | |
| ) | |
| start_idx += num_instances_per_batch | |
| hidden_states = torch.cat(patch_hidden_states, dim=0) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class SketchFusionAttnProcessor: | |
| r""" | |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is | |
| used in the TripoSG model. It applies a s normalization layer and rotary embedding on query and key vector. | |
| """ | |
| # @TODO: Spatial Gating Attention, We enforce the model to focus on sketch edges and lines. Intuitively, the control of focus | |
| # area is not rigid. As we discussed, sketch latent in shallow layers might contain more low-level geometry infos | |
| # while deep layers contains semantic infos. As sketch latent token of each patch not only contains infos merely | |
| # in this patch, the latent token is like 'blurring' during forward computing of ViT, so more infos from other patches | |
| # might be involved in this patch's token. This indicates that suppress attention scores of this 'no-drawing-line' | |
| # patch's token in deep layers naively might be unconvincing. Anyway, we will provide two method discussed above | |
| # and test their effectiveness in practical. | |
| # @TODO: 我们暂时就 | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| temb: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| num_instances: Optional[Union[int, torch.IntTensor]] = None, | |
| gating_map: Optional[torch.Tensor] = None, # Check: | |
| gating_intensity: Optional[torch.Tensor] = None, # Check: sketch.sketch_utils.get_sketch_spatial_gating_map | |
| num_instances_per_batch: Optional[Any] = None, | |
| ) -> torch.Tensor: | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| # print('intensity',gating_intensity) | |
| residual = hidden_states | |
| # print(f"### {gating_map.shape} ### {encoder_hidden_states.shape}") | |
| if attn.spatial_norm is not None: | |
| # print("这一行到底有没有走?") | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim) | |
| # instead of .view(..., 3, attn.heads, dim). So we need to re-split here. | |
| if not attn.is_cross_attention: | |
| qkv = torch.cat((query, key, value), dim=-1) | |
| split_size = qkv.shape[-1] // attn.heads // 3 | |
| qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) | |
| query, key, value = torch.split(qkv, split_size, dim=-1) | |
| else: | |
| kv = torch.cat((key, value), dim=-1) | |
| split_size = kv.shape[-1] // attn.heads // 2 | |
| kv = kv.view(batch_size, -1, attn.heads, split_size * 2) | |
| key, value = torch.split(kv, split_size, dim=-1) | |
| head_dim = key.shape[-1] | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # Apply RoPE if needed | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if not attn.is_cross_attention: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # @TODO: There we should use gating_map and gating_intensity to scale KEY value. This is equivalent to scale attention score. | |
| # However, the gradient of intensity is not the same between two method. But I think it's fine :) | |
| if gating_map is not None and gating_intensity is not None: | |
| assert gating_map.shape[0] == key.shape[0] | |
| if gating_map.ndim == 3: | |
| # [B, L, 1] or [B, 1, L] | |
| gating_map = gating_map.squeeze() | |
| assert gating_map.shape[-1] == key.shape[-2], f"Unequal sequence length of gating map and key vectors. Gating map: {gating_map.shape[-1]} while key values: {key.shape[-2]}" | |
| # print(gating_intensity.shape, ' ', gating_map.shape) | |
| # Move gating_map and intensity to the same device as key | |
| gating_map = gating_map.to(key.device) | |
| # gating_intensity might be a float or a tensor | |
| if isinstance(gating_intensity, torch.Tensor): | |
| gating_intensity = gating_intensity.to(key.device) | |
| suppression = 1.0 - (1.0 - gating_intensity) * (1.0 - gating_map) | |
| # Cast suppression back to key dtype to avoid precision upcasting (float16 * float32 -> float32) | |
| suppression = suppression.to(key.dtype) | |
| key = key * suppression.unsqueeze(1).unsqueeze(-1) | |
| key = key.to(query.dtype) | |
| value = value.to(query.dtype) | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| # print(f"###DiT AFTER: {torch.isnan(hidden_states).sum().item()} | {torch.max(hidden_states).item()} | {torch.min(hidden_states).item()} | Dtype: {query.dtype} ") | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # print(f"###到底是哪里: {torch.isnan(hidden_states).sum().item()}") | |
| # print(torch.isnan(attn.to_out[0].weight).sum().item(),f" | {torch.max(attn.to_out[0].weight).item()} | {torch.min(attn.to_out[0].weight).item()}") | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # print(f"###DiT Linear: {torch.isnan(hidden_states).sum().item()}") | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| # print(f"###DiT Dropout: {torch.isnan(hidden_states).sum().item()}") | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class SketchFusionAttnProcessor2: | |
| r""" | |
| Ablation study | |
| """ | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
| ) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| temb: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| num_instances: Optional[Union[int, torch.IntTensor]] = None, | |
| gating_map: Optional[torch.Tensor] = None, # Check: | |
| gating_intensity: Optional[torch.Tensor] = None, # Check: sketch.sketch_utils.get_sketch_spatial_gating_map | |
| num_instances_per_batch: Optional[Any] = None, | |
| ) -> torch.Tensor: | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| # print("HALO") | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view( | |
| batch_size, channel, height * width | |
| ).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape | |
| if encoder_hidden_states is None | |
| else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask( | |
| attention_mask, sequence_length, batch_size | |
| ) | |
| # scaled_dot_product_attention expects attention_mask shape to be | |
| # (batch, heads, source_length, target_length) | |
| attention_mask = attention_mask.view( | |
| batch_size, attn.heads, -1, attention_mask.shape[-1] | |
| ) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
| 1, 2 | |
| ) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states( | |
| encoder_hidden_states | |
| ) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| # NOTE that pre-trained models split heads first then split qkv or kv, like .view(..., attn.heads, 3, dim) | |
| # instead of .view(..., 3, attn.heads, dim). So we need to re-split here. | |
| if not attn.is_cross_attention: | |
| qkv = torch.cat((query, key, value), dim=-1) | |
| split_size = qkv.shape[-1] // attn.heads // 3 | |
| qkv = qkv.view(batch_size, -1, attn.heads, split_size * 3) | |
| query, key, value = torch.split(qkv, split_size, dim=-1) | |
| else: | |
| kv = torch.cat((key, value), dim=-1) | |
| split_size = kv.shape[-1] // attn.heads // 2 | |
| kv = kv.view(batch_size, -1, attn.heads, split_size * 2) | |
| key, value = torch.split(kv, split_size, dim=-1) | |
| head_dim = key.shape[-1] | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # Apply RoPE if needed | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| if not attn.is_cross_attention: | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | |
| # TODO: add support for attn.scale when we move to Torch 2.1 | |
| hidden_states = F.scaled_dot_product_attention( | |
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
| ) | |
| hidden_states = hidden_states.transpose(1, 2).reshape( | |
| batch_size, -1, attn.heads * head_dim | |
| ) | |
| hidden_states = hidden_states.to(query.dtype) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape( | |
| batch_size, channel, height, width | |
| ) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |