# Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNorm, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name class OmniGenFeedForward(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) self.activation_fn = nn.SiLU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: up_states = self.gate_up_proj(hidden_states) gate, up_states = up_states.chunk(2, dim=-1) up_states = up_states * self.activation_fn(gate) return self.down_proj(up_states) class OmniGenPatchEmbed(nn.Module): def __init__( self, patch_size: int = 2, in_channels: int = 4, embed_dim: int = 768, bias: bool = True, interpolation_scale: float = 1, pos_embed_max_size: int = 192, base_size: int = 64, ): super().__init__() self.output_image_proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) self.input_image_proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) self.patch_size = patch_size self.interpolation_scale = interpolation_scale self.pos_embed_max_size = pos_embed_max_size pos_embed = get_2d_sincos_pos_embed( embed_dim, self.pos_embed_max_size, base_size=base_size, interpolation_scale=self.interpolation_scale, output_type="pt", ) self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True) def _cropped_pos_embed(self, height, width): """Crops positional embeddings for SD3 compatibility.""" if self.pos_embed_max_size is None: raise ValueError("`pos_embed_max_size` must be set for cropping.") height = height // self.patch_size width = width // self.patch_size if height > self.pos_embed_max_size: raise ValueError( f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." ) if width > self.pos_embed_max_size: raise ValueError( f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." ) top = (self.pos_embed_max_size - height) // 2 left = (self.pos_embed_max_size - width) // 2 spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor: if is_input_image: hidden_states = self.input_image_proj(hidden_states) else: hidden_states = self.output_image_proj(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) return hidden_states def forward( self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None ) -> torch.Tensor: if isinstance(hidden_states, list): if padding_latent is None: padding_latent = [None] * len(hidden_states) patched_latents = [] for sub_latent, padding in zip(hidden_states, padding_latent): height, width = sub_latent.shape[-2:] sub_latent = self._patch_embeddings(sub_latent, is_input_image) pos_embed = self._cropped_pos_embed(height, width) sub_latent = sub_latent + pos_embed if padding is not None: sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2) patched_latents.append(sub_latent) else: height, width = hidden_states.shape[-2:] pos_embed = self._cropped_pos_embed(height, width) hidden_states = self._patch_embeddings(hidden_states, is_input_image) patched_latents = hidden_states + pos_embed return patched_latents class OmniGenSuScaledRotaryEmbedding(nn.Module): def __init__( self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None ): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) self.short_factor = rope_scaling["short_factor"] self.long_factor = rope_scaling["long_factor"] self.original_max_position_embeddings = original_max_position_embeddings def forward(self, hidden_states, position_ids): seq_len = torch.max(position_ids) + 1 if seq_len > self.original_max_position_embeddings: ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device) else: ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device) inv_freq_shape = ( torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim ) self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = hidden_states.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1)[0] scale = self.max_position_embeddings / self.original_max_position_embeddings if scale <= 1.0: scaling_factor = 1.0 else: scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) cos = emb.cos() * scaling_factor sin = emb.sin() * scaling_factor return cos, sin class OmniGenAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the OmniGen model. """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) bsz, q_len, query_dim = query.size() inner_dim = key.shape[-1] head_dim = query_dim // attn.heads # Get key-value heads kv_heads = inner_dim // head_dim query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) # Apply RoPE if needed if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2) key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2) hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) hidden_states = hidden_states.transpose(1, 2).type_as(query) hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim) hidden_states = attn.to_out[0](hidden_states) return hidden_states class OmniGenBlock(nn.Module): def __init__( self, hidden_size: int, num_attention_heads: int, num_key_value_heads: int, intermediate_size: int, rms_norm_eps: float, ) -> None: super().__init__() self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) self.self_attn = Attention( query_dim=hidden_size, cross_attention_dim=hidden_size, dim_head=hidden_size // num_attention_heads, heads=num_attention_heads, kv_heads=num_key_value_heads, bias=False, out_dim=hidden_size, out_bias=False, processor=OmniGenAttnProcessor2_0(), ) self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) self.mlp = OmniGenFeedForward(hidden_size, intermediate_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor ) -> torch.Tensor: # 1. Attention norm_hidden_states = self.input_layernorm(hidden_states) attn_output = self.self_attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + attn_output # 2. Feed Forward norm_hidden_states = self.post_attention_layernorm(hidden_states) ff_output = self.mlp(norm_hidden_states) hidden_states = hidden_states + ff_output return hidden_states class OmniGenTransformer2DModel(ModelMixin, ConfigMixin): """ The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340). Parameters: in_channels (`int`, defaults to `4`): The number of channels in the input. patch_size (`int`, defaults to `2`): The size of the spatial patches to use in the patch embedding layer. hidden_size (`int`, defaults to `3072`): The dimensionality of the hidden layers in the model. rms_norm_eps (`float`, defaults to `1e-5`): Eps for RMSNorm layer. num_attention_heads (`int`, defaults to `32`): The number of heads to use for multi-head attention. num_key_value_heads (`int`, defaults to `32`): The number of heads to use for keys and values in multi-head attention. intermediate_size (`int`, defaults to `8192`): Dimension of the hidden layer in FeedForward layers. num_layers (`int`, default to `32`): The number of layers of transformer blocks to use. pad_token_id (`int`, default to `32000`): The id of the padding token. vocab_size (`int`, default to `32064`): The size of the vocabulary of the embedding vocabulary. rope_base (`int`, default to `10000`): The default theta value to use when creating RoPE. rope_scaling (`Dict`, optional): The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`. pos_embed_max_size (`int`, default to `192`): The maximum size of the positional embeddings. time_step_dim (`int`, default to `256`): Output dimension of timestep embeddings. flip_sin_to_cos (`bool`, default to `True`): Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings. downscale_freq_shift (`int`, default to `0`): The frequency shift to use when downscaling the timestep embeddings. timestep_activation_fn (`str`, default to `silu`): The activation function to use for the timestep embeddings. """ _supports_gradient_checkpointing = True _no_split_modules = ["OmniGenBlock"] _skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"] @register_to_config def __init__( self, in_channels: int = 4, patch_size: int = 2, hidden_size: int = 3072, rms_norm_eps: float = 1e-5, num_attention_heads: int = 32, num_key_value_heads: int = 32, intermediate_size: int = 8192, num_layers: int = 32, pad_token_id: int = 32000, vocab_size: int = 32064, max_position_embeddings: int = 131072, original_max_position_embeddings: int = 4096, rope_base: int = 10000, rope_scaling: Dict = None, pos_embed_max_size: int = 192, time_step_dim: int = 256, flip_sin_to_cos: bool = True, downscale_freq_shift: int = 0, timestep_activation_fn: str = "silu", ): super().__init__() self.in_channels = in_channels self.out_channels = in_channels self.patch_embedding = OmniGenPatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, pos_embed_max_size=pos_embed_max_size, ) self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift) self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id) self.rope = OmniGenSuScaledRotaryEmbedding( hidden_size // num_attention_heads, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, base=rope_base, rope_scaling=rope_scaling, ) self.layers = nn.ModuleList( [ OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps) for _ in range(num_layers) ] ) self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1) self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False def _get_multimodal_embeddings( self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict ) -> Optional[torch.Tensor]: if input_ids is None: return None input_img_latents = [x.to(self.dtype) for x in input_img_latents] condition_tokens = self.embed_tokens(input_ids) input_img_inx = 0 input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True) for b_inx in input_image_sizes.keys(): for start_inx, end_inx in input_image_sizes[b_inx]: # replace the placeholder in text tokens with the image embedding. condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to( condition_tokens.dtype ) input_img_inx += 1 return condition_tokens def forward( self, hidden_states: torch.Tensor, timestep: Union[int, float, torch.FloatTensor], input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict[int, List[int]], attention_mask: torch.Tensor, position_ids: torch.Tensor, return_dict: bool = True, ) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]: batch_size, num_channels, height, width = hidden_states.shape p = self.config.patch_size post_patch_height, post_patch_width = height // p, width // p # 1. Patch & Timestep & Conditional Embedding hidden_states = self.patch_embedding(hidden_states, is_input_image=False) num_tokens_for_output_image = hidden_states.size(1) timestep_proj = self.time_proj(timestep).type_as(hidden_states) time_token = self.time_token(timestep_proj).unsqueeze(1) temb = self.t_embedder(timestep_proj) condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes) if condition_tokens is not None: hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1) else: hidden_states = torch.cat([time_token, hidden_states], dim=1) seq_length = hidden_states.size(1) position_ids = position_ids.view(-1, seq_length).long() # 2. Attention mask preprocessing if attention_mask is not None and attention_mask.dim() == 3: dtype = hidden_states.dtype min_dtype = torch.finfo(dtype).min attention_mask = (1 - attention_mask) * min_dtype attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states) # 3. Rotary position embedding image_rotary_emb = self.rope(hidden_states, position_ids) # 4. Transformer blocks for block in self.layers: if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( block, hidden_states, attention_mask, image_rotary_emb ) else: hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb) # 5. Output norm & projection hidden_states = self.norm(hidden_states) hidden_states = hidden_states[:, -num_tokens_for_output_image:] hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1) output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output)