| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin |
| from ...utils import apply_lora_scale, logging |
| from .._modeling_parallel import ContextParallelInput, ContextParallelOutput |
| from ..attention import AttentionMixin, AttentionModuleMixin |
| from ..attention_dispatch import dispatch_attention_fn |
| from ..cache_utils import CacheMixin |
| from ..embeddings import ( |
| TimestepEmbedding, |
| Timesteps, |
| apply_rotary_emb, |
| get_1d_rotary_pos_embed, |
| ) |
| from ..modeling_outputs import Transformer2DModelOutput |
| from ..modeling_utils import ModelMixin |
| from ..normalization import AdaLayerNormContinuous |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): |
| query = attn.to_q(hidden_states) |
| key = attn.to_k(hidden_states) |
| value = attn.to_v(hidden_states) |
|
|
| encoder_query = encoder_key = encoder_value = None |
| if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: |
| encoder_query = attn.add_q_proj(encoder_hidden_states) |
| encoder_key = attn.add_k_proj(encoder_hidden_states) |
| encoder_value = attn.add_v_proj(encoder_hidden_states) |
|
|
| return query, key, value, encoder_query, encoder_key, encoder_value |
|
|
|
|
| def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): |
| query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) |
|
|
| encoder_query = encoder_key = encoder_value = (None,) |
| if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): |
| encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) |
|
|
| return query, key, value, encoder_query, encoder_key, encoder_value |
|
|
|
|
| def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None): |
| if attn.fused_projections: |
| return _get_fused_projections(attn, hidden_states, encoder_hidden_states) |
| return _get_projections(attn, hidden_states, encoder_hidden_states) |
|
|
|
|
| class Flux2SwiGLU(nn.Module): |
| """ |
| Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection |
| layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.gate_fn = nn.SiLU() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x.chunk(2, dim=-1) |
| x = self.gate_fn(x1) * x2 |
| return x |
|
|
|
|
| class Flux2FeedForward(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| dim_out: int | None = None, |
| mult: float = 3.0, |
| inner_dim: int | None = None, |
| bias: bool = False, |
| ): |
| super().__init__() |
| if inner_dim is None: |
| inner_dim = int(dim * mult) |
| dim_out = dim_out or dim |
|
|
| |
| self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias) |
| self.act_fn = Flux2SwiGLU() |
| self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.linear_in(x) |
| x = self.act_fn(x) |
| x = self.linear_out(x) |
| return x |
|
|
|
|
| class Flux2AttnProcessor: |
| _attention_backend = None |
| _parallel_config = None |
|
|
| def __init__(self): |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") |
|
|
| def __call__( |
| self, |
| attn: "Flux2Attention", |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor = None, |
| attention_mask: torch.Tensor | None = None, |
| image_rotary_emb: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( |
| attn, hidden_states, encoder_hidden_states |
| ) |
|
|
| query = query.unflatten(-1, (attn.heads, -1)) |
| key = key.unflatten(-1, (attn.heads, -1)) |
| value = value.unflatten(-1, (attn.heads, -1)) |
|
|
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
|
|
| if attn.added_kv_proj_dim is not None: |
| encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) |
| encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) |
| encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) |
|
|
| encoder_query = attn.norm_added_q(encoder_query) |
| encoder_key = attn.norm_added_k(encoder_key) |
|
|
| query = torch.cat([encoder_query, query], dim=1) |
| key = torch.cat([encoder_key, key], dim=1) |
| value = torch.cat([encoder_value, value], dim=1) |
|
|
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
| hidden_states = dispatch_attention_fn( |
| query, |
| key, |
| value, |
| attn_mask=attention_mask, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
| hidden_states = hidden_states.flatten(2, 3) |
| hidden_states = hidden_states.to(query.dtype) |
|
|
| if encoder_hidden_states is not None: |
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( |
| [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 |
| ) |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
|
|
| hidden_states = attn.to_out[0](hidden_states) |
| hidden_states = attn.to_out[1](hidden_states) |
|
|
| if encoder_hidden_states is not None: |
| return hidden_states, encoder_hidden_states |
| else: |
| return hidden_states |
|
|
|
|
| class Flux2Attention(torch.nn.Module, AttentionModuleMixin): |
| _default_processor_cls = Flux2AttnProcessor |
| _available_processors = [Flux2AttnProcessor] |
|
|
| def __init__( |
| self, |
| query_dim: int, |
| heads: int = 8, |
| dim_head: int = 64, |
| dropout: float = 0.0, |
| bias: bool = False, |
| added_kv_proj_dim: int | None = None, |
| added_proj_bias: bool | None = True, |
| out_bias: bool = True, |
| eps: float = 1e-5, |
| out_dim: int = None, |
| elementwise_affine: bool = True, |
| processor=None, |
| ): |
| super().__init__() |
|
|
| self.head_dim = dim_head |
| self.inner_dim = out_dim if out_dim is not None else dim_head * heads |
| self.query_dim = query_dim |
| self.out_dim = out_dim if out_dim is not None else query_dim |
| self.heads = out_dim // dim_head if out_dim is not None else heads |
|
|
| self.use_bias = bias |
| self.dropout = dropout |
|
|
| self.added_kv_proj_dim = added_kv_proj_dim |
| self.added_proj_bias = added_proj_bias |
|
|
| self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) |
| self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) |
| self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) |
|
|
| |
| self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) |
| self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) |
|
|
| self.to_out = torch.nn.ModuleList([]) |
| self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) |
| self.to_out.append(torch.nn.Dropout(dropout)) |
|
|
| if added_kv_proj_dim is not None: |
| self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) |
| self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) |
| self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) |
| self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) |
| self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) |
| self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) |
|
|
| if processor is None: |
| processor = self._default_processor_cls() |
| self.set_processor(processor) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| image_rotary_emb: torch.Tensor | None = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) |
| unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] |
| if len(unused_kwargs) > 0: |
| logger.warning( |
| f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." |
| ) |
| kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} |
| return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) |
|
|
|
|
| class Flux2ParallelSelfAttnProcessor: |
| _attention_backend = None |
| _parallel_config = None |
|
|
| def __init__(self): |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") |
|
|
| def __call__( |
| self, |
| attn: "Flux2ParallelSelfAttention", |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| image_rotary_emb: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| |
| hidden_states = attn.to_qkv_mlp_proj(hidden_states) |
| qkv, mlp_hidden_states = torch.split( |
| hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1 |
| ) |
|
|
| |
| query, key, value = qkv.chunk(3, dim=-1) |
|
|
| query = query.unflatten(-1, (attn.heads, -1)) |
| key = key.unflatten(-1, (attn.heads, -1)) |
| value = value.unflatten(-1, (attn.heads, -1)) |
|
|
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
|
|
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
| hidden_states = dispatch_attention_fn( |
| query, |
| key, |
| value, |
| attn_mask=attention_mask, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
| hidden_states = hidden_states.flatten(2, 3) |
| hidden_states = hidden_states.to(query.dtype) |
|
|
| |
| mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states) |
|
|
| |
| hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1) |
| hidden_states = attn.to_out(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin): |
| """ |
| Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks. |
| |
| This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF) |
| input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B |
| paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block. |
| """ |
|
|
| _default_processor_cls = Flux2ParallelSelfAttnProcessor |
| _available_processors = [Flux2ParallelSelfAttnProcessor] |
| |
| _supports_qkv_fusion = False |
|
|
| def __init__( |
| self, |
| query_dim: int, |
| heads: int = 8, |
| dim_head: int = 64, |
| dropout: float = 0.0, |
| bias: bool = False, |
| out_bias: bool = True, |
| eps: float = 1e-5, |
| out_dim: int = None, |
| elementwise_affine: bool = True, |
| mlp_ratio: float = 4.0, |
| mlp_mult_factor: int = 2, |
| processor=None, |
| ): |
| super().__init__() |
|
|
| self.head_dim = dim_head |
| self.inner_dim = out_dim if out_dim is not None else dim_head * heads |
| self.query_dim = query_dim |
| self.out_dim = out_dim if out_dim is not None else query_dim |
| self.heads = out_dim // dim_head if out_dim is not None else heads |
|
|
| self.use_bias = bias |
| self.dropout = dropout |
|
|
| self.mlp_ratio = mlp_ratio |
| self.mlp_hidden_dim = int(query_dim * self.mlp_ratio) |
| self.mlp_mult_factor = mlp_mult_factor |
|
|
| |
| self.to_qkv_mlp_proj = torch.nn.Linear( |
| self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias |
| ) |
| self.mlp_act_fn = Flux2SwiGLU() |
|
|
| |
| self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) |
| self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) |
|
|
| |
| self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias) |
|
|
| if processor is None: |
| processor = self._default_processor_cls() |
| self.set_processor(processor) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| image_rotary_emb: torch.Tensor | None = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) |
| unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] |
| if len(unused_kwargs) > 0: |
| logger.warning( |
| f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." |
| ) |
| kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} |
| return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs) |
|
|
|
|
| class Flux2SingleTransformerBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_attention_heads: int, |
| attention_head_dim: int, |
| mlp_ratio: float = 3.0, |
| eps: float = 1e-6, |
| bias: bool = False, |
| ): |
| super().__init__() |
|
|
| self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
|
|
| |
| |
| |
| self.attn = Flux2ParallelSelfAttention( |
| query_dim=dim, |
| dim_head=attention_head_dim, |
| heads=num_attention_heads, |
| out_dim=dim, |
| bias=bias, |
| out_bias=bias, |
| eps=eps, |
| mlp_ratio=mlp_ratio, |
| mlp_mult_factor=2, |
| processor=Flux2ParallelSelfAttnProcessor(), |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor | None, |
| temb_mod: torch.Tensor, |
| image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, |
| joint_attention_kwargs: dict[str, Any] | None = None, |
| split_hidden_states: bool = False, |
| text_seq_len: int | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| |
| if encoder_hidden_states is not None: |
| text_seq_len = encoder_hidden_states.shape[1] |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
| mod_shift, mod_scale, mod_gate = Flux2Modulation.split(temb_mod, 1)[0] |
|
|
| norm_hidden_states = self.norm(hidden_states) |
| norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift |
|
|
| joint_attention_kwargs = joint_attention_kwargs or {} |
| attn_output = self.attn( |
| hidden_states=norm_hidden_states, |
| image_rotary_emb=image_rotary_emb, |
| **joint_attention_kwargs, |
| ) |
|
|
| hidden_states = hidden_states + mod_gate * attn_output |
| if hidden_states.dtype == torch.float16: |
| hidden_states = hidden_states.clip(-65504, 65504) |
|
|
| if split_hidden_states: |
| encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] |
| return encoder_hidden_states, hidden_states |
| else: |
| return hidden_states |
|
|
|
|
| class Flux2TransformerBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_attention_heads: int, |
| attention_head_dim: int, |
| mlp_ratio: float = 3.0, |
| eps: float = 1e-6, |
| bias: bool = False, |
| ): |
| super().__init__() |
| self.mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
| self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
| self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
|
|
| self.attn = Flux2Attention( |
| query_dim=dim, |
| added_kv_proj_dim=dim, |
| dim_head=attention_head_dim, |
| heads=num_attention_heads, |
| out_dim=dim, |
| bias=bias, |
| added_proj_bias=bias, |
| out_bias=bias, |
| eps=eps, |
| processor=Flux2AttnProcessor(), |
| ) |
|
|
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
| self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) |
|
|
| self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) |
| self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| temb_mod_img: torch.Tensor, |
| temb_mod_txt: torch.Tensor, |
| image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, |
| joint_attention_kwargs: dict[str, Any] | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| joint_attention_kwargs = joint_attention_kwargs or {} |
|
|
| |
| (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = Flux2Modulation.split(temb_mod_img, 2) |
| (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = Flux2Modulation.split( |
| temb_mod_txt, 2 |
| ) |
|
|
| |
| norm_hidden_states = self.norm1(hidden_states) |
| norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa |
|
|
| |
| norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) |
| norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa |
|
|
| |
| attention_outputs = self.attn( |
| hidden_states=norm_hidden_states, |
| encoder_hidden_states=norm_encoder_hidden_states, |
| image_rotary_emb=image_rotary_emb, |
| **joint_attention_kwargs, |
| ) |
|
|
| attn_output, context_attn_output = attention_outputs |
|
|
| |
| attn_output = gate_msa * attn_output |
| hidden_states = hidden_states + attn_output |
|
|
| norm_hidden_states = self.norm2(hidden_states) |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp |
|
|
| ff_output = self.ff(norm_hidden_states) |
| hidden_states = hidden_states + gate_mlp * ff_output |
|
|
| |
| context_attn_output = c_gate_msa * context_attn_output |
| encoder_hidden_states = encoder_hidden_states + context_attn_output |
|
|
| norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
| norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp |
|
|
| context_ff_output = self.ff_context(norm_encoder_hidden_states) |
| encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output |
| if encoder_hidden_states.dtype == torch.float16: |
| encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) |
|
|
| return encoder_hidden_states, hidden_states |
|
|
|
|
| class Flux2PosEmbed(nn.Module): |
| |
| def __init__(self, theta: int, axes_dim: list[int]): |
| super().__init__() |
| self.theta = theta |
| self.axes_dim = axes_dim |
|
|
| def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| |
| cos_out = [] |
| sin_out = [] |
| pos = ids.float() |
| is_mps = ids.device.type == "mps" |
| is_npu = ids.device.type == "npu" |
| freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 |
| |
| for i in range(len(self.axes_dim)): |
| cos, sin = get_1d_rotary_pos_embed( |
| self.axes_dim[i], |
| pos[..., i], |
| theta=self.theta, |
| repeat_interleave_real=True, |
| use_real=True, |
| freqs_dtype=freqs_dtype, |
| ) |
| cos_out.append(cos) |
| sin_out.append(sin) |
| freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) |
| freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) |
| return freqs_cos, freqs_sin |
|
|
|
|
| class Flux2TimestepGuidanceEmbeddings(nn.Module): |
| def __init__( |
| self, |
| in_channels: int = 256, |
| embedding_dim: int = 6144, |
| bias: bool = False, |
| guidance_embeds: bool = True, |
| ): |
| super().__init__() |
|
|
| self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0) |
| self.timestep_embedder = TimestepEmbedding( |
| in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias |
| ) |
|
|
| if guidance_embeds: |
| self.guidance_embedder = TimestepEmbedding( |
| in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias |
| ) |
| else: |
| self.guidance_embedder = None |
|
|
| def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor: |
| timesteps_proj = self.time_proj(timestep) |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) |
|
|
| if guidance is not None and self.guidance_embedder is not None: |
| guidance_proj = self.time_proj(guidance) |
| guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) |
| time_guidance_emb = timesteps_emb + guidance_emb |
| return time_guidance_emb |
| else: |
| return timesteps_emb |
|
|
|
|
| class Flux2Modulation(nn.Module): |
| def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False): |
| super().__init__() |
| self.mod_param_sets = mod_param_sets |
|
|
| self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias) |
| self.act_fn = nn.SiLU() |
|
|
| def forward(self, temb: torch.Tensor) -> torch.Tensor: |
| mod = self.act_fn(temb) |
| mod = self.linear(mod) |
| return mod |
|
|
| @staticmethod |
| |
| def split(mod: torch.Tensor, mod_param_sets: int) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: |
| if mod.ndim == 2: |
| mod = mod.unsqueeze(1) |
| mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1) |
| |
| return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets)) |
|
|
|
|
| class Flux2Transformer2DModel( |
| ModelMixin, |
| ConfigMixin, |
| PeftAdapterMixin, |
| FromOriginalModelMixin, |
| FluxTransformer2DLoadersMixin, |
| CacheMixin, |
| AttentionMixin, |
| ): |
| """ |
| The Transformer model introduced in Flux 2. |
| |
| Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ |
| |
| Args: |
| patch_size (`int`, defaults to `1`): |
| Patch size to turn the input data into small patches. |
| in_channels (`int`, defaults to `128`): |
| The number of channels in the input. |
| out_channels (`int`, *optional*, defaults to `None`): |
| The number of channels in the output. If not specified, it defaults to `in_channels`. |
| num_layers (`int`, defaults to `8`): |
| The number of layers of dual stream DiT blocks to use. |
| num_single_layers (`int`, defaults to `48`): |
| The number of layers of single stream DiT blocks to use. |
| attention_head_dim (`int`, defaults to `128`): |
| The number of dimensions to use for each attention head. |
| num_attention_heads (`int`, defaults to `48`): |
| The number of attention heads to use. |
| joint_attention_dim (`int`, defaults to `15360`): |
| The number of dimensions to use for the joint attention (embedding/channel dimension of |
| `encoder_hidden_states`). |
| pooled_projection_dim (`int`, defaults to `768`): |
| The number of dimensions to use for the pooled projection. |
| guidance_embeds (`bool`, defaults to `True`): |
| Whether to use guidance embeddings for guidance-distilled variant of the model. |
| axes_dims_rope (`tuple[int]`, defaults to `(32, 32, 32, 32)`): |
| The dimensions to use for the rotary positional embeddings. |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] |
| _skip_layerwise_casting_patterns = ["pos_embed", "norm"] |
| _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"] |
| _cp_plan = { |
| "": { |
| "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| "img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| "txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| }, |
| "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), |
| } |
|
|
| @register_to_config |
| def __init__( |
| self, |
| patch_size: int = 1, |
| in_channels: int = 128, |
| out_channels: int | None = None, |
| num_layers: int = 8, |
| num_single_layers: int = 48, |
| attention_head_dim: int = 128, |
| num_attention_heads: int = 48, |
| joint_attention_dim: int = 15360, |
| timestep_guidance_channels: int = 256, |
| mlp_ratio: float = 3.0, |
| axes_dims_rope: tuple[int, ...] = (32, 32, 32, 32), |
| rope_theta: int = 2000, |
| eps: float = 1e-6, |
| guidance_embeds: bool = True, |
| ): |
| super().__init__() |
| self.out_channels = out_channels or in_channels |
| self.inner_dim = num_attention_heads * attention_head_dim |
|
|
| |
| self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope) |
|
|
| |
| self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings( |
| in_channels=timestep_guidance_channels, |
| embedding_dim=self.inner_dim, |
| bias=False, |
| guidance_embeds=guidance_embeds, |
| ) |
|
|
| |
| |
| self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) |
| self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) |
| |
| self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False) |
|
|
| |
| self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False) |
| self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| Flux2TransformerBlock( |
| dim=self.inner_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| mlp_ratio=mlp_ratio, |
| eps=eps, |
| bias=False, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| |
| self.single_transformer_blocks = nn.ModuleList( |
| [ |
| Flux2SingleTransformerBlock( |
| dim=self.inner_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| mlp_ratio=mlp_ratio, |
| eps=eps, |
| bias=False, |
| ) |
| for _ in range(num_single_layers) |
| ] |
| ) |
|
|
| |
| self.norm_out = AdaLayerNormContinuous( |
| self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False |
| ) |
| self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False) |
|
|
| self.gradient_checkpointing = False |
|
|
| @apply_lora_scale("joint_attention_kwargs") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor = None, |
| timestep: torch.LongTensor = None, |
| img_ids: torch.Tensor = None, |
| txt_ids: torch.Tensor = None, |
| guidance: torch.Tensor = None, |
| joint_attention_kwargs: dict[str, Any] | None = None, |
| return_dict: bool = True, |
| ) -> torch.Tensor | Transformer2DModelOutput: |
| """ |
| The [`FluxTransformer2DModel`] forward method. |
| |
| Args: |
| hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): |
| Input `hidden_states`. |
| encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): |
| Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. |
| timestep ( `torch.LongTensor`): |
| Used to indicate denoising step. |
| block_controlnet_hidden_states: (`list` of `torch.Tensor`): |
| A list of tensors that if specified are added to the residuals of transformer blocks. |
| joint_attention_kwargs (`dict`, *optional*): |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| `self.processor` in |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
| tuple. |
| |
| Returns: |
| If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
| `tuple` where the first element is the sample tensor. |
| """ |
| |
|
|
| num_txt_tokens = encoder_hidden_states.shape[1] |
|
|
| |
| timestep = timestep.to(hidden_states.dtype) * 1000 |
|
|
| if guidance is not None: |
| guidance = guidance.to(hidden_states.dtype) * 1000 |
|
|
| temb = self.time_guidance_embed(timestep, guidance) |
|
|
| double_stream_mod_img = self.double_stream_modulation_img(temb) |
| double_stream_mod_txt = self.double_stream_modulation_txt(temb) |
| single_stream_mod = self.single_stream_modulation(temb) |
|
|
| |
| hidden_states = self.x_embedder(hidden_states) |
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) |
|
|
| |
| |
| |
| if img_ids.ndim == 3: |
| img_ids = img_ids[0] |
| if txt_ids.ndim == 3: |
| txt_ids = txt_ids[0] |
|
|
| image_rotary_emb = self.pos_embed(img_ids) |
| text_rotary_emb = self.pos_embed(txt_ids) |
| concat_rotary_emb = ( |
| torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), |
| torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), |
| ) |
|
|
| |
| for index_block, block in enumerate(self.transformer_blocks): |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( |
| block, |
| hidden_states, |
| encoder_hidden_states, |
| double_stream_mod_img, |
| double_stream_mod_txt, |
| concat_rotary_emb, |
| joint_attention_kwargs, |
| ) |
| else: |
| encoder_hidden_states, hidden_states = block( |
| hidden_states=hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| temb_mod_img=double_stream_mod_img, |
| temb_mod_txt=double_stream_mod_txt, |
| image_rotary_emb=concat_rotary_emb, |
| joint_attention_kwargs=joint_attention_kwargs, |
| ) |
| |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
| |
| for index_block, block in enumerate(self.single_transformer_blocks): |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| hidden_states = self._gradient_checkpointing_func( |
| block, |
| hidden_states, |
| None, |
| single_stream_mod, |
| concat_rotary_emb, |
| joint_attention_kwargs, |
| ) |
| else: |
| hidden_states = block( |
| hidden_states=hidden_states, |
| encoder_hidden_states=None, |
| temb_mod=single_stream_mod, |
| image_rotary_emb=concat_rotary_emb, |
| joint_attention_kwargs=joint_attention_kwargs, |
| ) |
| |
| hidden_states = hidden_states[:, num_txt_tokens:, ...] |
|
|
| |
| hidden_states = self.norm_out(hidden_states, temb) |
| output = self.proj_out(hidden_states) |
|
|
| if not return_dict: |
| return (output,) |
|
|
| return Transformer2DModelOutput(sample=output) |
|
|