|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
|
import math |
|
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from ...configuration_utils import ConfigMixin, register_to_config |
|
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin |
|
|
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers |
|
|
from ...utils.torch_utils import maybe_allow_in_graph |
|
|
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward |
|
|
from ..attention_dispatch import dispatch_attention_fn |
|
|
from ..cache_utils import CacheMixin |
|
|
from ..embeddings import PixArtAlphaTextProjection |
|
|
from ..modeling_outputs import Transformer2DModelOutput |
|
|
from ..modeling_utils import ModelMixin |
|
|
from ..normalization import AdaLayerNormSingle, RMSNorm |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class LTXVideoAttentionProcessor2_0: |
|
|
def __new__(cls, *args, **kwargs): |
|
|
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`" |
|
|
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message) |
|
|
|
|
|
return LTXVideoAttnProcessor(*args, **kwargs) |
|
|
|
|
|
|
|
|
class LTXVideoAttnProcessor: |
|
|
r""" |
|
|
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX |
|
|
model. It applies a normalization layer and rotary embedding on the query and key vector. |
|
|
""" |
|
|
|
|
|
_attention_backend = None |
|
|
|
|
|
def __init__(self): |
|
|
if is_torch_version("<", "2.0"): |
|
|
raise ValueError( |
|
|
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." |
|
|
) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
attn: "LTXAttention", |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
batch_size, sequence_length, _ = ( |
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
|
) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
|
|
if encoder_hidden_states is None: |
|
|
encoder_hidden_states = hidden_states |
|
|
|
|
|
query = attn.to_q(hidden_states) |
|
|
key = attn.to_k(encoder_hidden_states) |
|
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
|
|
query = attn.norm_q(query) |
|
|
key = attn.norm_k(key) |
|
|
|
|
|
if image_rotary_emb is not None: |
|
|
query = apply_rotary_emb(query, image_rotary_emb) |
|
|
key = apply_rotary_emb(key, image_rotary_emb) |
|
|
|
|
|
query = query.unflatten(2, (attn.heads, -1)) |
|
|
key = key.unflatten(2, (attn.heads, -1)) |
|
|
value = value.unflatten(2, (attn.heads, -1)) |
|
|
|
|
|
hidden_states = dispatch_attention_fn( |
|
|
query, |
|
|
key, |
|
|
value, |
|
|
attn_mask=attention_mask, |
|
|
dropout_p=0.0, |
|
|
is_causal=False, |
|
|
backend=self._attention_backend, |
|
|
) |
|
|
hidden_states = hidden_states.flatten(2, 3) |
|
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class LTXAttention(torch.nn.Module, AttentionModuleMixin): |
|
|
_default_processor_cls = LTXVideoAttnProcessor |
|
|
_available_processors = [LTXVideoAttnProcessor] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
query_dim: int, |
|
|
heads: int = 8, |
|
|
kv_heads: int = 8, |
|
|
dim_head: int = 64, |
|
|
dropout: float = 0.0, |
|
|
bias: bool = True, |
|
|
cross_attention_dim: Optional[int] = None, |
|
|
out_bias: bool = True, |
|
|
qk_norm: str = "rms_norm_across_heads", |
|
|
processor=None, |
|
|
): |
|
|
super().__init__() |
|
|
if qk_norm != "rms_norm_across_heads": |
|
|
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") |
|
|
|
|
|
self.head_dim = dim_head |
|
|
self.inner_dim = dim_head * heads |
|
|
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads |
|
|
self.query_dim = query_dim |
|
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim |
|
|
self.use_bias = bias |
|
|
self.dropout = dropout |
|
|
self.out_dim = query_dim |
|
|
self.heads = heads |
|
|
|
|
|
norm_eps = 1e-5 |
|
|
norm_elementwise_affine = True |
|
|
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) |
|
|
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) |
|
|
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) |
|
|
self.to_out = torch.nn.ModuleList([]) |
|
|
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) |
|
|
self.to_out.append(torch.nn.Dropout(dropout)) |
|
|
|
|
|
if processor is None: |
|
|
processor = self._default_processor_cls() |
|
|
self.set_processor(processor) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) |
|
|
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] |
|
|
if len(unused_kwargs) > 0: |
|
|
logger.warning( |
|
|
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." |
|
|
) |
|
|
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} |
|
|
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) |
|
|
|
|
|
|
|
|
class LTXVideoRotaryPosEmbed(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
base_num_frames: int = 20, |
|
|
base_height: int = 2048, |
|
|
base_width: int = 2048, |
|
|
patch_size: int = 1, |
|
|
patch_size_t: int = 1, |
|
|
theta: float = 10000.0, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.dim = dim |
|
|
self.base_num_frames = base_num_frames |
|
|
self.base_height = base_height |
|
|
self.base_width = base_width |
|
|
self.patch_size = patch_size |
|
|
self.patch_size_t = patch_size_t |
|
|
self.theta = theta |
|
|
|
|
|
def _prepare_video_coords( |
|
|
self, |
|
|
batch_size: int, |
|
|
num_frames: int, |
|
|
height: int, |
|
|
width: int, |
|
|
rope_interpolation_scale: Tuple[torch.Tensor, float, float], |
|
|
device: torch.device, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
grid_h = torch.arange(height, dtype=torch.float32, device=device) |
|
|
grid_w = torch.arange(width, dtype=torch.float32, device=device) |
|
|
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device) |
|
|
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") |
|
|
grid = torch.stack(grid, dim=0) |
|
|
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) |
|
|
|
|
|
if rope_interpolation_scale is not None: |
|
|
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames |
|
|
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height |
|
|
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width |
|
|
|
|
|
grid = grid.flatten(2, 4).transpose(1, 2) |
|
|
|
|
|
return grid |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
num_frames: Optional[int] = None, |
|
|
height: Optional[int] = None, |
|
|
width: Optional[int] = None, |
|
|
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, |
|
|
video_coords: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
batch_size = hidden_states.size(0) |
|
|
|
|
|
if video_coords is None: |
|
|
grid = self._prepare_video_coords( |
|
|
batch_size, |
|
|
num_frames, |
|
|
height, |
|
|
width, |
|
|
rope_interpolation_scale=rope_interpolation_scale, |
|
|
device=hidden_states.device, |
|
|
) |
|
|
else: |
|
|
grid = torch.stack( |
|
|
[ |
|
|
video_coords[:, 0] / self.base_num_frames, |
|
|
video_coords[:, 1] / self.base_height, |
|
|
video_coords[:, 2] / self.base_width, |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
start = 1.0 |
|
|
end = self.theta |
|
|
freqs = self.theta ** torch.linspace( |
|
|
math.log(start, self.theta), |
|
|
math.log(end, self.theta), |
|
|
self.dim // 6, |
|
|
device=hidden_states.device, |
|
|
dtype=torch.float32, |
|
|
) |
|
|
freqs = freqs * math.pi / 2.0 |
|
|
freqs = freqs * (grid.unsqueeze(-1) * 2 - 1) |
|
|
freqs = freqs.transpose(-1, -2).flatten(2) |
|
|
|
|
|
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) |
|
|
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) |
|
|
|
|
|
if self.dim % 6 != 0: |
|
|
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6]) |
|
|
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6]) |
|
|
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) |
|
|
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) |
|
|
|
|
|
return cos_freqs, sin_freqs |
|
|
|
|
|
|
|
|
@maybe_allow_in_graph |
|
|
class LTXVideoTransformerBlock(nn.Module): |
|
|
r""" |
|
|
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video). |
|
|
|
|
|
Args: |
|
|
dim (`int`): |
|
|
The number of channels in the input and output. |
|
|
num_attention_heads (`int`): |
|
|
The number of heads to use for multi-head attention. |
|
|
attention_head_dim (`int`): |
|
|
The number of channels in each head. |
|
|
qk_norm (`str`, defaults to `"rms_norm"`): |
|
|
The normalization layer to use. |
|
|
activation_fn (`str`, defaults to `"gelu-approximate"`): |
|
|
Activation function to use in feed-forward. |
|
|
eps (`float`, defaults to `1e-6`): |
|
|
Epsilon value for normalization layers. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
cross_attention_dim: int, |
|
|
qk_norm: str = "rms_norm_across_heads", |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
eps: float = 1e-6, |
|
|
elementwise_affine: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) |
|
|
self.attn1 = LTXAttention( |
|
|
query_dim=dim, |
|
|
heads=num_attention_heads, |
|
|
kv_heads=num_attention_heads, |
|
|
dim_head=attention_head_dim, |
|
|
bias=attention_bias, |
|
|
cross_attention_dim=None, |
|
|
out_bias=attention_out_bias, |
|
|
qk_norm=qk_norm, |
|
|
) |
|
|
|
|
|
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) |
|
|
self.attn2 = LTXAttention( |
|
|
query_dim=dim, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
heads=num_attention_heads, |
|
|
kv_heads=num_attention_heads, |
|
|
dim_head=attention_head_dim, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
qk_norm=qk_norm, |
|
|
) |
|
|
|
|
|
self.ff = FeedForward(dim, activation_fn=activation_fn) |
|
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
batch_size = hidden_states.size(0) |
|
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
|
|
num_ada_params = self.scale_shift_table.shape[0] |
|
|
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) |
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) |
|
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
|
|
|
|
|
attn_hidden_states = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
image_rotary_emb=image_rotary_emb, |
|
|
) |
|
|
hidden_states = hidden_states + attn_hidden_states * gate_msa |
|
|
|
|
|
attn_hidden_states = self.attn2( |
|
|
hidden_states, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
image_rotary_emb=None, |
|
|
attention_mask=encoder_attention_mask, |
|
|
) |
|
|
hidden_states = hidden_states + attn_hidden_states |
|
|
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp |
|
|
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
hidden_states = hidden_states + ff_output * gate_mlp |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
@maybe_allow_in_graph |
|
|
class LTXVideoTransformer3DModel( |
|
|
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin |
|
|
): |
|
|
r""" |
|
|
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). |
|
|
|
|
|
Args: |
|
|
in_channels (`int`, defaults to `128`): |
|
|
The number of channels in the input. |
|
|
out_channels (`int`, defaults to `128`): |
|
|
The number of channels in the output. |
|
|
patch_size (`int`, defaults to `1`): |
|
|
The size of the spatial patches to use in the patch embedding layer. |
|
|
patch_size_t (`int`, defaults to `1`): |
|
|
The size of the tmeporal patches to use in the patch embedding layer. |
|
|
num_attention_heads (`int`, defaults to `32`): |
|
|
The number of heads to use for multi-head attention. |
|
|
attention_head_dim (`int`, defaults to `64`): |
|
|
The number of channels in each head. |
|
|
cross_attention_dim (`int`, defaults to `2048 `): |
|
|
The number of channels for cross attention heads. |
|
|
num_layers (`int`, defaults to `28`): |
|
|
The number of layers of Transformer blocks to use. |
|
|
activation_fn (`str`, defaults to `"gelu-approximate"`): |
|
|
Activation function to use in feed-forward. |
|
|
qk_norm (`str`, defaults to `"rms_norm_across_heads"`): |
|
|
The normalization layer to use. |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
_skip_layerwise_casting_patterns = ["norm"] |
|
|
_repeated_blocks = ["LTXVideoTransformerBlock"] |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 128, |
|
|
out_channels: int = 128, |
|
|
patch_size: int = 1, |
|
|
patch_size_t: int = 1, |
|
|
num_attention_heads: int = 32, |
|
|
attention_head_dim: int = 64, |
|
|
cross_attention_dim: int = 2048, |
|
|
num_layers: int = 28, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
qk_norm: str = "rms_norm_across_heads", |
|
|
norm_elementwise_affine: bool = False, |
|
|
norm_eps: float = 1e-6, |
|
|
caption_channels: int = 4096, |
|
|
attention_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
out_channels = out_channels or in_channels |
|
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
|
|
self.proj_in = nn.Linear(in_channels, inner_dim) |
|
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) |
|
|
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) |
|
|
|
|
|
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) |
|
|
|
|
|
self.rope = LTXVideoRotaryPosEmbed( |
|
|
dim=inner_dim, |
|
|
base_num_frames=20, |
|
|
base_height=2048, |
|
|
base_width=2048, |
|
|
patch_size=patch_size, |
|
|
patch_size_t=patch_size_t, |
|
|
theta=10000.0, |
|
|
) |
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
LTXVideoTransformerBlock( |
|
|
dim=inner_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
cross_attention_dim=cross_attention_dim, |
|
|
qk_norm=qk_norm, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
attention_out_bias=attention_out_bias, |
|
|
eps=norm_eps, |
|
|
elementwise_affine=norm_elementwise_affine, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) |
|
|
self.proj_out = nn.Linear(inner_dim, out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
timestep: torch.LongTensor, |
|
|
encoder_attention_mask: torch.Tensor, |
|
|
num_frames: Optional[int] = None, |
|
|
height: Optional[int] = None, |
|
|
width: Optional[int] = None, |
|
|
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None, |
|
|
video_coords: Optional[torch.Tensor] = None, |
|
|
attention_kwargs: Optional[Dict[str, Any]] = None, |
|
|
return_dict: bool = True, |
|
|
) -> torch.Tensor: |
|
|
if attention_kwargs is not None: |
|
|
attention_kwargs = attention_kwargs.copy() |
|
|
lora_scale = attention_kwargs.pop("scale", 1.0) |
|
|
else: |
|
|
lora_scale = 1.0 |
|
|
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
|
|
scale_lora_layers(self, lora_scale) |
|
|
else: |
|
|
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: |
|
|
logger.warning( |
|
|
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." |
|
|
) |
|
|
|
|
|
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords) |
|
|
|
|
|
|
|
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
|
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
|
|
|
batch_size = hidden_states.size(0) |
|
|
hidden_states = self.proj_in(hidden_states) |
|
|
|
|
|
temb, embedded_timestep = self.time_embed( |
|
|
timestep.flatten(), |
|
|
batch_size=batch_size, |
|
|
hidden_dtype=hidden_states.dtype, |
|
|
) |
|
|
|
|
|
temb = temb.view(batch_size, -1, temb.size(-1)) |
|
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) |
|
|
|
|
|
encoder_hidden_states = self.caption_projection(encoder_hidden_states) |
|
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) |
|
|
|
|
|
for block in self.transformer_blocks: |
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
|
hidden_states = self._gradient_checkpointing_func( |
|
|
block, |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
temb, |
|
|
image_rotary_emb, |
|
|
encoder_attention_mask, |
|
|
) |
|
|
else: |
|
|
hidden_states = block( |
|
|
hidden_states=hidden_states, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
temb=temb, |
|
|
image_rotary_emb=image_rotary_emb, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
) |
|
|
|
|
|
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] |
|
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
|
hidden_states = hidden_states * (1 + scale) + shift |
|
|
output = self.proj_out(hidden_states) |
|
|
|
|
|
if USE_PEFT_BACKEND: |
|
|
|
|
|
unscale_lora_layers(self, lora_scale) |
|
|
|
|
|
if not return_dict: |
|
|
return (output,) |
|
|
return Transformer2DModelOutput(sample=output) |
|
|
|
|
|
|
|
|
def apply_rotary_emb(x, freqs): |
|
|
cos, sin = freqs |
|
|
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) |
|
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) |
|
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) |
|
|
return out |
|
|
|