| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import Literal |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import PeftAdapterMixin |
| from ...loaders.single_file_model import FromOriginalModelMixin |
| from ...models.attention_processor import Attention |
| from ...models.normalization import RMSNorm |
| from ...utils.torch_utils import maybe_allow_in_graph |
| from ..attention_dispatch import dispatch_attention_fn |
| from ..controlnets.controlnet import zero_module |
| from ..modeling_utils import ModelMixin |
|
|
|
|
| ADALN_EMBED_DIM = 256 |
| SEQ_MULTI_OF = 32 |
|
|
|
|
| |
| class TimestepEmbedder(nn.Module): |
| def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): |
| super().__init__() |
| if mid_size is None: |
| mid_size = out_size |
| self.mlp = nn.Sequential( |
| nn.Linear(frequency_embedding_size, mid_size, bias=True), |
| nn.SiLU(), |
| nn.Linear(mid_size, out_size, bias=True), |
| ) |
|
|
| self.frequency_embedding_size = frequency_embedding_size |
|
|
| @staticmethod |
| def timestep_embedding(t, dim, max_period=10000): |
| with torch.amp.autocast("cuda", enabled=False): |
| half = dim // 2 |
| freqs = torch.exp( |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half |
| ) |
| args = t[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| return embedding |
|
|
| def forward(self, t): |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) |
| weight_dtype = self.mlp[0].weight.dtype |
| compute_dtype = getattr(self.mlp[0], "compute_dtype", None) |
| if weight_dtype.is_floating_point: |
| t_freq = t_freq.to(weight_dtype) |
| elif compute_dtype is not None: |
| t_freq = t_freq.to(compute_dtype) |
| t_emb = self.mlp(t_freq) |
| return t_emb |
|
|
|
|
| |
| class ZSingleStreamAttnProcessor: |
| """ |
| Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the |
| original Z-ImageAttention module. |
| """ |
|
|
| _attention_backend = None |
| _parallel_config = None |
|
|
| def __init__(self): |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError( |
| "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." |
| ) |
|
|
| def __call__( |
| self, |
| attn: Attention, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| freqs_cis: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| query = attn.to_q(hidden_states) |
| key = attn.to_k(hidden_states) |
| value = attn.to_v(hidden_states) |
|
|
| query = query.unflatten(-1, (attn.heads, -1)) |
| key = key.unflatten(-1, (attn.heads, -1)) |
| value = value.unflatten(-1, (attn.heads, -1)) |
|
|
| |
| if attn.norm_q is not None: |
| query = attn.norm_q(query) |
| if attn.norm_k is not None: |
| key = attn.norm_k(key) |
|
|
| |
| def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: |
| with torch.amp.autocast("cuda", enabled=False): |
| x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) |
| freqs_cis = freqs_cis.unsqueeze(2) |
| x_out = torch.view_as_real(x * freqs_cis).flatten(3) |
| return x_out.type_as(x_in) |
|
|
| if freqs_cis is not None: |
| query = apply_rotary_emb(query, freqs_cis) |
| key = apply_rotary_emb(key, freqs_cis) |
|
|
| |
| dtype = query.dtype |
| query, key = query.to(dtype), key.to(dtype) |
|
|
| |
| if attention_mask is not None and attention_mask.ndim == 2: |
| attention_mask = attention_mask[:, None, None, :] |
|
|
| |
| hidden_states = dispatch_attention_fn( |
| query, |
| key, |
| value, |
| attn_mask=attention_mask, |
| dropout_p=0.0, |
| is_causal=False, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
|
|
| |
| hidden_states = hidden_states.flatten(2, 3) |
| hidden_states = hidden_states.to(dtype) |
|
|
| output = attn.to_out[0](hidden_states) |
| if len(attn.to_out) > 1: |
| output = attn.to_out[1](output) |
|
|
| return output |
|
|
|
|
| |
| class FeedForward(nn.Module): |
| def __init__(self, dim: int, hidden_dim: int): |
| super().__init__() |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def _forward_silu_gating(self, x1, x3): |
| return F.silu(x1) * x3 |
|
|
| def forward(self, x): |
| return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) |
|
|
|
|
| |
| def select_per_token( |
| value_noisy: torch.Tensor, |
| value_clean: torch.Tensor, |
| noise_mask: torch.Tensor, |
| seq_len: int, |
| ) -> torch.Tensor: |
| noise_mask_expanded = noise_mask.unsqueeze(-1) |
| return torch.where( |
| noise_mask_expanded == 1, |
| value_noisy.unsqueeze(1).expand(-1, seq_len, -1), |
| value_clean.unsqueeze(1).expand(-1, seq_len, -1), |
| ) |
|
|
|
|
| @maybe_allow_in_graph |
| |
| class ZImageTransformerBlock(nn.Module): |
| def __init__( |
| self, |
| layer_id: int, |
| dim: int, |
| n_heads: int, |
| n_kv_heads: int, |
| norm_eps: float, |
| qk_norm: bool, |
| modulation=True, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.head_dim = dim // n_heads |
|
|
| |
| |
| self.attention = Attention( |
| query_dim=dim, |
| cross_attention_dim=None, |
| dim_head=dim // n_heads, |
| heads=n_heads, |
| qk_norm="rms_norm" if qk_norm else None, |
| eps=1e-5, |
| bias=False, |
| out_bias=False, |
| processor=ZSingleStreamAttnProcessor(), |
| ) |
|
|
| self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) |
| self.layer_id = layer_id |
|
|
| self.attention_norm1 = RMSNorm(dim, eps=norm_eps) |
| self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) |
|
|
| self.attention_norm2 = RMSNorm(dim, eps=norm_eps) |
| self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) |
|
|
| self.modulation = modulation |
| if modulation: |
| self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attn_mask: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| adaln_input: torch.Tensor | None = None, |
| noise_mask: torch.Tensor | None = None, |
| adaln_noisy: torch.Tensor | None = None, |
| adaln_clean: torch.Tensor | None = None, |
| ): |
| if self.modulation: |
| seq_len = x.shape[1] |
|
|
| if noise_mask is not None: |
| |
| mod_noisy = self.adaLN_modulation(adaln_noisy) |
| mod_clean = self.adaLN_modulation(adaln_clean) |
|
|
| scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) |
| scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) |
|
|
| gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() |
| gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() |
|
|
| scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy |
| scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean |
|
|
| scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) |
| scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) |
| gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) |
| gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) |
| else: |
| |
| mod = self.adaLN_modulation(adaln_input) |
| scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) |
| gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() |
| scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp |
|
|
| |
| attn_out = self.attention( |
| self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis |
| ) |
| x = x + gate_msa * self.attention_norm2(attn_out) |
|
|
| |
| x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) |
| else: |
| |
| attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) |
| x = x + self.attention_norm2(attn_out) |
|
|
| |
| x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) |
|
|
| return x |
|
|
|
|
| |
| class RopeEmbedder: |
| def __init__( |
| self, |
| theta: float = 256.0, |
| axes_dims: list[int] = (16, 56, 56), |
| axes_lens: list[int] = (64, 128, 128), |
| ): |
| self.theta = theta |
| self.axes_dims = axes_dims |
| self.axes_lens = axes_lens |
| assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" |
| self.freqs_cis = None |
|
|
| @staticmethod |
| def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): |
| with torch.device("cpu"): |
| freqs_cis = [] |
| for i, (d, e) in enumerate(zip(dim, end)): |
| freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
| timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
| freqs = torch.outer(timestep, freqs).float() |
| freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) |
| freqs_cis.append(freqs_cis_i) |
|
|
| return freqs_cis |
|
|
| def __call__(self, ids: torch.Tensor): |
| assert ids.ndim == 2 |
| assert ids.shape[-1] == len(self.axes_dims) |
| device = ids.device |
|
|
| if self.freqs_cis is None: |
| self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) |
| self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] |
| else: |
| |
| if self.freqs_cis[0].device != device: |
| self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] |
|
|
| result = [] |
| for i in range(len(self.axes_dims)): |
| index = ids[:, i] |
| result.append(self.freqs_cis[i][index]) |
| return torch.cat(result, dim=-1) |
|
|
|
|
| @maybe_allow_in_graph |
| class ZImageControlTransformerBlock(nn.Module): |
| def __init__( |
| self, |
| layer_id: int, |
| dim: int, |
| n_heads: int, |
| n_kv_heads: int, |
| norm_eps: float, |
| qk_norm: bool, |
| modulation=True, |
| block_id=0, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.head_dim = dim // n_heads |
|
|
| |
| |
| self.attention = Attention( |
| query_dim=dim, |
| cross_attention_dim=None, |
| dim_head=dim // n_heads, |
| heads=n_heads, |
| qk_norm="rms_norm" if qk_norm else None, |
| eps=1e-5, |
| bias=False, |
| out_bias=False, |
| processor=ZSingleStreamAttnProcessor(), |
| ) |
|
|
| self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) |
| self.layer_id = layer_id |
|
|
| self.attention_norm1 = RMSNorm(dim, eps=norm_eps) |
| self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) |
|
|
| self.attention_norm2 = RMSNorm(dim, eps=norm_eps) |
| self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) |
|
|
| self.modulation = modulation |
| if modulation: |
| self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) |
|
|
| |
| self.block_id = block_id |
| if block_id == 0: |
| self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) |
| self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) |
|
|
| def forward( |
| self, |
| c: torch.Tensor, |
| x: torch.Tensor, |
| attn_mask: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| adaln_input: torch.Tensor | None = None, |
| ): |
| |
| if self.block_id == 0: |
| c = self.before_proj(c) + x |
| all_c = [] |
| else: |
| all_c = list(torch.unbind(c)) |
| c = all_c.pop(-1) |
|
|
| |
| if self.modulation: |
| assert adaln_input is not None |
| scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) |
| gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() |
| scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp |
|
|
| |
| attn_out = self.attention( |
| self.attention_norm1(c) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis |
| ) |
| c = c + gate_msa * self.attention_norm2(attn_out) |
|
|
| |
| c = c + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(c) * scale_mlp)) |
| else: |
| |
| attn_out = self.attention(self.attention_norm1(c), attention_mask=attn_mask, freqs_cis=freqs_cis) |
| c = c + self.attention_norm2(attn_out) |
|
|
| |
| c = c + self.ffn_norm2(self.feed_forward(self.ffn_norm1(c))) |
|
|
| |
| c_skip = self.after_proj(c) |
| all_c += [c_skip, c] |
| c = torch.stack(all_c) |
| return c |
|
|
|
|
| class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): |
| _supports_gradient_checkpointing = True |
|
|
| @register_to_config |
| def __init__( |
| self, |
| control_layers_places: list[int] = None, |
| control_refiner_layers_places: list[int] = None, |
| control_in_dim=None, |
| add_control_noise_refiner: Literal["control_layers", "control_noise_refiner"] | None = None, |
| all_patch_size=(2,), |
| all_f_patch_size=(1,), |
| dim=3840, |
| n_refiner_layers=2, |
| n_heads=30, |
| n_kv_heads=30, |
| norm_eps=1e-5, |
| qk_norm=True, |
| ): |
| super().__init__() |
| self.control_layers_places = control_layers_places |
| self.control_in_dim = control_in_dim |
| self.control_refiner_layers_places = control_refiner_layers_places |
| self.add_control_noise_refiner = add_control_noise_refiner |
|
|
| assert 0 in self.control_layers_places |
|
|
| |
| self.control_layers = nn.ModuleList( |
| [ |
| ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) |
| for i in self.control_layers_places |
| ] |
| ) |
|
|
| |
| all_x_embedder = {} |
| for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): |
| x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) |
| all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder |
|
|
| self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) |
| if self.add_control_noise_refiner == "control_layers": |
| self.control_noise_refiner = None |
| elif self.add_control_noise_refiner == "control_noise_refiner": |
| self.control_noise_refiner = nn.ModuleList( |
| [ |
| ZImageControlTransformerBlock( |
| 1000 + layer_id, |
| dim, |
| n_heads, |
| n_kv_heads, |
| norm_eps, |
| qk_norm, |
| modulation=True, |
| block_id=layer_id, |
| ) |
| for layer_id in range(n_refiner_layers) |
| ] |
| ) |
| else: |
| self.control_noise_refiner = nn.ModuleList( |
| [ |
| ZImageTransformerBlock( |
| 1000 + layer_id, |
| dim, |
| n_heads, |
| n_kv_heads, |
| norm_eps, |
| qk_norm, |
| modulation=True, |
| ) |
| for layer_id in range(n_refiner_layers) |
| ] |
| ) |
|
|
| self.t_scale: float | None = None |
| self.t_embedder: TimestepEmbedder | None = None |
| self.all_x_embedder: nn.ModuleDict | None = None |
| self.cap_embedder: nn.Sequential | None = None |
| self.rope_embedder: RopeEmbedder | None = None |
| self.noise_refiner: nn.ModuleList | None = None |
| self.context_refiner: nn.ModuleList | None = None |
| self.x_pad_token: nn.Parameter | None = None |
| self.cap_pad_token: nn.Parameter | None = None |
|
|
| @classmethod |
| def from_transformer(cls, controlnet, transformer): |
| controlnet.t_scale = transformer.t_scale |
| controlnet.t_embedder = transformer.t_embedder |
| controlnet.all_x_embedder = transformer.all_x_embedder |
| controlnet.cap_embedder = transformer.cap_embedder |
| controlnet.rope_embedder = transformer.rope_embedder |
| controlnet.noise_refiner = transformer.noise_refiner |
| controlnet.context_refiner = transformer.context_refiner |
| controlnet.x_pad_token = transformer.x_pad_token |
| controlnet.cap_pad_token = transformer.cap_pad_token |
| return controlnet |
|
|
| @staticmethod |
| |
| def create_coordinate_grid(size, start=None, device=None): |
| if start is None: |
| start = (0 for _ in size) |
| axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] |
| grids = torch.meshgrid(axes, indexing="ij") |
| return torch.stack(grids, dim=-1) |
|
|
| |
| def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): |
| """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" |
| pH, pW, pF = patch_size, patch_size, f_patch_size |
| C, F, H, W = image.size() |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW |
| image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) |
| image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) |
| return image, (F, H, W), (F_tokens, H_tokens, W_tokens) |
|
|
| |
| def _pad_with_ids( |
| self, |
| feat: torch.Tensor, |
| pos_grid_size: tuple, |
| pos_start: tuple, |
| device: torch.device, |
| noise_mask_val: int | None = None, |
| ): |
| """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" |
| ori_len = len(feat) |
| pad_len = (-ori_len) % SEQ_MULTI_OF |
| total_len = ori_len + pad_len |
|
|
| |
| ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) |
| if pad_len > 0: |
| pad_pos_ids = ( |
| self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) |
| .flatten(0, 2) |
| .repeat(pad_len, 1) |
| ) |
| pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) |
| padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) |
| pad_mask = torch.cat( |
| [ |
| torch.zeros(ori_len, dtype=torch.bool, device=device), |
| torch.ones(pad_len, dtype=torch.bool, device=device), |
| ] |
| ) |
| else: |
| pos_ids = ori_pos_ids |
| padded_feat = feat |
| pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) |
|
|
| noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None |
| return padded_feat, pos_ids, pad_mask, total_len, noise_mask |
|
|
| |
| def patchify_and_embed( |
| self, all_image: list[torch.Tensor], all_cap_feats: list[torch.Tensor], patch_size: int, f_patch_size: int |
| ): |
| """Patchify for basic mode: single image per batch item.""" |
| device = all_image[0].device |
| all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] |
| all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] |
|
|
| for image, cap_feat in zip(all_image, all_cap_feats): |
| |
| cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( |
| cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device |
| ) |
| all_cap_out.append(cap_out) |
| all_cap_pos_ids.append(cap_pos_ids) |
| all_cap_pad_mask.append(cap_pad_mask) |
|
|
| |
| img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) |
| img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( |
| img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device |
| ) |
| all_img_out.append(img_out) |
| all_img_size.append(size) |
| all_img_pos_ids.append(img_pos_ids) |
| all_img_pad_mask.append(img_pad_mask) |
|
|
| return ( |
| all_img_out, |
| all_cap_out, |
| all_img_size, |
| all_img_pos_ids, |
| all_cap_pos_ids, |
| all_img_pad_mask, |
| all_cap_pad_mask, |
| ) |
|
|
| def patchify( |
| self, |
| all_image: list[torch.Tensor], |
| patch_size: int, |
| f_patch_size: int, |
| ): |
| pH = pW = patch_size |
| pF = f_patch_size |
| all_image_out = [] |
|
|
| for i, image in enumerate(all_image): |
| |
| C, F, H, W = image.size() |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW |
|
|
| image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) |
| |
| image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) |
|
|
| image_ori_len = len(image) |
| image_padding_len = (-image_ori_len) % SEQ_MULTI_OF |
|
|
| |
| image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) |
| all_image_out.append(image_padded_feat) |
|
|
| return all_image_out |
|
|
| def forward( |
| self, |
| x: list[torch.Tensor], |
| t, |
| cap_feats: list[torch.Tensor], |
| control_context: list[torch.Tensor], |
| conditioning_scale: float = 1.0, |
| patch_size=2, |
| f_patch_size=1, |
| ): |
| if ( |
| self.t_scale is None |
| or self.t_embedder is None |
| or self.all_x_embedder is None |
| or self.cap_embedder is None |
| or self.rope_embedder is None |
| or self.noise_refiner is None |
| or self.context_refiner is None |
| or self.x_pad_token is None |
| or self.cap_pad_token is None |
| ): |
| raise ValueError( |
| "Required modules are `None`, use `from_transformer` to share required modules from `transformer`." |
| ) |
|
|
| assert patch_size in self.config.all_patch_size |
| assert f_patch_size in self.config.all_f_patch_size |
|
|
| bsz = len(x) |
| device = x[0].device |
| t = t * self.t_scale |
| t = self.t_embedder(t) |
|
|
| ( |
| x, |
| cap_feats, |
| x_size, |
| x_pos_ids, |
| cap_pos_ids, |
| x_inner_pad_mask, |
| cap_inner_pad_mask, |
| ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size) |
|
|
| x_item_seqlens = [len(_) for _ in x] |
| assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) |
| x_max_item_seqlen = max(x_item_seqlens) |
|
|
| control_context = self.patchify(control_context, patch_size, f_patch_size) |
| control_context = torch.cat(control_context, dim=0) |
| control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context) |
|
|
| control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token |
| control_context = list(control_context.split(x_item_seqlens, dim=0)) |
|
|
| control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0) |
|
|
| |
| x = torch.cat(x, dim=0) |
| x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) |
|
|
| |
| adaln_input = t.type_as(x) |
| x[torch.cat(x_inner_pad_mask)] = self.x_pad_token |
| x = list(x.split(x_item_seqlens, dim=0)) |
| x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) |
|
|
| x = pad_sequence(x, batch_first=True, padding_value=0.0) |
| x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) |
| |
| x_freqs_cis = x_freqs_cis[:, : x.shape[1]] |
|
|
| x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) |
| for i, seq_len in enumerate(x_item_seqlens): |
| x_attn_mask[i, :seq_len] = 1 |
|
|
| if self.add_control_noise_refiner is not None: |
| if self.add_control_noise_refiner == "control_layers": |
| layers = self.control_layers |
| elif self.add_control_noise_refiner == "control_noise_refiner": |
| layers = self.control_noise_refiner |
| else: |
| raise ValueError(f"Unsupported `add_control_noise_refiner` type: {self.add_control_noise_refiner}.") |
| for layer in layers: |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| control_context = self._gradient_checkpointing_func( |
| layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input |
| ) |
| else: |
| control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input) |
|
|
| hints = torch.unbind(control_context)[:-1] |
| control_context = torch.unbind(control_context)[-1] |
| noise_refiner_block_samples = { |
| layer_idx: hints[idx] * conditioning_scale |
| for idx, layer_idx in enumerate(self.control_refiner_layers_places) |
| } |
| else: |
| noise_refiner_block_samples = None |
|
|
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| for layer_idx, layer in enumerate(self.noise_refiner): |
| x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input) |
| if noise_refiner_block_samples is not None: |
| if layer_idx in noise_refiner_block_samples: |
| x = x + noise_refiner_block_samples[layer_idx] |
| else: |
| for layer_idx, layer in enumerate(self.noise_refiner): |
| x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) |
| if noise_refiner_block_samples is not None: |
| if layer_idx in noise_refiner_block_samples: |
| x = x + noise_refiner_block_samples[layer_idx] |
|
|
| |
| cap_item_seqlens = [len(_) for _ in cap_feats] |
| cap_max_item_seqlen = max(cap_item_seqlens) |
|
|
| cap_feats = torch.cat(cap_feats, dim=0) |
| cap_feats = self.cap_embedder(cap_feats) |
| cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token |
| cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) |
| cap_freqs_cis = list( |
| self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) |
| ) |
|
|
| cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) |
| cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) |
| |
| cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] |
|
|
| cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) |
| for i, seq_len in enumerate(cap_item_seqlens): |
| cap_attn_mask[i, :seq_len] = 1 |
|
|
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| for layer in self.context_refiner: |
| cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) |
| else: |
| for layer in self.context_refiner: |
| cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) |
|
|
| |
| unified = [] |
| unified_freqs_cis = [] |
| for i in range(bsz): |
| x_len = x_item_seqlens[i] |
| cap_len = cap_item_seqlens[i] |
| unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) |
| unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) |
| unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] |
| assert unified_item_seqlens == [len(_) for _ in unified] |
| unified_max_item_seqlen = max(unified_item_seqlens) |
|
|
| unified = pad_sequence(unified, batch_first=True, padding_value=0.0) |
| unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) |
| unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) |
| for i, seq_len in enumerate(unified_item_seqlens): |
| unified_attn_mask[i, :seq_len] = 1 |
|
|
| |
| if not self.add_control_noise_refiner: |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| for layer in self.control_noise_refiner: |
| control_context = self._gradient_checkpointing_func( |
| layer, control_context, x_attn_mask, x_freqs_cis, adaln_input |
| ) |
| else: |
| for layer in self.control_noise_refiner: |
| control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input) |
|
|
| |
| control_context_unified = [] |
| for i in range(bsz): |
| x_len = x_item_seqlens[i] |
| cap_len = cap_item_seqlens[i] |
| control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]])) |
| control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0) |
|
|
| for layer in self.control_layers: |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| control_context_unified = self._gradient_checkpointing_func( |
| layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input |
| ) |
| else: |
| control_context_unified = layer( |
| control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input |
| ) |
|
|
| hints = torch.unbind(control_context_unified)[:-1] |
| controlnet_block_samples = { |
| layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places) |
| } |
| return controlnet_block_samples |
|
|