# SPDX-License-Identifier: Apache-2.0 # Copyright 2025 Black Forest Labs. # Copyright (c) 2026 World Labs. Modifications: depth stream, joint RGBD # attention, depth decoder blocks, cross-stream timestep mixing. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """FluxRGBDDiT — FLUX.2 dual+single-stream transformer extended for joint RGB+depth. Derived from the FLUX.2 reference implementation: https://github.com/black-forest-labs/flux2 Architecture, inference-only: * Depth is a peer stream alongside image and text — each gets its own pre-norm, Q/K/V, MLP inside every dual-stream block. Joint attention runs over the concatenated `[txt, img, depth]` sequence. * 8 dual-stream blocks + 24 single-stream blocks (FLUX.2 backbone). * 4-block depth-only decoder before the depth output head. * 4D RoPE position ids (axes_dim=(32,32,32,32), theta=2000) shared via EmbedND. RGB tokens use time_id=0, depth uses time_id=1 so RoPE can distinguish them. Reuses the FLUX.2 reference building blocks unchanged so img/txt-side weights lifted from an RGB-only FLUX.2 [klein-9B] base load 1:1. """ from __future__ import annotations import einops import torch from torch import Tensor, nn from flux_rgbd._flux2 import model as flux2_model from flux_rgbd._flux2.model import ( EmbedND, LastLayer, MLPEmbedder, Modulation, SelfAttention, SiLUActivation, SingleStreamBlock, timestep_embedding, ) class _TripleStreamBlock(nn.Module): """Joint-attention block over (img, text, depth) with per-stream AdaLN. Submodule names on the img/txt halves match FLUX.2's DoubleStreamBlock so RGB-only base weights load 1:1; the depth half adds new submodules with the same shapes as img. """ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float): super().__init__() if hidden_size % num_heads != 0: raise ValueError(f"hidden_size {hidden_size} not divisible by num_heads {num_heads}") mlp_hidden = int(hidden_size * mlp_ratio) self.num_heads = num_heads # Same shapes per stream: pre-norm + (Q,K,V,proj) + post-norm + 2-up SiLU MLP. # The `* 2` in the first Linear is the gated-SiLU activation packing FLUX.2 uses. for prefix in ("img", "txt", "depth"): setattr(self, f"{prefix}_norm1", nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)) setattr(self, f"{prefix}_attn", SelfAttention(dim=hidden_size, num_heads=num_heads)) setattr(self, f"{prefix}_norm2", nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)) setattr(self, f"{prefix}_mlp", nn.Sequential( nn.Linear(hidden_size, mlp_hidden * 2, bias=False), SiLUActivation(), nn.Linear(mlp_hidden, hidden_size, bias=False), )) def _qkv(self, x: Tensor, norm: nn.LayerNorm, attn: SelfAttention, mod1: tuple[Tensor, Tensor, Tensor]): """Pre-norm + AdaLN + Q/K/V projection. Returns (q, k, v, gate).""" shift, scale, gate = mod1 qkv = attn.qkv((1 + scale) * norm(x) + shift) q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = attn.norm(q, k, v) return q, k, v, gate @staticmethod def _residual(x, attn_out, proj, gate1, mod2, norm2, mlp): """attn residual then MLP residual, both gated. Standard DiT pattern.""" x = x + gate1 * proj(attn_out) shift, scale, gate2 = mod2 return x + gate2 * mlp((1 + scale) * norm2(x) + shift) def forward(self, img, txt, depth, pe_img, pe_txt, pe_depth, mod_img, mod_txt, mod_depth): q_img, k_img, v_img, g_img = self._qkv(img, self.img_norm1, self.img_attn, mod_img[0]) q_txt, k_txt, v_txt, g_txt = self._qkv(txt, self.txt_norm1, self.txt_attn, mod_txt[0]) q_d, k_d, v_d, g_d = self._qkv(depth, self.depth_norm1, self.depth_attn, mod_depth[0]) # Joint attention over [txt, img, depth]; BFL's `attention` does # RoPE + scaled_dot_product_attention + rearrange. q = torch.cat([q_txt, q_img, q_d], dim=2) k = torch.cat([k_txt, k_img, k_d], dim=2) v = torch.cat([v_txt, v_img, v_d], dim=2) pe = torch.cat([pe_txt, pe_img, pe_depth], dim=2) out = flux2_model.attention(q, k, v, pe) n_txt, n_img = q_txt.shape[2], q_img.shape[2] txt_out = out[:, :n_txt] img_out = out[:, n_txt : n_txt + n_img] depth_out = out[:, n_txt + n_img :] img = self._residual(img, img_out, self.img_attn.proj, g_img, mod_img[1], self.img_norm2, self.img_mlp) txt = self._residual(txt, txt_out, self.txt_attn.proj, g_txt, mod_txt[1], self.txt_norm2, self.txt_mlp) depth = self._residual(depth, depth_out, self.depth_attn.proj, g_d, mod_depth[1], self.depth_norm2, self.depth_mlp) return img, txt, depth def _stack_per_token_mod( mod_img: tuple[Tensor, Tensor, Tensor], mod_depth: tuple[Tensor, Tensor, Tensor], n_txt: int, n_img: int, n_depth: int, ) -> tuple[Tensor, Tensor, Tensor]: """Build the per-token modulation triple consumed by SingleStreamBlock. Each ``Modulation`` output is broadcast over the sequence axis. For the triple-stream single stack we need different (shift, scale, gate) for the depth tokens vs the txt+img tokens — so we expand each entry along seq and ``cat``. """ n_txt_img = n_txt + n_img def _build(slot: int) -> Tensor: head = mod_img[slot].expand(-1, n_txt_img, -1) tail = mod_depth[slot].expand(-1, n_depth, -1) return torch.cat([head, tail], dim=1) return _build(0), _build(1), _build(2) class FluxRGBDDiT(nn.Module): """FLUX.2 Klein-9B DiT with peer depth stream. Forward: (img, depth, text, RGB timestep, depth timestep, *_ids) → (rgb_latent_out, depth_latent_out). Architecture (v1 defaults match the published checkpoint): depth_double=8 triple-stream blocks → full-token concat → depth_single=24 single-stream blocks → split → depth_decoder_num_layers=4 single blocks over depth only → final_layer / depth_final_layer. """ def __init__( self, *, in_channels: int = 128, depth_channels: int = 256, context_in_dim: int = 12288, hidden_size: int = 4096, num_heads: int = 32, depth_double: int = 8, depth_single: int = 24, depth_decoder_num_layers: int = 4, axes_dim: tuple[int, ...] = (32, 32, 32, 32), theta: int = 2000, mlp_ratio: float = 3.0, use_guidance_embed: bool = False, cross_stream_timestep_mixing: bool = False, ): super().__init__() pe_dim = hidden_size // num_heads if sum(axes_dim) != pe_dim: raise ValueError( f"axes_dim sum {sum(axes_dim)} != hidden_size/num_heads {pe_dim}" ) self.in_channels = in_channels self.depth_channels = depth_channels self.hidden_size = hidden_size self.num_heads = num_heads self.depth_double = depth_double self.depth_single = depth_single self.depth_decoder_num_layers = depth_decoder_num_layers # Stream-in projections. self.img_in = nn.Linear(in_channels, hidden_size, bias=False) self.txt_in = nn.Linear(context_in_dim, hidden_size, bias=False) self.depth_in = nn.Linear(depth_channels, hidden_size, bias=False) # Timestep embedders. `time_in` carries the RGB timestep (plus optional # guidance scale); `time_in_depth` carries the depth timestep. self.time_in = MLPEmbedder( in_dim=256, hidden_dim=hidden_size, disable_bias=True ) self.time_in_depth = MLPEmbedder( in_dim=256, hidden_dim=hidden_size, disable_bias=True ) self.use_guidance_embed = use_guidance_embed if use_guidance_embed: self.guidance_in = MLPEmbedder( in_dim=256, hidden_dim=hidden_size, disable_bias=True ) # Cross-stream timestep mixing: each stream's modulation vector also # sees the *other* stream's noise level, projected through a small MLP # and gated by a learnable scalar (trained from a zero init, so the # path is identity at step 0 and a learned delta thereafter): # vec_img += cross_alpha_img * time_in_depth_to_img(emb(t_depth)) # vec_depth += cross_alpha_depth * time_in_rgb_to_depth(emb(t_rgb)) # This lets the var-AR checkpoint condition each modality on how noisy # the other one is (e.g. clean RGB in i2d, clean depth in d2i). self.cross_stream_timestep_mixing = cross_stream_timestep_mixing if cross_stream_timestep_mixing: self.time_in_depth_to_img = MLPEmbedder( in_dim=256, hidden_dim=hidden_size, disable_bias=True ) self.time_in_rgb_to_depth = MLPEmbedder( in_dim=256, hidden_dim=hidden_size, disable_bias=True ) # Shape (1,) not () to mirror the training checkpoint's parameter # layout (kept 1-D for FSDP2's no-scalar-parameter rule). self.cross_alpha_img = nn.Parameter(torch.zeros(1)) self.cross_alpha_depth = nn.Parameter(torch.zeros(1)) # Position embedder shared across streams. self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) # Dual-stream stack. self.double_blocks = nn.ModuleList( [_TripleStreamBlock(hidden_size, num_heads, mlp_ratio) for _ in range(depth_double)] ) # Single-stream stack — reuses the FLUX.2 SingleStreamBlock with # per-token modulation built on the fly. self.single_blocks = nn.ModuleList( [SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_single)] ) # Modulation heads. self.double_stream_modulation_img = Modulation( hidden_size, double=True, disable_bias=True ) self.double_stream_modulation_txt = Modulation( hidden_size, double=True, disable_bias=True ) self.double_stream_modulation_depth = Modulation( hidden_size, double=True, disable_bias=True ) self.single_stream_modulation = Modulation( hidden_size, double=False, disable_bias=True ) self.single_stream_modulation_depth = Modulation( hidden_size, double=False, disable_bias=True ) # Optional depth-only decoder stack (4 SingleStreamBlocks for v1). if depth_decoder_num_layers > 0: self.depth_decoder_blocks = nn.ModuleList( [SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth_decoder_num_layers)] ) self.depth_decoder_modulation = Modulation( hidden_size, double=False, disable_bias=True ) else: self.depth_decoder_blocks = nn.ModuleList() self.depth_decoder_modulation = None # Final unpatchifiers. self.final_layer = LastLayer(hidden_size, in_channels) self.depth_final_layer = LastLayer(hidden_size, depth_channels) def forward( self, img: Tensor, img_ids: Tensor, timesteps: Tensor, ctx: Tensor, ctx_ids: Tensor, depth: Tensor, depth_ids: Tensor, depth_timesteps: Tensor, guidance: Tensor | None = None, ) -> tuple[Tensor, Tensor]: """Triple-stream forward. Args: img: `(B, n_img, in_channels)` RGB latent tokens img_ids: `(B, n_img, 4)` 4D position ids timesteps: `(B,)` RGB diffusion time ctx: `(B, n_txt, context_in_dim)` text embeddings ctx_ids: `(B, n_txt, 4)` text position ids depth: `(B, n_depth, depth_channels)` depth latent tokens depth_ids: `(B, n_depth, 4)` depth position ids depth_timesteps: `(B,)` depth diffusion time guidance: `(B,)` optional guidance scale Returns: `(rgb_out, depth_out)` patches. """ # Body dtype (set by the parameters of the input projections). Inputs # are cast to this so mixed-precision callers (BF16 body + FP32 # sampling state) work transparently. body_dtype = self.img_in.weight.dtype img = img.to(body_dtype) ctx = ctx.to(body_dtype) depth = depth.to(body_dtype) timesteps = timesteps.to(body_dtype) depth_timesteps = depth_timesteps.to(body_dtype) if guidance is not None: guidance = guidance.to(body_dtype) # Timestep embeddings (separate per modality). `timestep_embedding` # always returns FP32; cast before the BF16 MLPEmbedder weights. rgb_t_emb = timestep_embedding(timesteps, 256).to(body_dtype) depth_t_emb = timestep_embedding(depth_timesteps, 256).to(body_dtype) vec_img = self.time_in(rgb_t_emb) if self.use_guidance_embed and guidance is not None: vec_img = vec_img + self.guidance_in( timestep_embedding(guidance, 256).to(body_dtype) ) vec_depth = self.time_in_depth(depth_t_emb) # Cross-stream timestep mixing (see __init__): fold the other stream's # noise level into each modulation vector. The scalar gate is part of # the body dtype after `.to(dtype=…)`, so no extra cast is needed. if self.cross_stream_timestep_mixing: vec_img = vec_img + self.cross_alpha_img * self.time_in_depth_to_img( depth_t_emb ) vec_depth = vec_depth + self.cross_alpha_depth * self.time_in_rgb_to_depth( rgb_t_emb ) # Per-stream modulation coefficients. mod_img_double = self.double_stream_modulation_img(vec_img) mod_txt_double = self.double_stream_modulation_txt(vec_img) mod_depth_double = self.double_stream_modulation_depth(vec_depth) mod_img_single, _ = self.single_stream_modulation(vec_img) mod_depth_single, _ = self.single_stream_modulation_depth(vec_depth) # Project each stream into the hidden dim. img = self.img_in(img) txt = self.txt_in(ctx) depth = self.depth_in(depth) # Position embeddings. pe_img = self.pe_embedder(img_ids) pe_txt = self.pe_embedder(ctx_ids) pe_depth = self.pe_embedder(depth_ids) # Dual-stream stack. for block in self.double_blocks: img, txt, depth = block( img, txt, depth, pe_img, pe_txt, pe_depth, mod_img_double, mod_txt_double, mod_depth_double, ) # Single-stream stack over the concatenated [txt, img, depth] sequence # with per-token modulation. n_txt = txt.shape[1] n_img = img.shape[1] n_depth = depth.shape[1] joint = torch.cat([txt, img, depth], dim=1) pe_joint = torch.cat([pe_txt, pe_img, pe_depth], dim=2) per_token_mod = _stack_per_token_mod( mod_img_single, mod_depth_single, n_txt, n_img, n_depth ) for block in self.single_blocks: joint = block(joint, pe_joint, per_token_mod) # Split back per stream. rgb_tokens = joint[:, n_txt : n_txt + n_img] depth_tokens = joint[:, n_txt + n_img :] # Optional depth-only decoder. if self.depth_decoder_num_layers > 0: depth_decoder_mod, _ = self.depth_decoder_modulation(vec_depth) for block in self.depth_decoder_blocks: depth_tokens = block(depth_tokens, pe_depth, depth_decoder_mod) # Each output head may be in a different dtype than the body # (e.g. FP32 depth head over a BF16 body); cast inputs to the # head's own dtype at the boundary. rgb_head_dtype = self.final_layer.linear.weight.dtype depth_head_dtype = self.depth_final_layer.linear.weight.dtype rgb_out = self.final_layer( rgb_tokens.to(rgb_head_dtype), vec_img.to(rgb_head_dtype) ) depth_out = self.depth_final_layer( depth_tokens.to(depth_head_dtype), vec_depth.to(depth_head_dtype) ) return rgb_out, depth_out