| import inspect |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import FromOriginalModelMixin, PeftAdapterMixin |
| from ...utils import apply_lora_scale, logging |
| from ...utils.torch_utils import maybe_allow_in_graph |
| from ..attention import AttentionModuleMixin, FeedForward |
| from ..attention_dispatch import dispatch_attention_fn |
| from ..cache_utils import CacheMixin |
| from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding |
| from ..modeling_outputs import Transformer2DModelOutput |
| from ..modeling_utils import ModelMixin |
| from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def _get_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None): |
| query = attn.to_q(hidden_states) |
| key = attn.to_k(hidden_states) |
| value = attn.to_v(hidden_states) |
|
|
| encoder_query = encoder_key = encoder_value = None |
| if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: |
| encoder_query = attn.add_q_proj(encoder_hidden_states) |
| encoder_key = attn.add_k_proj(encoder_hidden_states) |
| encoder_value = attn.add_v_proj(encoder_hidden_states) |
|
|
| return query, key, value, encoder_query, encoder_key, encoder_value |
|
|
|
|
| def _get_fused_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None): |
| query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) |
|
|
| encoder_query = encoder_key = encoder_value = (None,) |
| if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): |
| encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) |
|
|
| return query, key, value, encoder_query, encoder_key, encoder_value |
|
|
|
|
| def _get_qkv_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None): |
| if attn.fused_projections: |
| return _get_fused_projections(attn, hidden_states, encoder_hidden_states) |
| return _get_projections(attn, hidden_states, encoder_hidden_states) |
|
|
|
|
| def get_1d_rotary_pos_embed( |
| dim: int, |
| pos: np.ndarray | int, |
| theta: float = 10000.0, |
| use_real=False, |
| linear_factor=1.0, |
| ntk_factor=1.0, |
| repeat_interleave_real=True, |
| freqs_dtype=torch.float32, |
| ): |
| """ |
| Precompute the frequency tensor for complex exponentials (cis) with given dimensions. |
| |
| This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end |
| index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 |
| data type. |
| |
| Args: |
| dim (`int`): Dimension of the frequency tensor. |
| pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar |
| theta (`float`, *optional*, defaults to 10000.0): |
| Scaling factor for frequency computation. Defaults to 10000.0. |
| use_real (`bool`, *optional*): |
| If True, return real part and imaginary part separately. Otherwise, return complex numbers. |
| linear_factor (`float`, *optional*, defaults to 1.0): |
| Scaling factor for the context extrapolation. Defaults to 1.0. |
| ntk_factor (`float`, *optional*, defaults to 1.0): |
| Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. |
| repeat_interleave_real (`bool`, *optional*, defaults to `True`): |
| If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. |
| Otherwise, they are concateanted with themselves. |
| freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): |
| the dtype of the frequency tensor. |
| Returns: |
| `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] |
| """ |
| assert dim % 2 == 0 |
|
|
| if isinstance(pos, int): |
| pos = torch.arange(pos) |
| if isinstance(pos, np.ndarray): |
| pos = torch.from_numpy(pos) |
|
|
| theta = theta * ntk_factor |
| freqs = ( |
| 1.0 |
| / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) |
| / linear_factor |
| ) |
| freqs = torch.outer(pos, freqs) |
| if use_real and repeat_interleave_real: |
| |
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() |
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() |
| return freqs_cos, freqs_sin |
| elif use_real: |
| |
| freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() |
| freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() |
| return freqs_cos, freqs_sin |
| else: |
| |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| class BriaAttnProcessor: |
| _attention_backend = None |
| _parallel_config = None |
|
|
| def __init__(self): |
| if not hasattr(F, "scaled_dot_product_attention"): |
| raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") |
|
|
| def __call__( |
| self, |
| attn: "BriaAttention", |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor = None, |
| attention_mask: torch.Tensor | None = None, |
| image_rotary_emb: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( |
| attn, hidden_states, encoder_hidden_states |
| ) |
|
|
| query = query.unflatten(-1, (attn.heads, -1)) |
| key = key.unflatten(-1, (attn.heads, -1)) |
| value = value.unflatten(-1, (attn.heads, -1)) |
|
|
| query = attn.norm_q(query) |
| key = attn.norm_k(key) |
|
|
| if attn.added_kv_proj_dim is not None: |
| encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) |
| encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) |
| encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) |
|
|
| encoder_query = attn.norm_added_q(encoder_query) |
| encoder_key = attn.norm_added_k(encoder_key) |
|
|
| query = torch.cat([encoder_query, query], dim=1) |
| key = torch.cat([encoder_key, key], dim=1) |
| value = torch.cat([encoder_value, value], dim=1) |
|
|
| if image_rotary_emb is not None: |
| query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
| key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
| hidden_states = dispatch_attention_fn( |
| query, |
| key, |
| value, |
| attn_mask=attention_mask, |
| backend=self._attention_backend, |
| parallel_config=self._parallel_config, |
| ) |
| hidden_states = hidden_states.flatten(2, 3) |
| hidden_states = hidden_states.to(query.dtype) |
|
|
| if encoder_hidden_states is not None: |
| encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( |
| [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 |
| ) |
| hidden_states = attn.to_out[0](hidden_states) |
| hidden_states = attn.to_out[1](hidden_states) |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
|
|
| return hidden_states, encoder_hidden_states |
| else: |
| return hidden_states |
|
|
|
|
| class BriaAttention(torch.nn.Module, AttentionModuleMixin): |
| _default_processor_cls = BriaAttnProcessor |
| _available_processors = [ |
| BriaAttnProcessor, |
| ] |
|
|
| def __init__( |
| self, |
| query_dim: int, |
| heads: int = 8, |
| dim_head: int = 64, |
| dropout: float = 0.0, |
| bias: bool = False, |
| added_kv_proj_dim: int | None = None, |
| added_proj_bias: bool | None = True, |
| out_bias: bool = True, |
| eps: float = 1e-5, |
| out_dim: int = None, |
| context_pre_only: bool | None = None, |
| pre_only: bool = False, |
| elementwise_affine: bool = True, |
| processor=None, |
| ): |
| super().__init__() |
|
|
| self.head_dim = dim_head |
| self.inner_dim = out_dim if out_dim is not None else dim_head * heads |
| self.query_dim = query_dim |
| self.use_bias = bias |
| self.dropout = dropout |
| self.out_dim = out_dim if out_dim is not None else query_dim |
| self.context_pre_only = context_pre_only |
| self.pre_only = pre_only |
| self.heads = out_dim // dim_head if out_dim is not None else heads |
| self.added_kv_proj_dim = added_kv_proj_dim |
| self.added_proj_bias = added_proj_bias |
|
|
| self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) |
| self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) |
| self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) |
| self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) |
| self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) |
|
|
| if not self.pre_only: |
| self.to_out = torch.nn.ModuleList([]) |
| self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) |
| self.to_out.append(torch.nn.Dropout(dropout)) |
|
|
| if added_kv_proj_dim is not None: |
| self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) |
| self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) |
| self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) |
| self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) |
| self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) |
| self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) |
|
|
| if processor is None: |
| processor = self._default_processor_cls() |
| self.set_processor(processor) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| image_rotary_emb: torch.Tensor | None = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) |
| quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} |
| unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] |
| if len(unused_kwargs) > 0: |
| logger.warning( |
| f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." |
| ) |
| kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} |
| return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) |
|
|
|
|
| class BriaEmbedND(torch.nn.Module): |
| |
| def __init__(self, theta: int, axes_dim: list[int]): |
| super().__init__() |
| self.theta = theta |
| self.axes_dim = axes_dim |
|
|
| def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| n_axes = ids.shape[-1] |
| cos_out = [] |
| sin_out = [] |
| pos = ids.float() |
| is_mps = ids.device.type == "mps" |
| freqs_dtype = torch.float32 if is_mps else torch.float64 |
| for i in range(n_axes): |
| cos, sin = get_1d_rotary_pos_embed( |
| self.axes_dim[i], |
| pos[:, i], |
| theta=self.theta, |
| repeat_interleave_real=True, |
| use_real=True, |
| freqs_dtype=freqs_dtype, |
| ) |
| cos_out.append(cos) |
| sin_out.append(sin) |
| freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) |
| freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) |
| return freqs_cos, freqs_sin |
|
|
|
|
| class BriaTimesteps(nn.Module): |
| def __init__( |
| self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000 |
| ): |
| super().__init__() |
| self.num_channels = num_channels |
| self.flip_sin_to_cos = flip_sin_to_cos |
| self.downscale_freq_shift = downscale_freq_shift |
| self.scale = scale |
| self.time_theta = time_theta |
|
|
| def forward(self, timesteps): |
| t_emb = get_timestep_embedding( |
| timesteps, |
| self.num_channels, |
| flip_sin_to_cos=self.flip_sin_to_cos, |
| downscale_freq_shift=self.downscale_freq_shift, |
| scale=self.scale, |
| max_period=self.time_theta, |
| ) |
| return t_emb |
|
|
|
|
| class BriaTimestepProjEmbeddings(nn.Module): |
| def __init__(self, embedding_dim, time_theta): |
| super().__init__() |
|
|
| self.time_proj = BriaTimesteps( |
| num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta |
| ) |
| self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
|
|
| def forward(self, timestep, dtype): |
| timesteps_proj = self.time_proj(timestep) |
| timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) |
| return timesteps_emb |
|
|
|
|
| class BriaPosEmbed(torch.nn.Module): |
| |
| def __init__(self, theta: int, axes_dim: list[int]): |
| super().__init__() |
| self.theta = theta |
| self.axes_dim = axes_dim |
|
|
| def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| n_axes = ids.shape[-1] |
| cos_out = [] |
| sin_out = [] |
| pos = ids.float() |
| is_mps = ids.device.type == "mps" |
| freqs_dtype = torch.float32 if is_mps else torch.float64 |
| for i in range(n_axes): |
| cos, sin = get_1d_rotary_pos_embed( |
| self.axes_dim[i], |
| pos[:, i], |
| theta=self.theta, |
| repeat_interleave_real=True, |
| use_real=True, |
| freqs_dtype=freqs_dtype, |
| ) |
| cos_out.append(cos) |
| sin_out.append(sin) |
| freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) |
| freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) |
| return freqs_cos, freqs_sin |
|
|
|
|
| @maybe_allow_in_graph |
| class BriaTransformerBlock(nn.Module): |
| def __init__( |
| self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 |
| ): |
| super().__init__() |
|
|
| self.norm1 = AdaLayerNormZero(dim) |
| self.norm1_context = AdaLayerNormZero(dim) |
|
|
| self.attn = BriaAttention( |
| query_dim=dim, |
| added_kv_proj_dim=dim, |
| dim_head=attention_head_dim, |
| heads=num_attention_heads, |
| out_dim=dim, |
| context_pre_only=False, |
| bias=True, |
| processor=BriaAttnProcessor(), |
| eps=eps, |
| ) |
|
|
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
| self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
|
| self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
| self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| temb: torch.Tensor, |
| image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, |
| attention_kwargs: dict[str, Any] | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) |
|
|
| norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( |
| encoder_hidden_states, emb=temb |
| ) |
| attention_kwargs = attention_kwargs or {} |
|
|
| |
| attention_outputs = self.attn( |
| hidden_states=norm_hidden_states, |
| encoder_hidden_states=norm_encoder_hidden_states, |
| image_rotary_emb=image_rotary_emb, |
| **attention_kwargs, |
| ) |
|
|
| if len(attention_outputs) == 2: |
| attn_output, context_attn_output = attention_outputs |
| elif len(attention_outputs) == 3: |
| attn_output, context_attn_output, ip_attn_output = attention_outputs |
|
|
| |
| attn_output = gate_msa.unsqueeze(1) * attn_output |
| hidden_states = hidden_states + attn_output |
|
|
| norm_hidden_states = self.norm2(hidden_states) |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
| ff_output = self.ff(norm_hidden_states) |
| ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
|
| hidden_states = hidden_states + ff_output |
| if len(attention_outputs) == 3: |
| hidden_states = hidden_states + ip_attn_output |
|
|
| |
| context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output |
| encoder_hidden_states = encoder_hidden_states + context_attn_output |
|
|
| norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
| norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] |
|
|
| context_ff_output = self.ff_context(norm_encoder_hidden_states) |
| encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output |
| if encoder_hidden_states.dtype == torch.float16: |
| encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) |
|
|
| return encoder_hidden_states, hidden_states |
|
|
|
|
| @maybe_allow_in_graph |
| class BriaSingleTransformerBlock(nn.Module): |
| def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): |
| super().__init__() |
| self.mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
| self.norm = AdaLayerNormZeroSingle(dim) |
| self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) |
| self.act_mlp = nn.GELU(approximate="tanh") |
| self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) |
|
|
| processor = BriaAttnProcessor() |
|
|
| self.attn = BriaAttention( |
| query_dim=dim, |
| dim_head=attention_head_dim, |
| heads=num_attention_heads, |
| out_dim=dim, |
| bias=True, |
| processor=processor, |
| eps=1e-6, |
| pre_only=True, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| temb: torch.Tensor, |
| image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, |
| attention_kwargs: dict[str, Any] | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| text_seq_len = encoder_hidden_states.shape[1] |
| hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
| residual = hidden_states |
| norm_hidden_states, gate = self.norm(hidden_states, emb=temb) |
| mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) |
| attention_kwargs = attention_kwargs or {} |
| attn_output = self.attn( |
| hidden_states=norm_hidden_states, |
| image_rotary_emb=image_rotary_emb, |
| **attention_kwargs, |
| ) |
|
|
| hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) |
| gate = gate.unsqueeze(1) |
| hidden_states = gate * self.proj_out(hidden_states) |
| hidden_states = residual + hidden_states |
| if hidden_states.dtype == torch.float16: |
| hidden_states = hidden_states.clip(-65504, 65504) |
|
|
| encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] |
| return encoder_hidden_states, hidden_states |
|
|
|
|
| class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): |
| """ |
| The Transformer model introduced in Flux. Based on FluxPipeline with several changes: |
| - no pooled embeddings |
| - We use zero padding for prompts |
| - No guidance embedding since this is not a distilled version |
| Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ |
| |
| Parameters: |
| patch_size (`int`): Patch size to turn the input data into small patches. |
| in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. |
| num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. |
| num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. |
| attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. |
| num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. |
| joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. |
| pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. |
| guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. |
| """ |
|
|
| _supports_gradient_checkpointing = True |
|
|
| @register_to_config |
| def __init__( |
| self, |
| patch_size: int = 1, |
| in_channels: int = 64, |
| num_layers: int = 19, |
| num_single_layers: int = 38, |
| attention_head_dim: int = 128, |
| num_attention_heads: int = 24, |
| joint_attention_dim: int = 4096, |
| pooled_projection_dim: int = None, |
| guidance_embeds: bool = False, |
| axes_dims_rope: list[int] = [16, 56, 56], |
| rope_theta=10000, |
| time_theta=10000, |
| ): |
| super().__init__() |
| self.out_channels = in_channels |
| self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim |
|
|
| self.pos_embed = BriaEmbedND(theta=rope_theta, axes_dim=axes_dims_rope) |
|
|
| self.time_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta) |
| if guidance_embeds: |
| self.guidance_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim) |
|
|
| self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) |
| self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) |
|
|
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BriaTransformerBlock( |
| dim=self.inner_dim, |
| num_attention_heads=self.config.num_attention_heads, |
| attention_head_dim=self.config.attention_head_dim, |
| ) |
| for i in range(self.config.num_layers) |
| ] |
| ) |
|
|
| self.single_transformer_blocks = nn.ModuleList( |
| [ |
| BriaSingleTransformerBlock( |
| dim=self.inner_dim, |
| num_attention_heads=self.config.num_attention_heads, |
| attention_head_dim=self.config.attention_head_dim, |
| ) |
| for i in range(self.config.num_single_layers) |
| ] |
| ) |
|
|
| self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) |
| self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) |
|
|
| self.gradient_checkpointing = False |
|
|
| @apply_lora_scale("attention_kwargs") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| encoder_hidden_states: torch.Tensor = None, |
| pooled_projections: torch.Tensor = None, |
| timestep: torch.LongTensor = None, |
| img_ids: torch.Tensor = None, |
| txt_ids: torch.Tensor = None, |
| guidance: torch.Tensor = None, |
| attention_kwargs: dict[str, Any] | None = None, |
| return_dict: bool = True, |
| controlnet_block_samples=None, |
| controlnet_single_block_samples=None, |
| ) -> tuple[torch.Tensor] | Transformer2DModelOutput: |
| """ |
| The [`BriaTransformer2DModel`] forward method. |
| |
| Args: |
| hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): |
| Input `hidden_states`. |
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): |
| Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. |
| pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected |
| from the embeddings of input conditions. |
| timestep ( `torch.LongTensor`): |
| Used to indicate denoising step. |
| block_controlnet_hidden_states: (`list` of `torch.Tensor`): |
| A list of tensors that if specified are added to the residuals of transformer blocks. |
| attention_kwargs (`dict`, *optional*): |
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| `self.processor` in |
| [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
| tuple. |
| |
| Returns: |
| If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
| `tuple` where the first element is the sample tensor. |
| """ |
| hidden_states = self.x_embedder(hidden_states) |
|
|
| timestep = timestep.to(hidden_states.dtype) |
| if guidance is not None: |
| guidance = guidance.to(hidden_states.dtype) |
| else: |
| guidance = None |
|
|
| temb = self.time_embed(timestep, dtype=hidden_states.dtype) |
|
|
| if guidance: |
| temb += self.guidance_embed(guidance, dtype=hidden_states.dtype) |
|
|
| encoder_hidden_states = self.context_embedder(encoder_hidden_states) |
|
|
| if len(txt_ids.shape) == 3: |
| txt_ids = txt_ids[0] |
|
|
| if len(img_ids.shape) == 3: |
| img_ids = img_ids[0] |
|
|
| ids = torch.cat((txt_ids, img_ids), dim=0) |
| image_rotary_emb = self.pos_embed(ids) |
|
|
| for index_block, block in enumerate(self.transformer_blocks): |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( |
| block, |
| hidden_states, |
| encoder_hidden_states, |
| temb, |
| image_rotary_emb, |
| attention_kwargs, |
| ) |
|
|
| else: |
| encoder_hidden_states, hidden_states = block( |
| hidden_states=hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| temb=temb, |
| image_rotary_emb=image_rotary_emb, |
| ) |
|
|
| |
| if controlnet_block_samples is not None: |
| interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) |
| interval_control = int(np.ceil(interval_control)) |
| hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] |
|
|
| for index_block, block in enumerate(self.single_transformer_blocks): |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( |
| block, |
| hidden_states, |
| encoder_hidden_states, |
| temb, |
| image_rotary_emb, |
| attention_kwargs, |
| ) |
|
|
| else: |
| encoder_hidden_states, hidden_states = block( |
| hidden_states=hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| temb=temb, |
| image_rotary_emb=image_rotary_emb, |
| ) |
|
|
| |
| if controlnet_single_block_samples is not None: |
| interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) |
| interval_control = int(np.ceil(interval_control)) |
| hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( |
| hidden_states[:, encoder_hidden_states.shape[1] :, ...] |
| + controlnet_single_block_samples[index_block // interval_control] |
| ) |
|
|
| hidden_states = self.norm_out(hidden_states, temb) |
| output = self.proj_out(hidden_states) |
|
|
| if not return_dict: |
| return (output,) |
|
|
| return Transformer2DModelOutput(sample=output) |
|
|