Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import List, Literal, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.rnn import pad_sequence | |
| from ...configuration_utils import ConfigMixin, register_to_config | |
| from ...loaders import PeftAdapterMixin | |
| from ...loaders.single_file_model import FromOriginalModelMixin | |
| from ...models.attention_processor import Attention | |
| from ...models.normalization import RMSNorm | |
| from ...utils.torch_utils import maybe_allow_in_graph | |
| from ..attention_dispatch import dispatch_attention_fn | |
| from ..controlnets.controlnet import zero_module | |
| from ..modeling_utils import ModelMixin | |
| ADALN_EMBED_DIM = 256 | |
| SEQ_MULTI_OF = 32 | |
| # Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): | |
| super().__init__() | |
| if mid_size is None: | |
| mid_size = out_size | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, mid_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(mid_size, out_size, bias=True), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| with torch.amp.autocast("cuda", enabled=False): | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half | |
| ) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| weight_dtype = self.mlp[0].weight.dtype | |
| compute_dtype = getattr(self.mlp[0], "compute_dtype", None) | |
| if weight_dtype.is_floating_point: | |
| t_freq = t_freq.to(weight_dtype) | |
| elif compute_dtype is not None: | |
| t_freq = t_freq.to(compute_dtype) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| # Copied from diffusers.models.transformers.transformer_z_image.ZSingleStreamAttnProcessor | |
| class ZSingleStreamAttnProcessor: | |
| """ | |
| Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the | |
| original Z-ImageAttention module. | |
| """ | |
| _attention_backend = None | |
| _parallel_config = None | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError( | |
| "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." | |
| ) | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| freqs_cis: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(hidden_states) | |
| value = attn.to_v(hidden_states) | |
| query = query.unflatten(-1, (attn.heads, -1)) | |
| key = key.unflatten(-1, (attn.heads, -1)) | |
| value = value.unflatten(-1, (attn.heads, -1)) | |
| # Apply Norms | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| # Apply RoPE | |
| def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| with torch.amp.autocast("cuda", enabled=False): | |
| x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) | |
| freqs_cis = freqs_cis.unsqueeze(2) | |
| x_out = torch.view_as_real(x * freqs_cis).flatten(3) | |
| return x_out.type_as(x_in) # todo | |
| if freqs_cis is not None: | |
| query = apply_rotary_emb(query, freqs_cis) | |
| key = apply_rotary_emb(key, freqs_cis) | |
| # Cast to correct dtype | |
| dtype = query.dtype | |
| query, key = query.to(dtype), key.to(dtype) | |
| # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] | |
| if attention_mask is not None and attention_mask.ndim == 2: | |
| attention_mask = attention_mask[:, None, None, :] | |
| # Compute joint attention | |
| hidden_states = dispatch_attention_fn( | |
| query, | |
| key, | |
| value, | |
| attn_mask=attention_mask, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| backend=self._attention_backend, | |
| parallel_config=self._parallel_config, | |
| ) | |
| # Reshape back | |
| hidden_states = hidden_states.flatten(2, 3) | |
| hidden_states = hidden_states.to(dtype) | |
| output = attn.to_out[0](hidden_states) | |
| if len(attn.to_out) > 1: # dropout | |
| output = attn.to_out[1](output) | |
| return output | |
| # Copied from diffusers.models.transformers.transformer_z_image.FeedForward | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim: int, hidden_dim: int): | |
| super().__init__() | |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
| def _forward_silu_gating(self, x1, x3): | |
| return F.silu(x1) * x3 | |
| def forward(self, x): | |
| return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) | |
| # Copied from diffusers.models.transformers.transformer_z_image.select_per_token | |
| def select_per_token( | |
| value_noisy: torch.Tensor, | |
| value_clean: torch.Tensor, | |
| noise_mask: torch.Tensor, | |
| seq_len: int, | |
| ) -> torch.Tensor: | |
| noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) | |
| return torch.where( | |
| noise_mask_expanded == 1, | |
| value_noisy.unsqueeze(1).expand(-1, seq_len, -1), | |
| value_clean.unsqueeze(1).expand(-1, seq_len, -1), | |
| ) | |
| # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock | |
| class ZImageTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| layer_id: int, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| norm_eps: float, | |
| qk_norm: bool, | |
| modulation=True, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.head_dim = dim // n_heads | |
| # Refactored to use diffusers Attention with custom processor | |
| # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm | |
| self.attention = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=None, | |
| dim_head=dim // n_heads, | |
| heads=n_heads, | |
| qk_norm="rms_norm" if qk_norm else None, | |
| eps=1e-5, | |
| bias=False, | |
| out_bias=False, | |
| processor=ZSingleStreamAttnProcessor(), | |
| ) | |
| self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) | |
| self.layer_id = layer_id | |
| self.attention_norm1 = RMSNorm(dim, eps=norm_eps) | |
| self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) | |
| self.attention_norm2 = RMSNorm(dim, eps=norm_eps) | |
| self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) | |
| self.modulation = modulation | |
| if modulation: | |
| self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attn_mask: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| adaln_input: Optional[torch.Tensor] = None, | |
| noise_mask: Optional[torch.Tensor] = None, | |
| adaln_noisy: Optional[torch.Tensor] = None, | |
| adaln_clean: Optional[torch.Tensor] = None, | |
| ): | |
| if self.modulation: | |
| seq_len = x.shape[1] | |
| if noise_mask is not None: | |
| # Per-token modulation: different modulation for noisy/clean tokens | |
| mod_noisy = self.adaLN_modulation(adaln_noisy) | |
| mod_clean = self.adaLN_modulation(adaln_clean) | |
| scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) | |
| scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) | |
| gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() | |
| gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() | |
| scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy | |
| scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean | |
| scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) | |
| scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) | |
| gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) | |
| gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) | |
| else: | |
| # Global modulation: same modulation for all tokens (avoid double select) | |
| mod = self.adaLN_modulation(adaln_input) | |
| scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) | |
| gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() | |
| scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp | |
| # Attention block | |
| attn_out = self.attention( | |
| self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis | |
| ) | |
| x = x + gate_msa * self.attention_norm2(attn_out) | |
| # FFN block | |
| x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) | |
| else: | |
| # Attention block | |
| attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) | |
| x = x + self.attention_norm2(attn_out) | |
| # FFN block | |
| x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) | |
| return x | |
| # Copied from diffusers.models.transformers.transformer_z_image.RopeEmbedder | |
| class RopeEmbedder: | |
| def __init__( | |
| self, | |
| theta: float = 256.0, | |
| axes_dims: List[int] = (16, 56, 56), | |
| axes_lens: List[int] = (64, 128, 128), | |
| ): | |
| self.theta = theta | |
| self.axes_dims = axes_dims | |
| self.axes_lens = axes_lens | |
| assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" | |
| self.freqs_cis = None | |
| def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): | |
| with torch.device("cpu"): | |
| freqs_cis = [] | |
| for i, (d, e) in enumerate(zip(dim, end)): | |
| freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) | |
| timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) | |
| freqs = torch.outer(timestep, freqs).float() | |
| freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 | |
| freqs_cis.append(freqs_cis_i) | |
| return freqs_cis | |
| def __call__(self, ids: torch.Tensor): | |
| assert ids.ndim == 2 | |
| assert ids.shape[-1] == len(self.axes_dims) | |
| device = ids.device | |
| if self.freqs_cis is None: | |
| self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) | |
| self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] | |
| else: | |
| # Ensure freqs_cis are on the same device as ids | |
| if self.freqs_cis[0].device != device: | |
| self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] | |
| result = [] | |
| for i in range(len(self.axes_dims)): | |
| index = ids[:, i] | |
| result.append(self.freqs_cis[i][index]) | |
| return torch.cat(result, dim=-1) | |
| class ZImageControlTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| layer_id: int, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| norm_eps: float, | |
| qk_norm: bool, | |
| modulation=True, | |
| block_id=0, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.head_dim = dim // n_heads | |
| # Refactored to use diffusers Attention with custom processor | |
| # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm | |
| self.attention = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=None, | |
| dim_head=dim // n_heads, | |
| heads=n_heads, | |
| qk_norm="rms_norm" if qk_norm else None, | |
| eps=1e-5, | |
| bias=False, | |
| out_bias=False, | |
| processor=ZSingleStreamAttnProcessor(), | |
| ) | |
| self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) | |
| self.layer_id = layer_id | |
| self.attention_norm1 = RMSNorm(dim, eps=norm_eps) | |
| self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) | |
| self.attention_norm2 = RMSNorm(dim, eps=norm_eps) | |
| self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) | |
| self.modulation = modulation | |
| if modulation: | |
| self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) | |
| # Control variant start | |
| self.block_id = block_id | |
| if block_id == 0: | |
| self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) | |
| self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) | |
| def forward( | |
| self, | |
| c: torch.Tensor, | |
| x: torch.Tensor, | |
| attn_mask: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| adaln_input: Optional[torch.Tensor] = None, | |
| ): | |
| # Control | |
| if self.block_id == 0: | |
| c = self.before_proj(c) + x | |
| all_c = [] | |
| else: | |
| all_c = list(torch.unbind(c)) | |
| c = all_c.pop(-1) | |
| # Compared to `ZImageTransformerBlock` x -> c | |
| if self.modulation: | |
| assert adaln_input is not None | |
| scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) | |
| gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() | |
| scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp | |
| # Attention block | |
| attn_out = self.attention( | |
| self.attention_norm1(c) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis | |
| ) | |
| c = c + gate_msa * self.attention_norm2(attn_out) | |
| # FFN block | |
| c = c + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(c) * scale_mlp)) | |
| else: | |
| # Attention block | |
| attn_out = self.attention(self.attention_norm1(c), attention_mask=attn_mask, freqs_cis=freqs_cis) | |
| c = c + self.attention_norm2(attn_out) | |
| # FFN block | |
| c = c + self.ffn_norm2(self.feed_forward(self.ffn_norm1(c))) | |
| # Control | |
| c_skip = self.after_proj(c) | |
| all_c += [c_skip, c] | |
| c = torch.stack(all_c) | |
| return c | |
| class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): | |
| _supports_gradient_checkpointing = True | |
| def __init__( | |
| self, | |
| control_layers_places: List[int] = None, | |
| control_refiner_layers_places: List[int] = None, | |
| control_in_dim=None, | |
| add_control_noise_refiner: Optional[Literal["control_layers", "control_noise_refiner"]] = None, | |
| all_patch_size=(2,), | |
| all_f_patch_size=(1,), | |
| dim=3840, | |
| n_refiner_layers=2, | |
| n_heads=30, | |
| n_kv_heads=30, | |
| norm_eps=1e-5, | |
| qk_norm=True, | |
| ): | |
| super().__init__() | |
| self.control_layers_places = control_layers_places | |
| self.control_in_dim = control_in_dim | |
| self.control_refiner_layers_places = control_refiner_layers_places | |
| self.add_control_noise_refiner = add_control_noise_refiner | |
| assert 0 in self.control_layers_places | |
| # control blocks | |
| self.control_layers = nn.ModuleList( | |
| [ | |
| ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) | |
| for i in self.control_layers_places | |
| ] | |
| ) | |
| # control patch embeddings | |
| all_x_embedder = {} | |
| for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): | |
| x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) | |
| all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder | |
| self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) | |
| if self.add_control_noise_refiner == "control_layers": | |
| self.control_noise_refiner = None | |
| elif self.add_control_noise_refiner == "control_noise_refiner": | |
| self.control_noise_refiner = nn.ModuleList( | |
| [ | |
| ZImageControlTransformerBlock( | |
| 1000 + layer_id, | |
| dim, | |
| n_heads, | |
| n_kv_heads, | |
| norm_eps, | |
| qk_norm, | |
| modulation=True, | |
| block_id=layer_id, | |
| ) | |
| for layer_id in range(n_refiner_layers) | |
| ] | |
| ) | |
| else: | |
| self.control_noise_refiner = nn.ModuleList( | |
| [ | |
| ZImageTransformerBlock( | |
| 1000 + layer_id, | |
| dim, | |
| n_heads, | |
| n_kv_heads, | |
| norm_eps, | |
| qk_norm, | |
| modulation=True, | |
| ) | |
| for layer_id in range(n_refiner_layers) | |
| ] | |
| ) | |
| self.t_scale: Optional[float] = None | |
| self.t_embedder: Optional[TimestepEmbedder] = None | |
| self.all_x_embedder: Optional[nn.ModuleDict] = None | |
| self.cap_embedder: Optional[nn.Sequential] = None | |
| self.rope_embedder: Optional[RopeEmbedder] = None | |
| self.noise_refiner: Optional[nn.ModuleList] = None | |
| self.context_refiner: Optional[nn.ModuleList] = None | |
| self.x_pad_token: Optional[nn.Parameter] = None | |
| self.cap_pad_token: Optional[nn.Parameter] = None | |
| def from_transformer(cls, controlnet, transformer): | |
| controlnet.t_scale = transformer.t_scale | |
| controlnet.t_embedder = transformer.t_embedder | |
| controlnet.all_x_embedder = transformer.all_x_embedder | |
| controlnet.cap_embedder = transformer.cap_embedder | |
| controlnet.rope_embedder = transformer.rope_embedder | |
| controlnet.noise_refiner = transformer.noise_refiner | |
| controlnet.context_refiner = transformer.context_refiner | |
| controlnet.x_pad_token = transformer.x_pad_token | |
| controlnet.cap_pad_token = transformer.cap_pad_token | |
| return controlnet | |
| # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.create_coordinate_grid | |
| def create_coordinate_grid(size, start=None, device=None): | |
| if start is None: | |
| start = (0 for _ in size) | |
| axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] | |
| grids = torch.meshgrid(axes, indexing="ij") | |
| return torch.stack(grids, dim=-1) | |
| # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image | |
| def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): | |
| """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" | |
| pH, pW, pF = patch_size, patch_size, f_patch_size | |
| C, F, H, W = image.size() | |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW | |
| image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) | |
| image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) | |
| return image, (F, H, W), (F_tokens, H_tokens, W_tokens) | |
| # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids | |
| def _pad_with_ids( | |
| self, | |
| feat: torch.Tensor, | |
| pos_grid_size: Tuple, | |
| pos_start: Tuple, | |
| device: torch.device, | |
| noise_mask_val: Optional[int] = None, | |
| ): | |
| """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" | |
| ori_len = len(feat) | |
| pad_len = (-ori_len) % SEQ_MULTI_OF | |
| total_len = ori_len + pad_len | |
| # Pos IDs | |
| ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) | |
| if pad_len > 0: | |
| pad_pos_ids = ( | |
| self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) | |
| .flatten(0, 2) | |
| .repeat(pad_len, 1) | |
| ) | |
| pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) | |
| padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) | |
| pad_mask = torch.cat( | |
| [ | |
| torch.zeros(ori_len, dtype=torch.bool, device=device), | |
| torch.ones(pad_len, dtype=torch.bool, device=device), | |
| ] | |
| ) | |
| else: | |
| pos_ids = ori_pos_ids | |
| padded_feat = feat | |
| pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) | |
| noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level | |
| return padded_feat, pos_ids, pad_mask, total_len, noise_mask | |
| # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed | |
| def patchify_and_embed( | |
| self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int | |
| ): | |
| """Patchify for basic mode: single image per batch item.""" | |
| device = all_image[0].device | |
| all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] | |
| all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] | |
| for image, cap_feat in zip(all_image, all_cap_feats): | |
| # Caption | |
| cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( | |
| cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device | |
| ) | |
| all_cap_out.append(cap_out) | |
| all_cap_pos_ids.append(cap_pos_ids) | |
| all_cap_pad_mask.append(cap_pad_mask) | |
| # Image | |
| img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) | |
| img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( | |
| img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device | |
| ) | |
| all_img_out.append(img_out) | |
| all_img_size.append(size) | |
| all_img_pos_ids.append(img_pos_ids) | |
| all_img_pad_mask.append(img_pad_mask) | |
| return ( | |
| all_img_out, | |
| all_cap_out, | |
| all_img_size, | |
| all_img_pos_ids, | |
| all_cap_pos_ids, | |
| all_img_pad_mask, | |
| all_cap_pad_mask, | |
| ) | |
| def patchify( | |
| self, | |
| all_image: List[torch.Tensor], | |
| patch_size: int, | |
| f_patch_size: int, | |
| ): | |
| pH = pW = patch_size | |
| pF = f_patch_size | |
| all_image_out = [] | |
| for i, image in enumerate(all_image): | |
| ### Process Image | |
| C, F, H, W = image.size() | |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW | |
| image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) | |
| # "c f pf h ph w pw -> (f h w) (pf ph pw c)" | |
| image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) | |
| image_ori_len = len(image) | |
| image_padding_len = (-image_ori_len) % SEQ_MULTI_OF | |
| # padded feature | |
| image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) | |
| all_image_out.append(image_padded_feat) | |
| return all_image_out | |
| def forward( | |
| self, | |
| x: List[torch.Tensor], | |
| t, | |
| cap_feats: List[torch.Tensor], | |
| control_context: List[torch.Tensor], | |
| conditioning_scale: float = 1.0, | |
| patch_size=2, | |
| f_patch_size=1, | |
| ): | |
| if ( | |
| self.t_scale is None | |
| or self.t_embedder is None | |
| or self.all_x_embedder is None | |
| or self.cap_embedder is None | |
| or self.rope_embedder is None | |
| or self.noise_refiner is None | |
| or self.context_refiner is None | |
| or self.x_pad_token is None | |
| or self.cap_pad_token is None | |
| ): | |
| raise ValueError( | |
| "Required modules are `None`, use `from_transformer` to share required modules from `transformer`." | |
| ) | |
| assert patch_size in self.config.all_patch_size | |
| assert f_patch_size in self.config.all_f_patch_size | |
| bsz = len(x) | |
| device = x[0].device | |
| t = t * self.t_scale | |
| t = self.t_embedder(t) | |
| ( | |
| x, | |
| cap_feats, | |
| x_size, | |
| x_pos_ids, | |
| cap_pos_ids, | |
| x_inner_pad_mask, | |
| cap_inner_pad_mask, | |
| ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) | |
| x_item_seqlens = [len(_) for _ in x] | |
| assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) | |
| x_max_item_seqlen = max(x_item_seqlens) | |
| control_context = self.patchify(control_context, patch_size, f_patch_size) | |
| control_context = torch.cat(control_context, dim=0) | |
| control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) | |
| control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token | |
| control_context = list(control_context.split(x_item_seqlens, dim=0)) | |
| control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) | |
| # x embed & refine | |
| x = torch.cat(x, dim=0) | |
| x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) | |
| # Match t_embedder output dtype to x for layerwise casting compatibility | |
| adaln_input = t.type_as(x) | |
| x[torch.cat(x_inner_pad_mask)] = self.x_pad_token | |
| x = list(x.split(x_item_seqlens, dim=0)) | |
| x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) | |
| x = pad_sequence(x, batch_first=True, padding_value=0.0) | |
| x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) | |
| # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors | |
| x_freqs_cis = x_freqs_cis[:, : x.shape[1]] | |
| x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) | |
| for i, seq_len in enumerate(x_item_seqlens): | |
| x_attn_mask[i, :seq_len] = 1 | |
| if self.add_control_noise_refiner is not None: | |
| if self.add_control_noise_refiner == "control_layers": | |
| layers = self.control_layers | |
| elif self.add_control_noise_refiner == "control_noise_refiner": | |
| layers = self.control_noise_refiner | |
| else: | |
| raise ValueError(f"Unsupported `add_control_noise_refiner` type: {self.add_control_noise_refiner}.") | |
| for layer in layers: | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| control_context = self._gradient_checkpointing_func( | |
| layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input | |
| ) | |
| else: | |
| control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input) | |
| hints = torch.unbind(control_context)[:-1] | |
| control_context = torch.unbind(control_context)[-1] | |
| noise_refiner_block_samples = { | |
| layer_idx: hints[idx] * conditioning_scale | |
| for idx, layer_idx in enumerate(self.control_refiner_layers_places) | |
| } | |
| else: | |
| noise_refiner_block_samples = None | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| for layer_idx, layer in enumerate(self.noise_refiner): | |
| x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) | |
| if noise_refiner_block_samples is not None: | |
| if layer_idx in noise_refiner_block_samples: | |
| x = x + noise_refiner_block_samples[layer_idx] | |
| else: | |
| for layer_idx, layer in enumerate(self.noise_refiner): | |
| x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) | |
| if noise_refiner_block_samples is not None: | |
| if layer_idx in noise_refiner_block_samples: | |
| x = x + noise_refiner_block_samples[layer_idx] | |
| # cap embed & refine | |
| cap_item_seqlens = [len(_) for _ in cap_feats] | |
| cap_max_item_seqlen = max(cap_item_seqlens) | |
| cap_feats = torch.cat(cap_feats, dim=0) | |
| cap_feats = self.cap_embedder(cap_feats) | |
| cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token | |
| cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) | |
| cap_freqs_cis = list( | |
| self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) | |
| ) | |
| cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) | |
| cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) | |
| # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors | |
| cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] | |
| cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) | |
| for i, seq_len in enumerate(cap_item_seqlens): | |
| cap_attn_mask[i, :seq_len] = 1 | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| for layer in self.context_refiner: | |
| cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) | |
| else: | |
| for layer in self.context_refiner: | |
| cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) | |
| # unified | |
| unified = [] | |
| unified_freqs_cis = [] | |
| for i in range(bsz): | |
| x_len = x_item_seqlens[i] | |
| cap_len = cap_item_seqlens[i] | |
| unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) | |
| unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) | |
| unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] | |
| assert unified_item_seqlens == [len(_) for _ in unified] | |
| unified_max_item_seqlen = max(unified_item_seqlens) | |
| unified = pad_sequence(unified, batch_first=True, padding_value=0.0) | |
| unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) | |
| unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) | |
| for i, seq_len in enumerate(unified_item_seqlens): | |
| unified_attn_mask[i, :seq_len] = 1 | |
| ## ControlNet start | |
| if not self.add_control_noise_refiner: | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| for layer in self.control_noise_refiner: | |
| control_context = self._gradient_checkpointing_func( | |
| layer, control_context, x_attn_mask, x_freqs_cis, adaln_input | |
| ) | |
| else: | |
| for layer in self.control_noise_refiner: | |
| control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) | |
| # unified | |
| control_context_unified = [] | |
| for i in range(bsz): | |
| x_len = x_item_seqlens[i] | |
| cap_len = cap_item_seqlens[i] | |
| control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) | |
| control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) | |
| for layer in self.control_layers: | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| control_context_unified = self._gradient_checkpointing_func( | |
| layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input | |
| ) | |
| else: | |
| control_context_unified = layer( | |
| control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input | |
| ) | |
| hints = torch.unbind(control_context_unified)[:-1] | |
| controlnet_block_samples = { | |
| layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) | |
| } | |
| return controlnet_block_samples | |