Spaces:
Running
on
Zero
Running
on
Zero
| # Original code: https://github.com/VectorSpaceLab/OmniGen2 | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from comfy.ldm.lightricks.model import Timesteps | |
| from comfy.ldm.flux.layers import EmbedND | |
| from comfy.ldm.modules.attention import optimized_attention_masked | |
| import comfy.model_management | |
| import comfy.ldm.common_dit | |
| def apply_rotary_emb(x, freqs_cis): | |
| if x.shape[1] == 0: | |
| return x | |
| t_ = x.reshape(*x.shape[:-1], -1, 1, 2) | |
| t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] | |
| return t_out.reshape(*x.shape).to(dtype=x.dtype) | |
| def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |
| return F.silu(x) * y | |
| class TimestepEmbedding(nn.Module): | |
| def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device) | |
| self.act = nn.SiLU() | |
| self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device) | |
| def forward(self, sample: torch.Tensor) -> torch.Tensor: | |
| sample = self.linear_1(sample) | |
| sample = self.act(sample) | |
| sample = self.linear_2(sample) | |
| return sample | |
| class LuminaRMSNormZero(nn.Module): | |
| def __init__(self, embedding_dim: int, norm_eps: float = 1e-5, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.silu = nn.SiLU() | |
| self.linear = operations.Linear(min(embedding_dim, 1024), 4 * embedding_dim, dtype=dtype, device=device) | |
| self.norm = operations.RMSNorm(embedding_dim, eps=norm_eps, dtype=dtype, device=device) | |
| def forward(self, x: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| emb = self.linear(self.silu(emb)) | |
| scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) | |
| x = self.norm(x) * (1 + scale_msa[:, None]) | |
| return x, gate_msa, scale_mlp, gate_mlp | |
| class LuminaLayerNormContinuous(nn.Module): | |
| def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, elementwise_affine: bool = False, eps: float = 1e-6, out_dim: Optional[int] = None, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.silu = nn.SiLU() | |
| self.linear_1 = operations.Linear(conditioning_embedding_dim, embedding_dim, dtype=dtype, device=device) | |
| self.norm = operations.LayerNorm(embedding_dim, eps, elementwise_affine, dtype=dtype, device=device) | |
| self.linear_2 = operations.Linear(embedding_dim, out_dim, bias=True, dtype=dtype, device=device) if out_dim is not None else None | |
| def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: | |
| emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) | |
| x = self.norm(x) * (1 + emb)[:, None, :] | |
| if self.linear_2 is not None: | |
| x = self.linear_2(x) | |
| return x | |
| class LuminaFeedForward(nn.Module): | |
| def __init__(self, dim: int, inner_dim: int, multiple_of: int = 256, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) | |
| self.linear_1 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device) | |
| self.linear_2 = operations.Linear(inner_dim, dim, bias=False, dtype=dtype, device=device) | |
| self.linear_3 = operations.Linear(dim, inner_dim, bias=False, dtype=dtype, device=device) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| h1, h2 = self.linear_1(x), self.linear_3(x) | |
| return self.linear_2(swiglu(h1, h2)) | |
| class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): | |
| def __init__(self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale) | |
| self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024), dtype=dtype, device=device, operations=operations) | |
| self.caption_embedder = nn.Sequential( | |
| operations.RMSNorm(text_feat_dim, eps=norm_eps, dtype=dtype, device=device), | |
| operations.Linear(text_feat_dim, hidden_size, bias=True, dtype=dtype, device=device), | |
| ) | |
| def forward(self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]: | |
| timestep_proj = self.time_proj(timestep).to(dtype=dtype) | |
| time_embed = self.timestep_embedder(timestep_proj) | |
| caption_embed = self.caption_embedder(text_hidden_states) | |
| return time_embed, caption_embed | |
| class Attention(nn.Module): | |
| def __init__(self, query_dim: int, dim_head: int, heads: int, kv_heads: int, eps: float = 1e-5, bias: bool = False, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.heads = heads | |
| self.kv_heads = kv_heads | |
| self.dim_head = dim_head | |
| self.scale = dim_head ** -0.5 | |
| self.to_q = operations.Linear(query_dim, heads * dim_head, bias=bias, dtype=dtype, device=device) | |
| self.to_k = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device) | |
| self.to_v = operations.Linear(query_dim, kv_heads * dim_head, bias=bias, dtype=dtype, device=device) | |
| self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) | |
| self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) | |
| self.to_out = nn.Sequential( | |
| operations.Linear(heads * dim_head, query_dim, bias=bias, dtype=dtype, device=device), | |
| nn.Dropout(0.0) | |
| ) | |
| def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| batch_size, sequence_length, _ = hidden_states.shape | |
| query = self.to_q(hidden_states) | |
| key = self.to_k(encoder_hidden_states) | |
| value = self.to_v(encoder_hidden_states) | |
| query = query.view(batch_size, -1, self.heads, self.dim_head) | |
| key = key.view(batch_size, -1, self.kv_heads, self.dim_head) | |
| value = value.view(batch_size, -1, self.kv_heads, self.dim_head) | |
| query = self.norm_q(query) | |
| key = self.norm_k(key) | |
| if image_rotary_emb is not None: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| query = query.transpose(1, 2) | |
| key = key.transpose(1, 2) | |
| value = value.transpose(1, 2) | |
| if self.kv_heads < self.heads: | |
| key = key.repeat_interleave(self.heads // self.kv_heads, dim=1) | |
| value = value.repeat_interleave(self.heads // self.kv_heads, dim=1) | |
| hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True) | |
| hidden_states = self.to_out[0](hidden_states) | |
| return hidden_states | |
| class OmniGen2TransformerBlock(nn.Module): | |
| def __init__(self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, dtype=None, device=None, operations=None): | |
| super().__init__() | |
| self.modulation = modulation | |
| self.attn = Attention( | |
| query_dim=dim, | |
| dim_head=dim // num_attention_heads, | |
| heads=num_attention_heads, | |
| kv_heads=num_kv_heads, | |
| eps=1e-5, | |
| bias=False, | |
| dtype=dtype, device=device, operations=operations, | |
| ) | |
| self.feed_forward = LuminaFeedForward( | |
| dim=dim, | |
| inner_dim=4 * dim, | |
| multiple_of=multiple_of, | |
| dtype=dtype, device=device, operations=operations | |
| ) | |
| if modulation: | |
| self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) | |
| else: | |
| self.norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) | |
| self.ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) | |
| self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) | |
| self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) | |
| def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| if self.modulation: | |
| norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) | |
| attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) | |
| hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) | |
| mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) | |
| hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) | |
| else: | |
| norm_hidden_states = self.norm1(hidden_states) | |
| attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb) | |
| hidden_states = hidden_states + self.norm2(attn_output) | |
| mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) | |
| hidden_states = hidden_states + self.ffn_norm2(mlp_output) | |
| return hidden_states | |
| class OmniGen2RotaryPosEmbed(nn.Module): | |
| def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2): | |
| super().__init__() | |
| self.theta = theta | |
| self.axes_dim = axes_dim | |
| self.axes_lens = axes_lens | |
| self.patch_size = patch_size | |
| self.rope_embedder = EmbedND(dim=sum(axes_dim), theta=self.theta, axes_dim=axes_dim) | |
| def forward(self, batch_size, encoder_seq_len, l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device): | |
| p = self.patch_size | |
| seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] | |
| max_seq_len = max(seq_lengths) | |
| max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) | |
| max_img_len = max(l_effective_img_len) | |
| position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) | |
| for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): | |
| position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") | |
| pe_shift = cap_seq_len | |
| pe_shift_len = cap_seq_len | |
| if ref_img_sizes[i] is not None: | |
| for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): | |
| H, W = ref_img_size | |
| ref_H_tokens, ref_W_tokens = H // p, W // p | |
| row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() | |
| col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() | |
| position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift | |
| position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids | |
| position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids | |
| pe_shift += max(ref_H_tokens, ref_W_tokens) | |
| pe_shift_len += ref_img_len | |
| H, W = img_sizes[i] | |
| H_tokens, W_tokens = H // p, W // p | |
| row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() | |
| col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() | |
| position_ids[i, pe_shift_len: seq_len, 0] = pe_shift | |
| position_ids[i, pe_shift_len: seq_len, 1] = row_ids | |
| position_ids[i, pe_shift_len: seq_len, 2] = col_ids | |
| freqs_cis = self.rope_embedder(position_ids).movedim(1, 2) | |
| cap_freqs_cis_shape = list(freqs_cis.shape) | |
| cap_freqs_cis_shape[1] = encoder_seq_len | |
| cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) | |
| ref_img_freqs_cis_shape = list(freqs_cis.shape) | |
| ref_img_freqs_cis_shape[1] = max_ref_img_len | |
| ref_img_freqs_cis = torch.zeros(*ref_img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) | |
| img_freqs_cis_shape = list(freqs_cis.shape) | |
| img_freqs_cis_shape[1] = max_img_len | |
| img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) | |
| for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): | |
| cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] | |
| ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] | |
| img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] | |
| return cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths | |
| class OmniGen2Transformer2DModel(nn.Module): | |
| def __init__( | |
| self, | |
| patch_size: int = 2, | |
| in_channels: int = 16, | |
| out_channels: Optional[int] = None, | |
| hidden_size: int = 2304, | |
| num_layers: int = 26, | |
| num_refiner_layers: int = 2, | |
| num_attention_heads: int = 24, | |
| num_kv_heads: int = 8, | |
| multiple_of: int = 256, | |
| ffn_dim_multiplier: Optional[float] = None, | |
| norm_eps: float = 1e-5, | |
| axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), | |
| axes_lens: Tuple[int, int, int] = (300, 512, 512), | |
| text_feat_dim: int = 1024, | |
| timestep_scale: float = 1.0, | |
| image_model=None, | |
| device=None, | |
| dtype=None, | |
| operations=None, | |
| ): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.out_channels = out_channels or in_channels | |
| self.hidden_size = hidden_size | |
| self.dtype = dtype | |
| self.rope_embedder = OmniGen2RotaryPosEmbed( | |
| theta=10000, | |
| axes_dim=axes_dim_rope, | |
| axes_lens=axes_lens, | |
| patch_size=patch_size, | |
| ) | |
| self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device) | |
| self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device) | |
| self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( | |
| hidden_size=hidden_size, | |
| text_feat_dim=text_feat_dim, | |
| norm_eps=norm_eps, | |
| timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations | |
| ) | |
| self.noise_refiner = nn.ModuleList([ | |
| OmniGen2TransformerBlock( | |
| hidden_size, num_attention_heads, num_kv_heads, | |
| multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations | |
| ) for _ in range(num_refiner_layers) | |
| ]) | |
| self.ref_image_refiner = nn.ModuleList([ | |
| OmniGen2TransformerBlock( | |
| hidden_size, num_attention_heads, num_kv_heads, | |
| multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations | |
| ) for _ in range(num_refiner_layers) | |
| ]) | |
| self.context_refiner = nn.ModuleList([ | |
| OmniGen2TransformerBlock( | |
| hidden_size, num_attention_heads, num_kv_heads, | |
| multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations | |
| ) for _ in range(num_refiner_layers) | |
| ]) | |
| self.layers = nn.ModuleList([ | |
| OmniGen2TransformerBlock( | |
| hidden_size, num_attention_heads, num_kv_heads, | |
| multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations | |
| ) for _ in range(num_layers) | |
| ]) | |
| self.norm_out = LuminaLayerNormContinuous( | |
| embedding_dim=hidden_size, | |
| conditioning_embedding_dim=min(hidden_size, 1024), | |
| elementwise_affine=False, | |
| eps=1e-6, | |
| out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations | |
| ) | |
| self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype)) | |
| def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): | |
| batch_size = len(hidden_states) | |
| p = self.patch_size | |
| img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] | |
| l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] | |
| if ref_image_hidden_states is not None: | |
| ref_image_hidden_states = list(map(lambda ref: comfy.ldm.common_dit.pad_to_patch_size(ref, (p, p)), ref_image_hidden_states)) | |
| ref_img_sizes = [[(imgs.size(2), imgs.size(3)) if imgs is not None else None for imgs in ref_image_hidden_states]] * batch_size | |
| l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes] | |
| else: | |
| ref_img_sizes = [None for _ in range(batch_size)] | |
| l_effective_ref_img_len = [[0] for _ in range(batch_size)] | |
| flat_ref_img_hidden_states = None | |
| if ref_image_hidden_states is not None: | |
| imgs = [] | |
| for ref_img in ref_image_hidden_states: | |
| B, C, H, W = ref_img.size() | |
| ref_img = rearrange(ref_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) | |
| imgs.append(ref_img) | |
| flat_ref_img_hidden_states = torch.cat(imgs, dim=1) | |
| img = hidden_states | |
| B, C, H, W = img.size() | |
| flat_hidden_states = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) | |
| return ( | |
| flat_hidden_states, flat_ref_img_hidden_states, | |
| None, None, | |
| l_effective_ref_img_len, l_effective_img_len, | |
| ref_img_sizes, img_sizes, | |
| ) | |
| def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb): | |
| batch_size = len(hidden_states) | |
| hidden_states = self.x_embedder(hidden_states) | |
| if ref_image_hidden_states is not None: | |
| ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) | |
| image_index_embedding = comfy.model_management.cast_to(self.image_index_embedding, dtype=hidden_states.dtype, device=hidden_states.device) | |
| for i in range(batch_size): | |
| shift = 0 | |
| for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): | |
| ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + image_index_embedding[j] | |
| shift += ref_img_len | |
| for layer in self.noise_refiner: | |
| hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) | |
| if ref_image_hidden_states is not None: | |
| for layer in self.ref_image_refiner: | |
| ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb) | |
| hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1) | |
| return hidden_states | |
| def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs): | |
| B, C, H, W = x.shape | |
| hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) | |
| _, _, H_padded, W_padded = hidden_states.shape | |
| timestep = 1.0 - timesteps | |
| text_hidden_states = context | |
| text_attention_mask = attention_mask | |
| ref_image_hidden_states = ref_latents | |
| device = hidden_states.device | |
| temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) | |
| ( | |
| hidden_states, ref_image_hidden_states, | |
| img_mask, ref_img_mask, | |
| l_effective_ref_img_len, l_effective_img_len, | |
| ref_img_sizes, img_sizes, | |
| ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) | |
| ( | |
| context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb, | |
| rotary_emb, encoder_seq_lengths, seq_lengths, | |
| ) = self.rope_embedder( | |
| hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0], | |
| l_effective_ref_img_len, l_effective_img_len, | |
| ref_img_sizes, img_sizes, device, | |
| ) | |
| for layer in self.context_refiner: | |
| text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) | |
| img_len = hidden_states.shape[1] | |
| combined_img_hidden_states = self.img_patch_embed_and_refine( | |
| hidden_states, ref_image_hidden_states, | |
| img_mask, ref_img_mask, | |
| noise_rotary_emb, ref_img_rotary_emb, | |
| l_effective_ref_img_len, l_effective_img_len, | |
| temb, | |
| ) | |
| hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) | |
| attention_mask = None | |
| for layer in self.layers: | |
| hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) | |
| hidden_states = self.norm_out(hidden_states, temb) | |
| p = self.patch_size | |
| output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded// p, p1=p, p2=p)[:, :, :H, :W] | |
| return -output | |