Buckets:
Krea2Transformer2DModel
The single-stream MMDiT flow-matching transformer used by Krea 2.
Krea2Transformer2DModel[[diffusers.Krea2Transformer2DModel]]
diffusers.Krea2Transformer2DModel[[diffusers.Krea2Transformer2DModel]]
The single-stream MMDiT flow-matching backbone used by the Krea 2 pipeline.
Text conditioning enters as a stack of hidden states tapped from several layers of a multimodal text encoder. A
small text-fusion transformer collapses the layer axis and refines the token sequence; the result is concatenated
with the patchified image latents into a single [text, image] sequence processed by the transformer blocks. The
timestep conditions every block through one shared modulation vector plus per-block learned tables.
forwarddiffusers.Krea2Transformer2DModel.forwardhttps://github.com/huggingface/diffusers/blob/vr_13751/src/diffusers/models/transformers/transformer_krea2.py#L447[{"name": "hidden_states", "val": ": Tensor"}, {"name": "encoder_hidden_states", "val": ": Tensor"}, {"name": "timestep", "val": ": Tensor"}, {"name": "position_ids", "val": ": Tensor"}, {"name": "encoder_attention_mask", "val": ": torch.Tensor | None = None"}, {"name": "attention_kwargs", "val": ": dict[str, typing.Any] | None = None"}, {"name": "return_dict", "val": ": bool = True"}]- hidden_states (torch.Tensor of shape (batch_size, image_seq_len, in_channels)) --
Packed (patchified) noisy image latents.
- encoder_hidden_states (
torch.Tensorof shape(batch_size, text_seq_len, num_text_layers, text_hidden_dim)) -- Stack of tapped text-encoder hidden states per token. - timestep (
torch.Tensorof shape(batch_size,)) -- Flow-matching time in[0, 1](1 is pure noise, 0 is clean data). - position_ids (
torch.Tensorof shape(text_seq_len + image_seq_len, 3)) --(t, h, w)rotary coordinates for the combined sequence. Text rows are all-zero; image rows hold the latent-grid coordinates. - encoder_attention_mask (
torch.Tensorof shape(batch_size, text_seq_len), optional) -- Boolean mask marking valid text tokens. PassNonewhen every text token is valid. - attention_kwargs (
dict, optional) -- A kwargs dictionary that, when it contains ascaleentry, sets the LoRA scale applied to this transformer's adapters for the duration of the forward pass. - return_dict (
bool, optional, defaults toTrue) -- Whether to return a Transformer2DModelOutput instead of a plain tuple.0Transformer2DModelOutput or atuplewhose first element is the velocity tensor of shape(batch_size, image_seq_len, in_channels).
Predict the flow-matching velocity for the image tokens.
Parameters:
in_channels (int, defaults to 64) : Latent channel count after patchification (vae_channels * patch_size ** 2).
num_layers (int, defaults to 28) : Number of transformer blocks.
attention_head_dim (int, defaults to 128) : Dimension of each attention head; the total hidden size is attention_head_dim * num_attention_heads.
num_attention_heads (int, defaults to 48) : Number of query heads.
num_key_value_heads (int, defaults to 12) : Number of key/value heads for grouped-query attention.
intermediate_size (int, defaults to 16384) : Feed-forward hidden size of the SwiGLU MLP inside each block.
timestep_embed_dim (int, defaults to 256) : Width of the sinusoidal timestep embedding before its MLP.
text_hidden_dim (int, defaults to 2560) : Hidden size of the text encoder whose hidden states are consumed.
num_text_layers (int, defaults to 12) : Number of tapped text-encoder hidden states stacked per token.
text_num_attention_heads (int, defaults to 20) : Number of query heads in the text fusion blocks.
text_num_key_value_heads (int, defaults to 20) : Number of key/value heads in the text fusion blocks.
text_intermediate_size (int, defaults to 6912) : Feed-forward hidden size of the SwiGLU MLP inside the text fusion blocks.
num_layerwise_text_blocks (int, defaults to 2) : Number of text fusion blocks applied across the tapped-layer axis (per token).
num_refiner_text_blocks (int, defaults to 2) : Number of text fusion blocks applied across the token sequence.
axes_dims_rope (tuple[int, int, int], defaults to (32, 48, 48)) : Head-dim split across the (t, h, w) rotary position axes.
rope_theta (float, defaults to 1000.0) : Base used by the rotary position embedding.
norm_eps (float, defaults to 1e-5) : Epsilon used by all RMSNorm modules.
Returns:
Transformer2DModelOutput or a tuple whose first element is the velocity
tensor of shape (batch_size, image_seq_len, in_channels).
Xet Storage Details
- Size:
- 5.2 kB
- Xet hash:
- 3a26b4c54dbfdac712a13821bc1d51a2890b15d9fb6dc025a18039d5bd5cd5e2
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.