File size: 20,029 Bytes
be9fa39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 |
# Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class OmniGenFeedForward(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.activation_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states)
class OmniGenPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 4,
embed_dim: int = 768,
bias: bool = True,
interpolation_scale: float = 1,
pos_embed_max_size: int = 192,
base_size: int = 64,
):
super().__init__()
self.output_image_proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.input_image_proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.patch_size = patch_size
self.interpolation_scale = interpolation_scale
self.pos_embed_max_size = pos_embed_max_size
pos_embed = get_2d_sincos_pos_embed(
embed_dim,
self.pos_embed_max_size,
base_size=base_size,
interpolation_scale=self.interpolation_scale,
output_type="pt",
)
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
def _cropped_pos_embed(self, height, width):
"""Crops positional embeddings for SD3 compatibility."""
if self.pos_embed_max_size is None:
raise ValueError("`pos_embed_max_size` must be set for cropping.")
height = height // self.patch_size
width = width // self.patch_size
if height > self.pos_embed_max_size:
raise ValueError(
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
if width > self.pos_embed_max_size:
raise ValueError(
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
)
top = (self.pos_embed_max_size - height) // 2
left = (self.pos_embed_max_size - width) // 2
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor:
if is_input_image:
hidden_states = self.input_image_proj(hidden_states)
else:
hidden_states = self.output_image_proj(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
return hidden_states
def forward(
self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None
) -> torch.Tensor:
if isinstance(hidden_states, list):
if padding_latent is None:
padding_latent = [None] * len(hidden_states)
patched_latents = []
for sub_latent, padding in zip(hidden_states, padding_latent):
height, width = sub_latent.shape[-2:]
sub_latent = self._patch_embeddings(sub_latent, is_input_image)
pos_embed = self._cropped_pos_embed(height, width)
sub_latent = sub_latent + pos_embed
if padding is not None:
sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
patched_latents.append(sub_latent)
else:
height, width = hidden_states.shape[-2:]
pos_embed = self._cropped_pos_embed(height, width)
hidden_states = self._patch_embeddings(hidden_states, is_input_image)
patched_latents = hidden_states + pos_embed
return patched_latents
class OmniGenSuScaledRotaryEmbedding(nn.Module):
def __init__(
self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
self.short_factor = rope_scaling["short_factor"]
self.long_factor = rope_scaling["long_factor"]
self.original_max_position_embeddings = original_max_position_embeddings
def forward(self, hidden_states, position_ids):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device)
inv_freq_shape = (
torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim
)
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = hidden_states.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)[0]
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos, sin
class OmniGenAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the OmniGen model.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
bsz, q_len, query_dim = query.size()
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
# Get key-value heads
kv_heads = inner_dim // head_dim
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
# Apply RoPE if needed
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2)
key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).type_as(query)
hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim)
hidden_states = attn.to_out[0](hidden_states)
return hidden_states
class OmniGenBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
intermediate_size: int,
rms_norm_eps: float,
) -> None:
super().__init__()
self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.self_attn = Attention(
query_dim=hidden_size,
cross_attention_dim=hidden_size,
dim_head=hidden_size // num_attention_heads,
heads=num_attention_heads,
kv_heads=num_key_value_heads,
bias=False,
out_dim=hidden_size,
out_bias=False,
processor=OmniGenAttnProcessor2_0(),
)
self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.mlp = OmniGenFeedForward(hidden_size, intermediate_size)
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor
) -> torch.Tensor:
# 1. Attention
norm_hidden_states = self.input_layernorm(hidden_states)
attn_output = self.self_attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_output
# 2. Feed Forward
norm_hidden_states = self.post_attention_layernorm(hidden_states)
ff_output = self.mlp(norm_hidden_states)
hidden_states = hidden_states + ff_output
return hidden_states
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
"""
The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340).
Parameters:
in_channels (`int`, defaults to `4`):
The number of channels in the input.
patch_size (`int`, defaults to `2`):
The size of the spatial patches to use in the patch embedding layer.
hidden_size (`int`, defaults to `3072`):
The dimensionality of the hidden layers in the model.
rms_norm_eps (`float`, defaults to `1e-5`):
Eps for RMSNorm layer.
num_attention_heads (`int`, defaults to `32`):
The number of heads to use for multi-head attention.
num_key_value_heads (`int`, defaults to `32`):
The number of heads to use for keys and values in multi-head attention.
intermediate_size (`int`, defaults to `8192`):
Dimension of the hidden layer in FeedForward layers.
num_layers (`int`, default to `32`):
The number of layers of transformer blocks to use.
pad_token_id (`int`, default to `32000`):
The id of the padding token.
vocab_size (`int`, default to `32064`):
The size of the vocabulary of the embedding vocabulary.
rope_base (`int`, default to `10000`):
The default theta value to use when creating RoPE.
rope_scaling (`Dict`, optional):
The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`.
pos_embed_max_size (`int`, default to `192`):
The maximum size of the positional embeddings.
time_step_dim (`int`, default to `256`):
Output dimension of timestep embeddings.
flip_sin_to_cos (`bool`, default to `True`):
Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings.
downscale_freq_shift (`int`, default to `0`):
The frequency shift to use when downscaling the timestep embeddings.
timestep_activation_fn (`str`, default to `silu`):
The activation function to use for the timestep embeddings.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["OmniGenBlock"]
_skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"]
@register_to_config
def __init__(
self,
in_channels: int = 4,
patch_size: int = 2,
hidden_size: int = 3072,
rms_norm_eps: float = 1e-5,
num_attention_heads: int = 32,
num_key_value_heads: int = 32,
intermediate_size: int = 8192,
num_layers: int = 32,
pad_token_id: int = 32000,
vocab_size: int = 32064,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
rope_base: int = 10000,
rope_scaling: Dict = None,
pos_embed_max_size: int = 192,
time_step_dim: int = 256,
flip_sin_to_cos: bool = True,
downscale_freq_shift: int = 0,
timestep_activation_fn: str = "silu",
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_embedding = OmniGenPatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=hidden_size,
pos_embed_max_size=pos_embed_max_size,
)
self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift)
self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
self.rope = OmniGenSuScaledRotaryEmbedding(
hidden_size // num_attention_heads,
max_position_embeddings=max_position_embeddings,
original_max_position_embeddings=original_max_position_embeddings,
base=rope_base,
rope_scaling=rope_scaling,
)
self.layers = nn.ModuleList(
[
OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps)
for _ in range(num_layers)
]
)
self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
def _get_multimodal_embeddings(
self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict
) -> Optional[torch.Tensor]:
if input_ids is None:
return None
input_img_latents = [x.to(self.dtype) for x in input_img_latents]
condition_tokens = self.embed_tokens(input_ids)
input_img_inx = 0
input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
for b_inx in input_image_sizes.keys():
for start_inx, end_inx in input_image_sizes[b_inx]:
# replace the placeholder in text tokens with the image embedding.
condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
condition_tokens.dtype
)
input_img_inx += 1
return condition_tokens
def forward(
self,
hidden_states: torch.Tensor,
timestep: Union[int, float, torch.FloatTensor],
input_ids: torch.Tensor,
input_img_latents: List[torch.Tensor],
input_image_sizes: Dict[int, List[int]],
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
return_dict: bool = True,
) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]:
batch_size, num_channels, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height, post_patch_width = height // p, width // p
# 1. Patch & Timestep & Conditional Embedding
hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
num_tokens_for_output_image = hidden_states.size(1)
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
time_token = self.time_token(timestep_proj).unsqueeze(1)
temb = self.t_embedder(timestep_proj)
condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes)
if condition_tokens is not None:
hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
else:
hidden_states = torch.cat([time_token, hidden_states], dim=1)
seq_length = hidden_states.size(1)
position_ids = position_ids.view(-1, seq_length).long()
# 2. Attention mask preprocessing
if attention_mask is not None and attention_mask.dim() == 3:
dtype = hidden_states.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = (1 - attention_mask) * min_dtype
attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states)
# 3. Rotary position embedding
image_rotary_emb = self.rope(hidden_states, position_ids)
# 4. Transformer blocks
for block in self.layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, attention_mask, image_rotary_emb
)
else:
hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb)
# 5. Output norm & projection
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, -num_tokens_for_output_image:]
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1)
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
|