QwenTest
/
pythonProject
/diffusers-main
/src
/diffusers
/models
/transformers
/transformer_omnigen.py
| # Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ...configuration_utils import ConfigMixin, register_to_config | |
| from ...utils import logging | |
| from ..attention_processor import Attention | |
| from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed | |
| from ..modeling_outputs import Transformer2DModelOutput | |
| from ..modeling_utils import ModelMixin | |
| from ..normalization import AdaLayerNorm, RMSNorm | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| class OmniGenFeedForward(nn.Module): | |
| def __init__(self, hidden_size: int, intermediate_size: int): | |
| super().__init__() | |
| self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) | |
| self.activation_fn = nn.SiLU() | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| up_states = self.gate_up_proj(hidden_states) | |
| gate, up_states = up_states.chunk(2, dim=-1) | |
| up_states = up_states * self.activation_fn(gate) | |
| return self.down_proj(up_states) | |
| class OmniGenPatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| patch_size: int = 2, | |
| in_channels: int = 4, | |
| embed_dim: int = 768, | |
| bias: bool = True, | |
| interpolation_scale: float = 1, | |
| pos_embed_max_size: int = 192, | |
| base_size: int = 64, | |
| ): | |
| super().__init__() | |
| self.output_image_proj = nn.Conv2d( | |
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
| ) | |
| self.input_image_proj = nn.Conv2d( | |
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
| ) | |
| self.patch_size = patch_size | |
| self.interpolation_scale = interpolation_scale | |
| self.pos_embed_max_size = pos_embed_max_size | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim, | |
| self.pos_embed_max_size, | |
| base_size=base_size, | |
| interpolation_scale=self.interpolation_scale, | |
| output_type="pt", | |
| ) | |
| self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True) | |
| def _cropped_pos_embed(self, height, width): | |
| """Crops positional embeddings for SD3 compatibility.""" | |
| if self.pos_embed_max_size is None: | |
| raise ValueError("`pos_embed_max_size` must be set for cropping.") | |
| height = height // self.patch_size | |
| width = width // self.patch_size | |
| if height > self.pos_embed_max_size: | |
| raise ValueError( | |
| f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." | |
| ) | |
| if width > self.pos_embed_max_size: | |
| raise ValueError( | |
| f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." | |
| ) | |
| top = (self.pos_embed_max_size - height) // 2 | |
| left = (self.pos_embed_max_size - width) // 2 | |
| spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) | |
| spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] | |
| spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) | |
| return spatial_pos_embed | |
| def _patch_embeddings(self, hidden_states: torch.Tensor, is_input_image: bool) -> torch.Tensor: | |
| if is_input_image: | |
| hidden_states = self.input_image_proj(hidden_states) | |
| else: | |
| hidden_states = self.output_image_proj(hidden_states) | |
| hidden_states = hidden_states.flatten(2).transpose(1, 2) | |
| return hidden_states | |
| def forward( | |
| self, hidden_states: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None | |
| ) -> torch.Tensor: | |
| if isinstance(hidden_states, list): | |
| if padding_latent is None: | |
| padding_latent = [None] * len(hidden_states) | |
| patched_latents = [] | |
| for sub_latent, padding in zip(hidden_states, padding_latent): | |
| height, width = sub_latent.shape[-2:] | |
| sub_latent = self._patch_embeddings(sub_latent, is_input_image) | |
| pos_embed = self._cropped_pos_embed(height, width) | |
| sub_latent = sub_latent + pos_embed | |
| if padding is not None: | |
| sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2) | |
| patched_latents.append(sub_latent) | |
| else: | |
| height, width = hidden_states.shape[-2:] | |
| pos_embed = self._cropped_pos_embed(height, width) | |
| hidden_states = self._patch_embeddings(hidden_states, is_input_image) | |
| patched_latents = hidden_states + pos_embed | |
| return patched_latents | |
| class OmniGenSuScaledRotaryEmbedding(nn.Module): | |
| def __init__( | |
| self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.max_position_embeddings = max_position_embeddings | |
| self.base = base | |
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) | |
| self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) | |
| self.short_factor = rope_scaling["short_factor"] | |
| self.long_factor = rope_scaling["long_factor"] | |
| self.original_max_position_embeddings = original_max_position_embeddings | |
| def forward(self, hidden_states, position_ids): | |
| seq_len = torch.max(position_ids) + 1 | |
| if seq_len > self.original_max_position_embeddings: | |
| ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=hidden_states.device) | |
| else: | |
| ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=hidden_states.device) | |
| inv_freq_shape = ( | |
| torch.arange(0, self.dim, 2, dtype=torch.int64, device=hidden_states.device).float() / self.dim | |
| ) | |
| self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) | |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | |
| position_ids_expanded = position_ids[:, None, :].float() | |
| # Force float32 since bfloat16 loses precision on long contexts | |
| # See https://github.com/huggingface/transformers/pull/29285 | |
| device_type = hidden_states.device.type | |
| device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | |
| with torch.autocast(device_type=device_type, enabled=False): | |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | |
| emb = torch.cat((freqs, freqs), dim=-1)[0] | |
| scale = self.max_position_embeddings / self.original_max_position_embeddings | |
| if scale <= 1.0: | |
| scaling_factor = 1.0 | |
| else: | |
| scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) | |
| cos = emb.cos() * scaling_factor | |
| sin = emb.sin() * scaling_factor | |
| return cos, sin | |
| class OmniGenAttnProcessor2_0: | |
| r""" | |
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is | |
| used in the OmniGen model. | |
| """ | |
| def __init__(self): | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| image_rotary_emb: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| batch_size, sequence_length, _ = hidden_states.shape | |
| # Get Query-Key-Value Pair | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| bsz, q_len, query_dim = query.size() | |
| inner_dim = key.shape[-1] | |
| head_dim = query_dim // attn.heads | |
| # Get key-value heads | |
| kv_heads = inner_dim // head_dim | |
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
| key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) | |
| value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) | |
| # Apply RoPE if needed | |
| if image_rotary_emb is not None: | |
| from ..embeddings import apply_rotary_emb | |
| query = apply_rotary_emb(query, image_rotary_emb, use_real_unbind_dim=-2) | |
| key = apply_rotary_emb(key, image_rotary_emb, use_real_unbind_dim=-2) | |
| hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) | |
| hidden_states = hidden_states.transpose(1, 2).type_as(query) | |
| hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim) | |
| hidden_states = attn.to_out[0](hidden_states) | |
| return hidden_states | |
| class OmniGenBlock(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_attention_heads: int, | |
| num_key_value_heads: int, | |
| intermediate_size: int, | |
| rms_norm_eps: float, | |
| ) -> None: | |
| super().__init__() | |
| self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) | |
| self.self_attn = Attention( | |
| query_dim=hidden_size, | |
| cross_attention_dim=hidden_size, | |
| dim_head=hidden_size // num_attention_heads, | |
| heads=num_attention_heads, | |
| kv_heads=num_key_value_heads, | |
| bias=False, | |
| out_dim=hidden_size, | |
| out_bias=False, | |
| processor=OmniGenAttnProcessor2_0(), | |
| ) | |
| self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) | |
| self.mlp = OmniGenFeedForward(hidden_size, intermediate_size) | |
| def forward( | |
| self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor | |
| ) -> torch.Tensor: | |
| # 1. Attention | |
| norm_hidden_states = self.input_layernorm(hidden_states) | |
| attn_output = self.self_attn( | |
| hidden_states=norm_hidden_states, | |
| encoder_hidden_states=norm_hidden_states, | |
| attention_mask=attention_mask, | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| hidden_states = hidden_states + attn_output | |
| # 2. Feed Forward | |
| norm_hidden_states = self.post_attention_layernorm(hidden_states) | |
| ff_output = self.mlp(norm_hidden_states) | |
| hidden_states = hidden_states + ff_output | |
| return hidden_states | |
| class OmniGenTransformer2DModel(ModelMixin, ConfigMixin): | |
| """ | |
| The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340). | |
| Parameters: | |
| in_channels (`int`, defaults to `4`): | |
| The number of channels in the input. | |
| patch_size (`int`, defaults to `2`): | |
| The size of the spatial patches to use in the patch embedding layer. | |
| hidden_size (`int`, defaults to `3072`): | |
| The dimensionality of the hidden layers in the model. | |
| rms_norm_eps (`float`, defaults to `1e-5`): | |
| Eps for RMSNorm layer. | |
| num_attention_heads (`int`, defaults to `32`): | |
| The number of heads to use for multi-head attention. | |
| num_key_value_heads (`int`, defaults to `32`): | |
| The number of heads to use for keys and values in multi-head attention. | |
| intermediate_size (`int`, defaults to `8192`): | |
| Dimension of the hidden layer in FeedForward layers. | |
| num_layers (`int`, default to `32`): | |
| The number of layers of transformer blocks to use. | |
| pad_token_id (`int`, default to `32000`): | |
| The id of the padding token. | |
| vocab_size (`int`, default to `32064`): | |
| The size of the vocabulary of the embedding vocabulary. | |
| rope_base (`int`, default to `10000`): | |
| The default theta value to use when creating RoPE. | |
| rope_scaling (`Dict`, optional): | |
| The scaling factors for the RoPE. Must contain `short_factor` and `long_factor`. | |
| pos_embed_max_size (`int`, default to `192`): | |
| The maximum size of the positional embeddings. | |
| time_step_dim (`int`, default to `256`): | |
| Output dimension of timestep embeddings. | |
| flip_sin_to_cos (`bool`, default to `True`): | |
| Whether to flip the sin and cos in the positional embeddings when preparing timestep embeddings. | |
| downscale_freq_shift (`int`, default to `0`): | |
| The frequency shift to use when downscaling the timestep embeddings. | |
| timestep_activation_fn (`str`, default to `silu`): | |
| The activation function to use for the timestep embeddings. | |
| """ | |
| _supports_gradient_checkpointing = True | |
| _no_split_modules = ["OmniGenBlock"] | |
| _skip_layerwise_casting_patterns = ["patch_embedding", "embed_tokens", "norm"] | |
| def __init__( | |
| self, | |
| in_channels: int = 4, | |
| patch_size: int = 2, | |
| hidden_size: int = 3072, | |
| rms_norm_eps: float = 1e-5, | |
| num_attention_heads: int = 32, | |
| num_key_value_heads: int = 32, | |
| intermediate_size: int = 8192, | |
| num_layers: int = 32, | |
| pad_token_id: int = 32000, | |
| vocab_size: int = 32064, | |
| max_position_embeddings: int = 131072, | |
| original_max_position_embeddings: int = 4096, | |
| rope_base: int = 10000, | |
| rope_scaling: Dict = None, | |
| pos_embed_max_size: int = 192, | |
| time_step_dim: int = 256, | |
| flip_sin_to_cos: bool = True, | |
| downscale_freq_shift: int = 0, | |
| timestep_activation_fn: str = "silu", | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels | |
| self.patch_embedding = OmniGenPatchEmbed( | |
| patch_size=patch_size, | |
| in_channels=in_channels, | |
| embed_dim=hidden_size, | |
| pos_embed_max_size=pos_embed_max_size, | |
| ) | |
| self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift) | |
| self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) | |
| self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn) | |
| self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id) | |
| self.rope = OmniGenSuScaledRotaryEmbedding( | |
| hidden_size // num_attention_heads, | |
| max_position_embeddings=max_position_embeddings, | |
| original_max_position_embeddings=original_max_position_embeddings, | |
| base=rope_base, | |
| rope_scaling=rope_scaling, | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| OmniGenBlock(hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, rms_norm_eps) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) | |
| self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1) | |
| self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True) | |
| self.gradient_checkpointing = False | |
| def _get_multimodal_embeddings( | |
| self, input_ids: torch.Tensor, input_img_latents: List[torch.Tensor], input_image_sizes: Dict | |
| ) -> Optional[torch.Tensor]: | |
| if input_ids is None: | |
| return None | |
| input_img_latents = [x.to(self.dtype) for x in input_img_latents] | |
| condition_tokens = self.embed_tokens(input_ids) | |
| input_img_inx = 0 | |
| input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True) | |
| for b_inx in input_image_sizes.keys(): | |
| for start_inx, end_inx in input_image_sizes[b_inx]: | |
| # replace the placeholder in text tokens with the image embedding. | |
| condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to( | |
| condition_tokens.dtype | |
| ) | |
| input_img_inx += 1 | |
| return condition_tokens | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| timestep: Union[int, float, torch.FloatTensor], | |
| input_ids: torch.Tensor, | |
| input_img_latents: List[torch.Tensor], | |
| input_image_sizes: Dict[int, List[int]], | |
| attention_mask: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| return_dict: bool = True, | |
| ) -> Union[Transformer2DModelOutput, Tuple[torch.Tensor]]: | |
| batch_size, num_channels, height, width = hidden_states.shape | |
| p = self.config.patch_size | |
| post_patch_height, post_patch_width = height // p, width // p | |
| # 1. Patch & Timestep & Conditional Embedding | |
| hidden_states = self.patch_embedding(hidden_states, is_input_image=False) | |
| num_tokens_for_output_image = hidden_states.size(1) | |
| timestep_proj = self.time_proj(timestep).type_as(hidden_states) | |
| time_token = self.time_token(timestep_proj).unsqueeze(1) | |
| temb = self.t_embedder(timestep_proj) | |
| condition_tokens = self._get_multimodal_embeddings(input_ids, input_img_latents, input_image_sizes) | |
| if condition_tokens is not None: | |
| hidden_states = torch.cat([condition_tokens, time_token, hidden_states], dim=1) | |
| else: | |
| hidden_states = torch.cat([time_token, hidden_states], dim=1) | |
| seq_length = hidden_states.size(1) | |
| position_ids = position_ids.view(-1, seq_length).long() | |
| # 2. Attention mask preprocessing | |
| if attention_mask is not None and attention_mask.dim() == 3: | |
| dtype = hidden_states.dtype | |
| min_dtype = torch.finfo(dtype).min | |
| attention_mask = (1 - attention_mask) * min_dtype | |
| attention_mask = attention_mask.unsqueeze(1).type_as(hidden_states) | |
| # 3. Rotary position embedding | |
| image_rotary_emb = self.rope(hidden_states, position_ids) | |
| # 4. Transformer blocks | |
| for block in self.layers: | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| hidden_states = self._gradient_checkpointing_func( | |
| block, hidden_states, attention_mask, image_rotary_emb | |
| ) | |
| else: | |
| hidden_states = block(hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb) | |
| # 5. Output norm & projection | |
| hidden_states = self.norm(hidden_states) | |
| hidden_states = hidden_states[:, -num_tokens_for_output_image:] | |
| hidden_states = self.norm_out(hidden_states, temb=temb) | |
| hidden_states = self.proj_out(hidden_states) | |
| hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, p, p, -1) | |
| output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3) | |
| if not return_dict: | |
| return (output,) | |
| return Transformer2DModelOutput(sample=output) | |