| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
| import math |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import FromOriginalModelMixin, PeftAdapterMixin |
| from ...utils import ( |
| logging, |
| ) |
| from ..attention import AttentionMixin, AttentionModuleMixin |
| from ..attention_dispatch import _CAN_USE_FLEX_ATTN, dispatch_attention_fn |
| from ..cache_utils import CacheMixin |
| from ..modeling_outputs import Transformer2DModelOutput |
| from ..modeling_utils import ModelMixin |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def get_freqs(dim, max_period=10000.0): |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) |
| return freqs |
|
|
|
|
| def fractal_flatten(x, rope, shape, block_mask=False): |
| if block_mask: |
| pixel_size = 8 |
| x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) |
| rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) |
| x = x.flatten(1, 2) |
| rope = rope.flatten(1, 2) |
| else: |
| x = x.flatten(1, 3) |
| rope = rope.flatten(1, 3) |
| return x, rope |
|
|
|
|
| def fractal_unflatten(x, shape, block_mask=False): |
| if block_mask: |
| pixel_size = 8 |
| x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:]) |
| x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) |
| else: |
| x = x.reshape(*shape, *x.shape[2:]) |
| return x |
|
|
|
|
| def local_patching(x, shape, group_size, dim=0): |
| batch_size, duration, height, width = shape |
| g1, g2, g3 = group_size |
| x = x.reshape( |
| *x.shape[:dim], |
| duration // g1, |
| g1, |
| height // g2, |
| g2, |
| width // g3, |
| g3, |
| *x.shape[dim + 3 :], |
| ) |
| x = x.permute( |
| *range(len(x.shape[:dim])), |
| dim, |
| dim + 2, |
| dim + 4, |
| dim + 1, |
| dim + 3, |
| dim + 5, |
| *range(dim + 6, len(x.shape)), |
| ) |
| x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) |
| return x |
|
|
|
|
| def local_merge(x, shape, group_size, dim=0): |
| batch_size, duration, height, width = shape |
| g1, g2, g3 = group_size |
| x = x.reshape( |
| *x.shape[:dim], |
| duration // g1, |
| height // g2, |
| width // g3, |
| g1, |
| g2, |
| g3, |
| *x.shape[dim + 2 :], |
| ) |
| x = x.permute( |
| *range(len(x.shape[:dim])), |
| dim, |
| dim + 3, |
| dim + 1, |
| dim + 4, |
| dim + 2, |
| dim + 5, |
| *range(dim + 6, len(x.shape)), |
| ) |
| x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) |
| return x |
|
|
|
|
| def nablaT_v2( |
| q: Tensor, |
| k: Tensor, |
| sta: Tensor, |
| thr: float = 0.9, |
| ): |
| if _CAN_USE_FLEX_ATTN: |
| from torch.nn.attention.flex_attention import BlockMask |
| else: |
| raise ValueError("Nabla attention is not supported with this version of PyTorch") |
|
|
| q = q.transpose(1, 2).contiguous() |
| k = k.transpose(1, 2).contiguous() |
|
|
| |
| B, h, S, D = q.shape |
| s1 = S // 64 |
| qa = q.reshape(B, h, s1, 64, D).mean(-2) |
| ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) |
| map = qa @ ka |
|
|
| map = torch.softmax(map / math.sqrt(D), dim=-1) |
| |
| vals, inds = map.sort(-1) |
| cvals = vals.cumsum_(-1) |
| mask = (cvals >= 1 - thr).int() |
| mask = mask.gather(-1, inds.argsort(-1)) |
|
|
| mask = torch.logical_or(mask, sta) |
|
|
| |
| kv_nb = mask.sum(-1).to(torch.int32) |
| kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) |
| return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None) |
|
|
|
|
| class Kandinsky5TimeEmbeddings(nn.Module): |
| def __init__(self, model_dim, time_dim, max_period=10000.0): |
| super().__init__() |
| assert model_dim % 2 == 0 |
| self.model_dim = model_dim |
| self.max_period = max_period |
| self.freqs = get_freqs(self.model_dim // 2, self.max_period) |
| self.in_layer = nn.Linear(model_dim, time_dim, bias=True) |
| self.activation = nn.SiLU() |
| self.out_layer = nn.Linear(time_dim, time_dim, bias=True) |
|
|
| def forward(self, time): |
| args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device)) |
| time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) |
| return time_embed |
|
|
|
|
| class Kandinsky5TextEmbeddings(nn.Module): |
| def __init__(self, text_dim, model_dim): |
| super().__init__() |
| self.in_layer = nn.Linear(text_dim, model_dim, bias=True) |
| self.norm = nn.LayerNorm(model_dim, elementwise_affine=True) |
|
|
| def forward(self, text_embed): |
| text_embed = self.in_layer(text_embed) |
| return self.norm(text_embed).type_as(text_embed) |
|
|
|
|
| class Kandinsky5VisualEmbeddings(nn.Module): |
| def __init__(self, visual_dim, model_dim, patch_size): |
| super().__init__() |
| self.patch_size = patch_size |
| self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) |
|
|
| def forward(self, x): |
| batch_size, duration, height, width, dim = x.shape |
| x = ( |
| x.view( |
| batch_size, |
| duration // self.patch_size[0], |
| self.patch_size[0], |
| height // self.patch_size[1], |
| self.patch_size[1], |
| width // self.patch_size[2], |
| self.patch_size[2], |
| dim, |
| ) |
| .permute(0, 1, 3, 5, 2, 4, 6, 7) |
| .flatten(4, 7) |
| ) |
| return self.in_layer(x) |
|
|
|
|
| class Kandinsky5RoPE1D(nn.Module): |
| def __init__(self, dim, max_pos=1024, max_period=10000.0): |
| super().__init__() |
| self.max_period = max_period |
| self.dim = dim |
| self.max_pos = max_pos |
| freq = get_freqs(dim // 2, max_period) |
| pos = torch.arange(max_pos, dtype=freq.dtype) |
| self.register_buffer("args", torch.outer(pos, freq), persistent=False) |
|
|
| def forward(self, pos): |
| args = self.args[pos] |
| cosine = torch.cos(args) |
| sine = torch.sin(args) |
| rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) |
| rope = rope.view(*rope.shape[:-1], 2, 2) |
| return rope.unsqueeze(-4) |
|
|
|
|
| class Kandinsky5RoPE3D(nn.Module): |
| def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): |
| super().__init__() |
| self.axes_dims = axes_dims |
| self.max_pos = max_pos |
| self.max_period = max_period |
|
|
| for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)): |
| freq = get_freqs(axes_dim // 2, max_period) |
| pos = torch.arange(ax_max_pos, dtype=freq.dtype) |
| self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False) |
|
|
| def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): |
| batch_size, duration, height, width = shape |
| args_t = self.args_0[pos[0]] / scale_factor[0] |
| args_h = self.args_1[pos[1]] / scale_factor[1] |
| args_w = self.args_2[pos[2]] / scale_factor[2] |
|
|
| args = torch.cat( |
| [ |
| args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), |
| args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), |
| args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), |
| ], |
| dim=-1, |
| ) |
| cosine = torch.cos(args) |
| sine = torch.sin(args) |
| rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) |
| rope = rope.view(*rope.shape[:-1], 2, 2) |
| return rope.unsqueeze(-4) |
|
|
|
|
| class Kandinsky5Modulation(nn.Module): |
| def __init__(self, time_dim, model_dim, num_params): |
| super().__init__() |
| self.activation = nn.SiLU() |
| self.out_layer = nn.Linear(time_dim, num_params * model_dim) |
| self.out_layer.weight.data.zero_() |
| self.out_layer.bias.data.zero_() |
|
|
| def forward(self, x): |
| return self.out_layer(self.activation(x)) |
|
|
|
|
| class Kandinsky5AttnProcessor: |
| _attention_backend = None |
| _parallel_config = None |
|
|
| def __init__(self): |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") |
|
|
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): |
| |
| query = attn.to_query(hidden_states) |
|
|
| if encoder_hidden_states is not None: |
| key = attn.to_key(encoder_hidden_states) |
| value = attn.to_value(encoder_hidden_states) |
|
|
| shape, cond_shape = query.shape[:-1], key.shape[:-1] |
| query = query.reshape(*shape, attn.num_heads, -1) |
| key = key.reshape(*cond_shape, attn.num_heads, -1) |
| value = value.reshape(*cond_shape, attn.num_heads, -1) |
|
|
| else: |
| key = attn.to_key(hidden_states) |
| value = attn.to_value(hidden_states) |
|
|
| shape = query.shape[:-1] |
| query = query.reshape(*shape, attn.num_heads, -1) |
| key = key.reshape(*shape, attn.num_heads, -1) |
| value = value.reshape(*shape, attn.num_heads, -1) |
|
|
| |
| query = attn.query_norm(query.float()).type_as(query) |
| key = attn.key_norm(key.float()).type_as(key) |
|
|
| def apply_rotary(x, rope): |
| x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) |
| x_out = (rope * x_).sum(dim=-1) |
| return x_out.reshape(*x.shape).to(torch.bfloat16) |
|
|
| if rotary_emb is not None: |
| query = apply_rotary(query, rotary_emb).type_as(query) |
| key = apply_rotary(key, rotary_emb).type_as(key) |
|
|
| if sparse_params is not None: |
| attn_mask = nablaT_v2( |
| query, |
| key, |
| sparse_params["sta_mask"], |
| thr=sparse_params["P"], |
| ) |
|
|
| else: |
| attn_mask = None |
|
|
| hidden_states = dispatch_attention_fn( |
| query, |
| key, |
| value, |
| attn_mask=attn_mask, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
|
|
| hidden_states = hidden_states.flatten(-2, -1) |
|
|
| attn_out = attn.out_layer(hidden_states) |
| return attn_out |
|
|
|
|
| class Kandinsky5Attention(nn.Module, AttentionModuleMixin): |
| _default_processor_cls = Kandinsky5AttnProcessor |
| _available_processors = [ |
| Kandinsky5AttnProcessor, |
| ] |
|
|
| def __init__(self, num_channels, head_dim, processor=None): |
| super().__init__() |
| assert num_channels % head_dim == 0 |
| self.num_heads = num_channels // head_dim |
|
|
| self.to_query = nn.Linear(num_channels, num_channels, bias=True) |
| self.to_key = nn.Linear(num_channels, num_channels, bias=True) |
| self.to_value = nn.Linear(num_channels, num_channels, bias=True) |
| self.query_norm = nn.RMSNorm(head_dim) |
| self.key_norm = nn.RMSNorm(head_dim) |
|
|
| self.out_layer = nn.Linear(num_channels, num_channels, bias=True) |
| if processor is None: |
| processor = self._default_processor_cls() |
| self.set_processor(processor) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor | None = None, |
| sparse_params: torch.Tensor | None = None, |
| rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) |
| quiet_attn_parameters = {} |
| unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] |
| if len(unused_kwargs) > 0: |
| logger.warning( |
| f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." |
| ) |
| kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} |
|
|
| return self.processor( |
| self, |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| sparse_params=sparse_params, |
| rotary_emb=rotary_emb, |
| **kwargs, |
| ) |
|
|
|
|
| class Kandinsky5FeedForward(nn.Module): |
| def __init__(self, dim, ff_dim): |
| super().__init__() |
| self.in_layer = nn.Linear(dim, ff_dim, bias=False) |
| self.activation = nn.GELU() |
| self.out_layer = nn.Linear(ff_dim, dim, bias=False) |
|
|
| def forward(self, x): |
| return self.out_layer(self.activation(self.in_layer(x))) |
|
|
|
|
| class Kandinsky5OutLayer(nn.Module): |
| def __init__(self, model_dim, time_dim, visual_dim, patch_size): |
| super().__init__() |
| self.patch_size = patch_size |
| self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2) |
| self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) |
| self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True) |
|
|
| def forward(self, visual_embed, text_embed, time_embed): |
| shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) |
|
|
| visual_embed = ( |
| self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None] |
| ).type_as(visual_embed) |
|
|
| x = self.out_layer(visual_embed) |
|
|
| batch_size, duration, height, width, _ = x.shape |
| x = ( |
| x.view( |
| batch_size, |
| duration, |
| height, |
| width, |
| -1, |
| self.patch_size[0], |
| self.patch_size[1], |
| self.patch_size[2], |
| ) |
| .permute(0, 1, 5, 2, 6, 3, 7, 4) |
| .flatten(1, 2) |
| .flatten(2, 3) |
| .flatten(3, 4) |
| ) |
| return x |
|
|
|
|
| class Kandinsky5TransformerEncoderBlock(nn.Module): |
| def __init__(self, model_dim, time_dim, ff_dim, head_dim): |
| super().__init__() |
| self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6) |
|
|
| self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) |
| self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) |
|
|
| self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) |
| self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) |
|
|
| def forward(self, x, time_embed, rope): |
| self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) |
| shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) |
| out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) |
| out = self.self_attention(out, rotary_emb=rope) |
| x = (x.float() + gate.float() * out.float()).type_as(x) |
|
|
| shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) |
| out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) |
| out = self.feed_forward(out) |
| x = (x.float() + gate.float() * out.float()).type_as(x) |
|
|
| return x |
|
|
|
|
| class Kandinsky5TransformerDecoderBlock(nn.Module): |
| def __init__(self, model_dim, time_dim, ff_dim, head_dim): |
| super().__init__() |
| self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) |
|
|
| self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) |
| self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) |
|
|
| self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) |
| self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) |
|
|
| self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) |
| self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) |
|
|
| def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): |
| self_attn_params, cross_attn_params, ff_params = torch.chunk( |
| self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 |
| ) |
|
|
| shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) |
| visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( |
| visual_embed |
| ) |
| visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params) |
| visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) |
|
|
| shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) |
| visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( |
| visual_embed |
| ) |
| visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed) |
| visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) |
|
|
| shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) |
| visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( |
| visual_embed |
| ) |
| visual_out = self.feed_forward(visual_out) |
| visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) |
|
|
| return visual_embed |
|
|
|
|
| class Kandinsky5Transformer3DModel( |
| ModelMixin, |
| ConfigMixin, |
| PeftAdapterMixin, |
| FromOriginalModelMixin, |
| CacheMixin, |
| AttentionMixin, |
| ): |
| """ |
| A 3D Diffusion Transformer model for video-like data. |
| """ |
|
|
| _repeated_blocks = [ |
| "Kandinsky5TransformerEncoderBlock", |
| "Kandinsky5TransformerDecoderBlock", |
| ] |
| _keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"] |
| _supports_gradient_checkpointing = True |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_visual_dim=4, |
| in_text_dim=3584, |
| in_text_dim2=768, |
| time_dim=512, |
| out_visual_dim=4, |
| patch_size=(1, 2, 2), |
| model_dim=2048, |
| ff_dim=5120, |
| num_text_blocks=2, |
| num_visual_blocks=32, |
| axes_dims=(16, 24, 24), |
| visual_cond=False, |
| attention_type: str = "regular", |
| attention_causal: bool = None, |
| attention_local: bool = None, |
| attention_glob: bool = None, |
| attention_window: int = None, |
| attention_P: float = None, |
| attention_wT: int = None, |
| attention_wW: int = None, |
| attention_wH: int = None, |
| attention_add_sta: bool = None, |
| attention_method: str = None, |
| ): |
| super().__init__() |
|
|
| head_dim = sum(axes_dims) |
| self.in_visual_dim = in_visual_dim |
| self.model_dim = model_dim |
| self.patch_size = patch_size |
| self.visual_cond = visual_cond |
| self.attention_type = attention_type |
|
|
| visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim |
|
|
| |
| self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) |
| self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) |
| self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) |
| self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) |
|
|
| |
| self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) |
| self.visual_rope_embeddings = Kandinsky5RoPE3D(axes_dims) |
|
|
| |
| self.text_transformer_blocks = nn.ModuleList( |
| [Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks)] |
| ) |
|
|
| self.visual_transformer_blocks = nn.ModuleList( |
| [ |
| Kandinsky5TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) |
| for _ in range(num_visual_blocks) |
| ] |
| ) |
|
|
| |
| self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| timestep: torch.Tensor, |
| pooled_projections: torch.Tensor, |
| visual_rope_pos: tuple[int, int, int], |
| text_rope_pos: torch.LongTensor, |
| scale_factor: tuple[float, float, float] = (1.0, 1.0, 1.0), |
| sparse_params: dict[str, Any] | None = None, |
| return_dict: bool = True, |
| ) -> Transformer2DModelOutput | torch.FloatTensor: |
| """ |
| Forward pass of the Kandinsky5 3D Transformer. |
| |
| Args: |
| hidden_states (`torch.FloatTensor`): Input visual states |
| encoder_hidden_states (`torch.FloatTensor`): Text embeddings |
| timestep (`torch.Tensor` or `float` or `int`): Current timestep |
| pooled_projections (`torch.FloatTensor`): Pooled text embeddings |
| visual_rope_pos (`tuple[int, int, int]`): Position for visual RoPE |
| text_rope_pos (`torch.LongTensor`): Position for text RoPE |
| scale_factor (`tuple[float, float, float]`, optional): Scale factor for RoPE |
| sparse_params (`dict[str, Any]`, optional): Parameters for sparse attention |
| return_dict (`bool`, optional): Whether to return a dictionary |
| |
| Returns: |
| [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: The output of the transformer |
| """ |
| x = hidden_states |
| text_embed = encoder_hidden_states |
| time = timestep |
| pooled_text_embed = pooled_projections |
|
|
| text_embed = self.text_embeddings(text_embed) |
| time_embed = self.time_embeddings(time) |
| time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) |
| visual_embed = self.visual_embeddings(x) |
| text_rope = self.text_rope_embeddings(text_rope_pos) |
| text_rope = text_rope.unsqueeze(dim=0) |
|
|
| for text_transformer_block in self.text_transformer_blocks: |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| text_embed = self._gradient_checkpointing_func( |
| text_transformer_block, text_embed, time_embed, text_rope |
| ) |
| else: |
| text_embed = text_transformer_block(text_embed, time_embed, text_rope) |
|
|
| visual_shape = visual_embed.shape[:-1] |
| visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) |
| to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False |
| visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, block_mask=to_fractal) |
|
|
| for visual_transformer_block in self.visual_transformer_blocks: |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| visual_embed = self._gradient_checkpointing_func( |
| visual_transformer_block, |
| visual_embed, |
| text_embed, |
| time_embed, |
| visual_rope, |
| sparse_params, |
| ) |
| else: |
| visual_embed = visual_transformer_block( |
| visual_embed, text_embed, time_embed, visual_rope, sparse_params |
| ) |
|
|
| visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) |
| x = self.out_layer(visual_embed, text_embed, time_embed) |
|
|
| if not return_dict: |
| return x |
|
|
| return Transformer2DModelOutput(sample=x) |
|
|