| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import math |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | from ..utils import deprecate |
| | from .activations import FP32SiLU, get_activation |
| | from .attention_processor import Attention |
| |
|
| |
|
| | def get_timestep_embedding( |
| | timesteps: torch.Tensor, |
| | embedding_dim: int, |
| | flip_sin_to_cos: bool = False, |
| | downscale_freq_shift: float = 1, |
| | scale: float = 1, |
| | max_period: int = 10000, |
| | ): |
| | """ |
| | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| | |
| | :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| | These may be fractional. |
| | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
| | embeddings. :return: an [N x dim] Tensor of positional embeddings. |
| | """ |
| | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
| |
|
| | half_dim = embedding_dim // 2 |
| | exponent = -math.log(max_period) * torch.arange( |
| | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device |
| | ) |
| | exponent = exponent / (half_dim - downscale_freq_shift) |
| |
|
| | emb = torch.exp(exponent) |
| | emb = timesteps[:, None].float() * emb[None, :] |
| |
|
| | |
| | emb = scale * emb |
| |
|
| | |
| | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
| |
|
| | |
| | if flip_sin_to_cos: |
| | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
| |
|
| | |
| | if embedding_dim % 2 == 1: |
| | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| | return emb |
| |
|
| |
|
| | def get_2d_sincos_pos_embed( |
| | embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 |
| | ): |
| | """ |
| | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or |
| | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| | """ |
| | if isinstance(grid_size, int): |
| | grid_size = (grid_size, grid_size) |
| |
|
| | grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale |
| | grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale |
| | grid = np.meshgrid(grid_w, grid_h) |
| | grid = np.stack(grid, axis=0) |
| |
|
| | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) |
| | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| | if cls_token and extra_tokens > 0: |
| | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) |
| | return pos_embed |
| |
|
| |
|
| | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| | if embed_dim % 2 != 0: |
| | raise ValueError("embed_dim must be divisible by 2") |
| |
|
| | |
| | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
| |
|
| | emb = np.concatenate([emb_h, emb_w], axis=1) |
| | return emb |
| |
|
| |
|
| | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| | """ |
| | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) |
| | """ |
| | if embed_dim % 2 != 0: |
| | raise ValueError("embed_dim must be divisible by 2") |
| |
|
| | omega = np.arange(embed_dim // 2, dtype=np.float64) |
| | omega /= embed_dim / 2.0 |
| | omega = 1.0 / 10000**omega |
| |
|
| | pos = pos.reshape(-1) |
| | out = np.einsum("m,d->md", pos, omega) |
| |
|
| | emb_sin = np.sin(out) |
| | emb_cos = np.cos(out) |
| |
|
| | emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| | return emb |
| |
|
| |
|
| | class PatchEmbed(nn.Module): |
| | """2D Image to Patch Embedding with support for SD3 cropping.""" |
| |
|
| | def __init__( |
| | self, |
| | height=224, |
| | width=224, |
| | patch_size=16, |
| | in_channels=3, |
| | embed_dim=768, |
| | layer_norm=False, |
| | flatten=True, |
| | bias=True, |
| | interpolation_scale=1, |
| | pos_embed_type="sincos", |
| | pos_embed_max_size=None, |
| | ): |
| | super().__init__() |
| |
|
| | num_patches = (height // patch_size) * (width // patch_size) |
| | self.flatten = flatten |
| | self.layer_norm = layer_norm |
| | self.pos_embed_max_size = pos_embed_max_size |
| |
|
| | self.proj = nn.Conv2d( |
| | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias |
| | ) |
| | if layer_norm: |
| | self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) |
| | else: |
| | self.norm = None |
| |
|
| | self.patch_size = patch_size |
| | self.height, self.width = height // patch_size, width // patch_size |
| | self.base_size = height // patch_size |
| | self.interpolation_scale = interpolation_scale |
| |
|
| | |
| | if pos_embed_max_size: |
| | grid_size = pos_embed_max_size |
| | else: |
| | grid_size = int(num_patches**0.5) |
| |
|
| | if pos_embed_type is None: |
| | self.pos_embed = None |
| | elif pos_embed_type == "sincos": |
| | pos_embed = get_2d_sincos_pos_embed( |
| | embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale |
| | ) |
| | persistent = True if pos_embed_max_size else False |
| | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) |
| | else: |
| | raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") |
| |
|
| | def cropped_pos_embed(self, height, width): |
| | """Crops positional embeddings for SD3 compatibility.""" |
| | if self.pos_embed_max_size is None: |
| | raise ValueError("`pos_embed_max_size` must be set for cropping.") |
| |
|
| | height = height // self.patch_size |
| | width = width // self.patch_size |
| | if height > self.pos_embed_max_size: |
| | raise ValueError( |
| | f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." |
| | ) |
| | if width > self.pos_embed_max_size: |
| | raise ValueError( |
| | f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." |
| | ) |
| |
|
| | top = (self.pos_embed_max_size - height) // 2 |
| | left = (self.pos_embed_max_size - width) // 2 |
| | spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) |
| | spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] |
| | spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) |
| | return spatial_pos_embed |
| |
|
| | def forward(self, latent): |
| | if self.pos_embed_max_size is not None: |
| | height, width = latent.shape[-2:] |
| | else: |
| | height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size |
| |
|
| | latent = self.proj(latent) |
| | if self.flatten: |
| | latent = latent.flatten(2).transpose(1, 2) |
| | if self.layer_norm: |
| | latent = self.norm(latent) |
| | if self.pos_embed is None: |
| | return latent.to(latent.dtype) |
| | |
| | if self.pos_embed_max_size: |
| | pos_embed = self.cropped_pos_embed(height, width) |
| | else: |
| | if self.height != height or self.width != width: |
| | pos_embed = get_2d_sincos_pos_embed( |
| | embed_dim=self.pos_embed.shape[-1], |
| | grid_size=(height, width), |
| | base_size=self.base_size, |
| | interpolation_scale=self.interpolation_scale, |
| | ) |
| | pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) |
| | else: |
| | pos_embed = self.pos_embed |
| |
|
| | return (latent + pos_embed).to(latent.dtype) |
| |
|
| |
|
| | def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): |
| | """ |
| | RoPE for image tokens with 2d structure. |
| | |
| | Args: |
| | embed_dim: (`int`): |
| | The embedding dimension size |
| | crops_coords (`Tuple[int]`) |
| | The top-left and bottom-right coordinates of the crop. |
| | grid_size (`Tuple[int]`): |
| | The grid size of the positional embedding. |
| | use_real (`bool`): |
| | If True, return real part and imaginary part separately. Otherwise, return complex numbers. |
| | |
| | Returns: |
| | `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`. |
| | """ |
| | start, stop = crops_coords |
| | grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) |
| | grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) |
| | grid = np.meshgrid(grid_w, grid_h) |
| | grid = np.stack(grid, axis=0) |
| |
|
| | grid = grid.reshape([2, 1, *grid.shape[1:]]) |
| | pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) |
| | return pos_embed |
| |
|
| |
|
| | def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): |
| | assert embed_dim % 4 == 0 |
| |
|
| | |
| | emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) |
| | emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) |
| |
|
| | if use_real: |
| | cos = torch.cat([emb_h[0], emb_w[0]], dim=1) |
| | sin = torch.cat([emb_h[1], emb_w[1]], dim=1) |
| | return cos, sin |
| | else: |
| | emb = torch.cat([emb_h, emb_w], dim=1) |
| | return emb |
| |
|
| |
|
| | def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): |
| | """ |
| | Precompute the frequency tensor for complex exponentials (cis) with given dimensions. |
| | |
| | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end |
| | index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 |
| | data type. |
| | |
| | Args: |
| | dim (`int`): Dimension of the frequency tensor. |
| | pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar |
| | theta (`float`, *optional*, defaults to 10000.0): |
| | Scaling factor for frequency computation. Defaults to 10000.0. |
| | use_real (`bool`, *optional*): |
| | If True, return real part and imaginary part separately. Otherwise, return complex numbers. |
| | |
| | Returns: |
| | `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] |
| | """ |
| | if isinstance(pos, int): |
| | pos = np.arange(pos) |
| | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| | t = torch.from_numpy(pos).to(freqs.device) |
| | freqs = torch.outer(t, freqs).float() |
| | if use_real: |
| | freqs_cos = freqs.cos().repeat_interleave(2, dim=1) |
| | freqs_sin = freqs.sin().repeat_interleave(2, dim=1) |
| | return freqs_cos, freqs_sin |
| | else: |
| | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| | return freqs_cis |
| |
|
| |
|
| | def apply_rotary_emb( |
| | x: torch.Tensor, |
| | freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings |
| | to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are |
| | reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting |
| | tensors contain rotary embeddings and are returned as real tensors. |
| | |
| | Args: |
| | x (`torch.Tensor`): |
| | Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply |
| | freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
| | """ |
| | cos, sin = freqs_cis |
| | cos = cos[None, None] |
| | sin = sin[None, None] |
| | cos, sin = cos.to(x.device), sin.to(x.device) |
| |
|
| | x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) |
| | x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) |
| | out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) |
| |
|
| | return out |
| |
|
| |
|
| | class TimestepEmbedding(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | time_embed_dim: int, |
| | act_fn: str = "silu", |
| | out_dim: int = None, |
| | post_act_fn: Optional[str] = None, |
| | cond_proj_dim=None, |
| | sample_proj_bias=True, |
| | ): |
| | super().__init__() |
| |
|
| | self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) |
| |
|
| | if cond_proj_dim is not None: |
| | self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) |
| | else: |
| | self.cond_proj = None |
| |
|
| | self.act = get_activation(act_fn) |
| |
|
| | if out_dim is not None: |
| | time_embed_dim_out = out_dim |
| | else: |
| | time_embed_dim_out = time_embed_dim |
| | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) |
| |
|
| | if post_act_fn is None: |
| | self.post_act = None |
| | else: |
| | self.post_act = get_activation(post_act_fn) |
| |
|
| | def forward(self, sample, condition=None): |
| | if condition is not None: |
| | sample = sample + self.cond_proj(condition) |
| | sample = self.linear_1(sample) |
| |
|
| | if self.act is not None: |
| | sample = self.act(sample) |
| |
|
| | sample = self.linear_2(sample) |
| |
|
| | if self.post_act is not None: |
| | sample = self.post_act(sample) |
| | return sample |
| |
|
| |
|
| | class Timesteps(nn.Module): |
| | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): |
| | super().__init__() |
| | self.num_channels = num_channels |
| | self.flip_sin_to_cos = flip_sin_to_cos |
| | self.downscale_freq_shift = downscale_freq_shift |
| |
|
| | def forward(self, timesteps): |
| | t_emb = get_timestep_embedding( |
| | timesteps, |
| | self.num_channels, |
| | flip_sin_to_cos=self.flip_sin_to_cos, |
| | downscale_freq_shift=self.downscale_freq_shift, |
| | ) |
| | return t_emb |
| |
|
| |
|
| | class GaussianFourierProjection(nn.Module): |
| | """Gaussian Fourier embeddings for noise levels.""" |
| |
|
| | def __init__( |
| | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False |
| | ): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
| | self.log = log |
| | self.flip_sin_to_cos = flip_sin_to_cos |
| |
|
| | if set_W_to_weight: |
| | |
| | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) |
| |
|
| | self.weight = self.W |
| |
|
| | def forward(self, x): |
| | if self.log: |
| | x = torch.log(x) |
| |
|
| | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi |
| |
|
| | if self.flip_sin_to_cos: |
| | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) |
| | else: |
| | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) |
| | return out |
| |
|
| |
|
| | class SinusoidalPositionalEmbedding(nn.Module): |
| | """Apply positional information to a sequence of embeddings. |
| | |
| | Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to |
| | them |
| | |
| | Args: |
| | embed_dim: (int): Dimension of the positional embedding. |
| | max_seq_length: Maximum sequence length to apply positional embeddings |
| | |
| | """ |
| |
|
| | def __init__(self, embed_dim: int, max_seq_length: int = 32): |
| | super().__init__() |
| | position = torch.arange(max_seq_length).unsqueeze(1) |
| | div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) |
| | pe = torch.zeros(1, max_seq_length, embed_dim) |
| | pe[0, :, 0::2] = torch.sin(position * div_term) |
| | pe[0, :, 1::2] = torch.cos(position * div_term) |
| | self.register_buffer("pe", pe) |
| |
|
| | def forward(self, x): |
| | _, seq_length, _ = x.shape |
| | x = x + self.pe[:, :seq_length] |
| | return x |
| |
|
| |
|
| | class ImagePositionalEmbeddings(nn.Module): |
| | """ |
| | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the |
| | height and width of the latent space. |
| | |
| | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 |
| | |
| | For VQ-diffusion: |
| | |
| | Output vector embeddings are used as input for the transformer. |
| | |
| | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. |
| | |
| | Args: |
| | num_embed (`int`): |
| | Number of embeddings for the latent pixels embeddings. |
| | height (`int`): |
| | Height of the latent image i.e. the number of height embeddings. |
| | width (`int`): |
| | Width of the latent image i.e. the number of width embeddings. |
| | embed_dim (`int`): |
| | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | num_embed: int, |
| | height: int, |
| | width: int, |
| | embed_dim: int, |
| | ): |
| | super().__init__() |
| |
|
| | self.height = height |
| | self.width = width |
| | self.num_embed = num_embed |
| | self.embed_dim = embed_dim |
| |
|
| | self.emb = nn.Embedding(self.num_embed, embed_dim) |
| | self.height_emb = nn.Embedding(self.height, embed_dim) |
| | self.width_emb = nn.Embedding(self.width, embed_dim) |
| |
|
| | def forward(self, index): |
| | emb = self.emb(index) |
| |
|
| | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) |
| |
|
| | |
| | height_emb = height_emb.unsqueeze(2) |
| |
|
| | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) |
| |
|
| | |
| | width_emb = width_emb.unsqueeze(1) |
| |
|
| | pos_emb = height_emb + width_emb |
| |
|
| | |
| | pos_emb = pos_emb.view(1, self.height * self.width, -1) |
| |
|
| | emb = emb + pos_emb[:, : emb.shape[1], :] |
| |
|
| | return emb |
| |
|
| |
|
| | class LabelEmbedding(nn.Module): |
| | """ |
| | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. |
| | |
| | Args: |
| | num_classes (`int`): The number of classes. |
| | hidden_size (`int`): The size of the vector embeddings. |
| | dropout_prob (`float`): The probability of dropping a label. |
| | """ |
| |
|
| | def __init__(self, num_classes, hidden_size, dropout_prob): |
| | super().__init__() |
| | use_cfg_embedding = dropout_prob > 0 |
| | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) |
| | self.num_classes = num_classes |
| | self.dropout_prob = dropout_prob |
| |
|
| | def token_drop(self, labels, force_drop_ids=None): |
| | """ |
| | Drops labels to enable classifier-free guidance. |
| | """ |
| | if force_drop_ids is None: |
| | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob |
| | else: |
| | drop_ids = torch.tensor(force_drop_ids == 1) |
| | labels = torch.where(drop_ids, self.num_classes, labels) |
| | return labels |
| |
|
| | def forward(self, labels: torch.LongTensor, force_drop_ids=None): |
| | use_dropout = self.dropout_prob > 0 |
| | if (self.training and use_dropout) or (force_drop_ids is not None): |
| | labels = self.token_drop(labels, force_drop_ids) |
| | embeddings = self.embedding_table(labels) |
| | return embeddings |
| |
|
| |
|
| | class TextImageProjection(nn.Module): |
| | def __init__( |
| | self, |
| | text_embed_dim: int = 1024, |
| | image_embed_dim: int = 768, |
| | cross_attention_dim: int = 768, |
| | num_image_text_embeds: int = 10, |
| | ): |
| | super().__init__() |
| |
|
| | self.num_image_text_embeds = num_image_text_embeds |
| | self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) |
| | self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) |
| |
|
| | def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): |
| | batch_size = text_embeds.shape[0] |
| |
|
| | |
| | image_text_embeds = self.image_embeds(image_embeds) |
| | image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) |
| |
|
| | |
| | text_embeds = self.text_proj(text_embeds) |
| |
|
| | return torch.cat([image_text_embeds, text_embeds], dim=1) |
| |
|
| |
|
| | class ImageProjection(nn.Module): |
| | def __init__( |
| | self, |
| | image_embed_dim: int = 768, |
| | cross_attention_dim: int = 768, |
| | num_image_text_embeds: int = 32, |
| | ): |
| | super().__init__() |
| |
|
| | self.num_image_text_embeds = num_image_text_embeds |
| | self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) |
| | self.norm = nn.LayerNorm(cross_attention_dim) |
| |
|
| | def forward(self, image_embeds: torch.Tensor): |
| | batch_size = image_embeds.shape[0] |
| |
|
| | |
| | image_embeds = self.image_embeds(image_embeds) |
| | image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) |
| | image_embeds = self.norm(image_embeds) |
| | return image_embeds |
| |
|
| |
|
| | class IPAdapterFullImageProjection(nn.Module): |
| | def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): |
| | super().__init__() |
| | from .attention import FeedForward |
| |
|
| | self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") |
| | self.norm = nn.LayerNorm(cross_attention_dim) |
| |
|
| | def forward(self, image_embeds: torch.Tensor): |
| | return self.norm(self.ff(image_embeds)) |
| |
|
| |
|
| | class IPAdapterFaceIDImageProjection(nn.Module): |
| | def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): |
| | super().__init__() |
| | from .attention import FeedForward |
| |
|
| | self.num_tokens = num_tokens |
| | self.cross_attention_dim = cross_attention_dim |
| | self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") |
| | self.norm = nn.LayerNorm(cross_attention_dim) |
| |
|
| | def forward(self, image_embeds: torch.Tensor): |
| | x = self.ff(image_embeds) |
| | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) |
| | return self.norm(x) |
| |
|
| |
|
| | class CombinedTimestepLabelEmbeddings(nn.Module): |
| | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): |
| | super().__init__() |
| |
|
| | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) |
| | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
| | self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) |
| |
|
| | def forward(self, timestep, class_labels, hidden_dtype=None): |
| | timesteps_proj = self.time_proj(timestep) |
| | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) |
| |
|
| | class_labels = self.class_embedder(class_labels) |
| |
|
| | conditioning = timesteps_emb + class_labels |
| |
|
| | return conditioning |
| |
|
| |
|
| | class CombinedTimestepTextProjEmbeddings(nn.Module): |
| | def __init__(self, embedding_dim, pooled_projection_dim): |
| | super().__init__() |
| |
|
| | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
| | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
| | self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") |
| |
|
| | def forward(self, timestep, pooled_projection): |
| | timesteps_proj = self.time_proj(timestep) |
| | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) |
| |
|
| | pooled_projections = self.text_embedder(pooled_projection) |
| |
|
| | conditioning = timesteps_emb + pooled_projections |
| |
|
| | return conditioning |
| |
|
| |
|
| | class HunyuanDiTAttentionPool(nn.Module): |
| | |
| |
|
| | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): |
| | super().__init__() |
| | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) |
| | self.k_proj = nn.Linear(embed_dim, embed_dim) |
| | self.q_proj = nn.Linear(embed_dim, embed_dim) |
| | self.v_proj = nn.Linear(embed_dim, embed_dim) |
| | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) |
| | self.num_heads = num_heads |
| |
|
| | def forward(self, x): |
| | x = x.permute(1, 0, 2) |
| | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) |
| | x = x + self.positional_embedding[:, None, :].to(x.dtype) |
| | x, _ = F.multi_head_attention_forward( |
| | query=x[:1], |
| | key=x, |
| | value=x, |
| | embed_dim_to_check=x.shape[-1], |
| | num_heads=self.num_heads, |
| | q_proj_weight=self.q_proj.weight, |
| | k_proj_weight=self.k_proj.weight, |
| | v_proj_weight=self.v_proj.weight, |
| | in_proj_weight=None, |
| | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), |
| | bias_k=None, |
| | bias_v=None, |
| | add_zero_attn=False, |
| | dropout_p=0, |
| | out_proj_weight=self.c_proj.weight, |
| | out_proj_bias=self.c_proj.bias, |
| | use_separate_proj_weight=True, |
| | training=self.training, |
| | need_weights=False, |
| | ) |
| | return x.squeeze(0) |
| |
|
| |
|
| | class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): |
| | def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048): |
| | super().__init__() |
| |
|
| | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
| | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
| |
|
| | self.pooler = HunyuanDiTAttentionPool( |
| | seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim |
| | ) |
| | |
| | self.style_embedder = nn.Embedding(1, embedding_dim) |
| | extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim |
| | self.extra_embedder = PixArtAlphaTextProjection( |
| | in_features=extra_in_dim, |
| | hidden_size=embedding_dim * 4, |
| | out_features=embedding_dim, |
| | act_fn="silu_fp32", |
| | ) |
| |
|
| | def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): |
| | timesteps_proj = self.time_proj(timestep) |
| | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) |
| |
|
| | |
| | pooled_projections = self.pooler(encoder_hidden_states) |
| |
|
| | |
| | image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0) |
| | image_meta_size = image_meta_size.to(dtype=hidden_dtype) |
| | image_meta_size = image_meta_size.view(-1, 6 * 256) |
| |
|
| | |
| | style_embedding = self.style_embedder(style) |
| |
|
| | |
| | extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) |
| | conditioning = timesteps_emb + self.extra_embedder(extra_cond) |
| |
|
| | return conditioning |
| |
|
| |
|
| | class TextTimeEmbedding(nn.Module): |
| | def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(encoder_dim) |
| | self.pool = AttentionPooling(num_heads, encoder_dim) |
| | self.proj = nn.Linear(encoder_dim, time_embed_dim) |
| | self.norm2 = nn.LayerNorm(time_embed_dim) |
| |
|
| | def forward(self, hidden_states): |
| | hidden_states = self.norm1(hidden_states) |
| | hidden_states = self.pool(hidden_states) |
| | hidden_states = self.proj(hidden_states) |
| | hidden_states = self.norm2(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class TextImageTimeEmbedding(nn.Module): |
| | def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): |
| | super().__init__() |
| | self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) |
| | self.text_norm = nn.LayerNorm(time_embed_dim) |
| | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) |
| |
|
| | def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): |
| | |
| | time_text_embeds = self.text_proj(text_embeds) |
| | time_text_embeds = self.text_norm(time_text_embeds) |
| |
|
| | |
| | time_image_embeds = self.image_proj(image_embeds) |
| |
|
| | return time_image_embeds + time_text_embeds |
| |
|
| |
|
| | class ImageTimeEmbedding(nn.Module): |
| | def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): |
| | super().__init__() |
| | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) |
| | self.image_norm = nn.LayerNorm(time_embed_dim) |
| |
|
| | def forward(self, image_embeds: torch.Tensor): |
| | |
| | time_image_embeds = self.image_proj(image_embeds) |
| | time_image_embeds = self.image_norm(time_image_embeds) |
| | return time_image_embeds |
| |
|
| |
|
| | class ImageHintTimeEmbedding(nn.Module): |
| | def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): |
| | super().__init__() |
| | self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) |
| | self.image_norm = nn.LayerNorm(time_embed_dim) |
| | self.input_hint_block = nn.Sequential( |
| | nn.Conv2d(3, 16, 3, padding=1), |
| | nn.SiLU(), |
| | nn.Conv2d(16, 16, 3, padding=1), |
| | nn.SiLU(), |
| | nn.Conv2d(16, 32, 3, padding=1, stride=2), |
| | nn.SiLU(), |
| | nn.Conv2d(32, 32, 3, padding=1), |
| | nn.SiLU(), |
| | nn.Conv2d(32, 96, 3, padding=1, stride=2), |
| | nn.SiLU(), |
| | nn.Conv2d(96, 96, 3, padding=1), |
| | nn.SiLU(), |
| | nn.Conv2d(96, 256, 3, padding=1, stride=2), |
| | nn.SiLU(), |
| | nn.Conv2d(256, 4, 3, padding=1), |
| | ) |
| |
|
| | def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): |
| | |
| | time_image_embeds = self.image_proj(image_embeds) |
| | time_image_embeds = self.image_norm(time_image_embeds) |
| | hint = self.input_hint_block(hint) |
| | return time_image_embeds, hint |
| |
|
| |
|
| | class AttentionPooling(nn.Module): |
| | |
| |
|
| | def __init__(self, num_heads, embed_dim, dtype=None): |
| | super().__init__() |
| | self.dtype = dtype |
| | self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) |
| | self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) |
| | self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) |
| | self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) |
| | self.num_heads = num_heads |
| | self.dim_per_head = embed_dim // self.num_heads |
| |
|
| | def forward(self, x): |
| | bs, length, width = x.size() |
| |
|
| | def shape(x): |
| | |
| | x = x.view(bs, -1, self.num_heads, self.dim_per_head) |
| | |
| | x = x.transpose(1, 2) |
| | |
| | x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) |
| | |
| | x = x.transpose(1, 2) |
| | return x |
| |
|
| | class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) |
| | x = torch.cat([class_token, x], dim=1) |
| |
|
| | |
| | q = shape(self.q_proj(class_token)) |
| | |
| | k = shape(self.k_proj(x)) |
| | v = shape(self.v_proj(x)) |
| |
|
| | |
| | scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) |
| | weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) |
| | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) |
| |
|
| | |
| | a = torch.einsum("bts,bcs->bct", weight, v) |
| |
|
| | |
| | a = a.reshape(bs, -1, 1).transpose(1, 2) |
| |
|
| | return a[:, 0, :] |
| |
|
| |
|
| | def get_fourier_embeds_from_boundingbox(embed_dim, box): |
| | """ |
| | Args: |
| | embed_dim: int |
| | box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline |
| | Returns: |
| | [B x N x embed_dim] tensor of positional embeddings |
| | """ |
| |
|
| | batch_size, num_boxes = box.shape[:2] |
| |
|
| | emb = 100 ** (torch.arange(embed_dim) / embed_dim) |
| | emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) |
| | emb = emb * box.unsqueeze(-1) |
| |
|
| | emb = torch.stack((emb.sin(), emb.cos()), dim=-1) |
| | emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) |
| |
|
| | return emb |
| |
|
| |
|
| | class GLIGENTextBoundingboxProjection(nn.Module): |
| | def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): |
| | super().__init__() |
| | self.positive_len = positive_len |
| | self.out_dim = out_dim |
| |
|
| | self.fourier_embedder_dim = fourier_freqs |
| | self.position_dim = fourier_freqs * 2 * 4 |
| |
|
| | if isinstance(out_dim, tuple): |
| | out_dim = out_dim[0] |
| |
|
| | if feature_type == "text-only": |
| | self.linears = nn.Sequential( |
| | nn.Linear(self.positive_len + self.position_dim, 512), |
| | nn.SiLU(), |
| | nn.Linear(512, 512), |
| | nn.SiLU(), |
| | nn.Linear(512, out_dim), |
| | ) |
| | self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) |
| |
|
| | elif feature_type == "text-image": |
| | self.linears_text = nn.Sequential( |
| | nn.Linear(self.positive_len + self.position_dim, 512), |
| | nn.SiLU(), |
| | nn.Linear(512, 512), |
| | nn.SiLU(), |
| | nn.Linear(512, out_dim), |
| | ) |
| | self.linears_image = nn.Sequential( |
| | nn.Linear(self.positive_len + self.position_dim, 512), |
| | nn.SiLU(), |
| | nn.Linear(512, 512), |
| | nn.SiLU(), |
| | nn.Linear(512, out_dim), |
| | ) |
| | self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) |
| | self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) |
| |
|
| | self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) |
| |
|
| | def forward( |
| | self, |
| | boxes, |
| | masks, |
| | positive_embeddings=None, |
| | phrases_masks=None, |
| | image_masks=None, |
| | phrases_embeddings=None, |
| | image_embeddings=None, |
| | ): |
| | masks = masks.unsqueeze(-1) |
| |
|
| | |
| | xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) |
| |
|
| | |
| | xyxy_null = self.null_position_feature.view(1, 1, -1) |
| |
|
| | |
| | xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null |
| |
|
| | |
| | if positive_embeddings is not None: |
| | |
| | positive_null = self.null_positive_feature.view(1, 1, -1) |
| |
|
| | |
| | positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null |
| |
|
| | objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) |
| |
|
| | |
| | else: |
| | phrases_masks = phrases_masks.unsqueeze(-1) |
| | image_masks = image_masks.unsqueeze(-1) |
| |
|
| | |
| | text_null = self.null_text_feature.view(1, 1, -1) |
| | image_null = self.null_image_feature.view(1, 1, -1) |
| |
|
| | |
| | phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null |
| | image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null |
| |
|
| | objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) |
| | objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) |
| | objs = torch.cat([objs_text, objs_image], dim=1) |
| |
|
| | return objs |
| |
|
| |
|
| | class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): |
| | """ |
| | For PixArt-Alpha. |
| | |
| | Reference: |
| | https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 |
| | """ |
| |
|
| | def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): |
| | super().__init__() |
| |
|
| | self.outdim = size_emb_dim |
| | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
| | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) |
| |
|
| | self.use_additional_conditions = use_additional_conditions |
| | if use_additional_conditions: |
| | self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) |
| | self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) |
| | self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) |
| |
|
| | def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): |
| | timesteps_proj = self.time_proj(timestep) |
| | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) |
| |
|
| | if self.use_additional_conditions: |
| | resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) |
| | resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) |
| | aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) |
| | aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) |
| | conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) |
| | else: |
| | conditioning = timesteps_emb |
| |
|
| | return conditioning |
| |
|
| |
|
| | class PixArtAlphaTextProjection(nn.Module): |
| | """ |
| | Projects caption embeddings. Also handles dropout for classifier-free guidance. |
| | |
| | Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py |
| | """ |
| |
|
| | def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): |
| | super().__init__() |
| | if out_features is None: |
| | out_features = hidden_size |
| | self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) |
| | if act_fn == "gelu_tanh": |
| | self.act_1 = nn.GELU(approximate="tanh") |
| | elif act_fn == "silu": |
| | self.act_1 = nn.SiLU() |
| | elif act_fn == "silu_fp32": |
| | self.act_1 = FP32SiLU() |
| | else: |
| | raise ValueError(f"Unknown activation function: {act_fn}") |
| | self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) |
| |
|
| | def forward(self, caption): |
| | hidden_states = self.linear_1(caption) |
| | hidden_states = self.act_1(hidden_states) |
| | hidden_states = self.linear_2(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class IPAdapterPlusImageProjectionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dims: int = 768, |
| | dim_head: int = 64, |
| | heads: int = 16, |
| | ffn_ratio: float = 4, |
| | ) -> None: |
| | super().__init__() |
| | from .attention import FeedForward |
| |
|
| | self.ln0 = nn.LayerNorm(embed_dims) |
| | self.ln1 = nn.LayerNorm(embed_dims) |
| | self.attn = Attention( |
| | query_dim=embed_dims, |
| | dim_head=dim_head, |
| | heads=heads, |
| | out_bias=False, |
| | ) |
| | self.ff = nn.Sequential( |
| | nn.LayerNorm(embed_dims), |
| | FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), |
| | ) |
| |
|
| | def forward(self, x, latents, residual): |
| | encoder_hidden_states = self.ln0(x) |
| | latents = self.ln1(latents) |
| | encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) |
| | latents = self.attn(latents, encoder_hidden_states) + residual |
| | latents = self.ff(latents) + latents |
| | return latents |
| |
|
| |
|
| | class IPAdapterPlusImageProjection(nn.Module): |
| | """Resampler of IP-Adapter Plus. |
| | |
| | Args: |
| | embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, |
| | that is the same |
| | number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. |
| | hidden_dims (int): |
| | The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults |
| | to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. |
| | Defaults to 16. num_queries (int): |
| | The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio |
| | of feedforward network hidden |
| | layer channels. Defaults to 4. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embed_dims: int = 768, |
| | output_dims: int = 1024, |
| | hidden_dims: int = 1280, |
| | depth: int = 4, |
| | dim_head: int = 64, |
| | heads: int = 16, |
| | num_queries: int = 8, |
| | ffn_ratio: float = 4, |
| | ) -> None: |
| | super().__init__() |
| | self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) |
| |
|
| | self.proj_in = nn.Linear(embed_dims, hidden_dims) |
| |
|
| | self.proj_out = nn.Linear(hidden_dims, output_dims) |
| | self.norm_out = nn.LayerNorm(output_dims) |
| |
|
| | self.layers = nn.ModuleList( |
| | [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Forward pass. |
| | |
| | Args: |
| | x (torch.Tensor): Input Tensor. |
| | Returns: |
| | torch.Tensor: Output Tensor. |
| | """ |
| | latents = self.latents.repeat(x.size(0), 1, 1) |
| |
|
| | x = self.proj_in(x) |
| |
|
| | for block in self.layers: |
| | residual = latents |
| | latents = block(x, latents, residual) |
| |
|
| | latents = self.proj_out(latents) |
| | return self.norm_out(latents) |
| |
|
| |
|
| | class IPAdapterFaceIDPlusImageProjection(nn.Module): |
| | """FacePerceiverResampler of IP-Adapter Plus. |
| | |
| | Args: |
| | embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, |
| | that is the same |
| | number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. |
| | hidden_dims (int): |
| | The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults |
| | to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. |
| | Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. |
| | ffn_ratio (float): The expansion ratio of feedforward network hidden |
| | layer channels. Defaults to 4. |
| | ffproj_ratio (float): The expansion ratio of feedforward network hidden |
| | layer channels (for ID embeddings). Defaults to 4. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embed_dims: int = 768, |
| | output_dims: int = 768, |
| | hidden_dims: int = 1280, |
| | id_embeddings_dim: int = 512, |
| | depth: int = 4, |
| | dim_head: int = 64, |
| | heads: int = 16, |
| | num_tokens: int = 4, |
| | num_queries: int = 8, |
| | ffn_ratio: float = 4, |
| | ffproj_ratio: int = 2, |
| | ) -> None: |
| | super().__init__() |
| | from .attention import FeedForward |
| |
|
| | self.num_tokens = num_tokens |
| | self.embed_dim = embed_dims |
| | self.clip_embeds = None |
| | self.shortcut = False |
| | self.shortcut_scale = 1.0 |
| |
|
| | self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) |
| | self.norm = nn.LayerNorm(embed_dims) |
| |
|
| | self.proj_in = nn.Linear(hidden_dims, embed_dims) |
| |
|
| | self.proj_out = nn.Linear(embed_dims, output_dims) |
| | self.norm_out = nn.LayerNorm(output_dims) |
| |
|
| | self.layers = nn.ModuleList( |
| | [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] |
| | ) |
| |
|
| | def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: |
| | """Forward pass. |
| | |
| | Args: |
| | id_embeds (torch.Tensor): Input Tensor (ID embeds). |
| | Returns: |
| | torch.Tensor: Output Tensor. |
| | """ |
| | id_embeds = id_embeds.to(self.clip_embeds.dtype) |
| | id_embeds = self.proj(id_embeds) |
| | id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) |
| | id_embeds = self.norm(id_embeds) |
| | latents = id_embeds |
| |
|
| | clip_embeds = self.proj_in(self.clip_embeds) |
| | x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) |
| |
|
| | for block in self.layers: |
| | residual = latents |
| | latents = block(x, latents, residual) |
| |
|
| | latents = self.proj_out(latents) |
| | out = self.norm_out(latents) |
| | if self.shortcut: |
| | out = id_embeds + self.shortcut_scale * out |
| | return out |
| |
|
| |
|
| | class MultiIPAdapterImageProjection(nn.Module): |
| | def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): |
| | super().__init__() |
| | self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) |
| |
|
| | def forward(self, image_embeds: List[torch.Tensor]): |
| | projected_image_embeds = [] |
| |
|
| | |
| | |
| | |
| | if not isinstance(image_embeds, list): |
| | deprecation_message = ( |
| | "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." |
| | " Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning." |
| | ) |
| | deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) |
| | image_embeds = [image_embeds.unsqueeze(1)] |
| |
|
| | if len(image_embeds) != len(self.image_projection_layers): |
| | raise ValueError( |
| | f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" |
| | ) |
| |
|
| | for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): |
| | batch_size, num_images = image_embed.shape[0], image_embed.shape[1] |
| | image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) |
| | image_embed = image_projection_layer(image_embed) |
| | image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) |
| |
|
| | projected_image_embeds.append(image_embed) |
| |
|
| | return projected_image_embeds |
| |
|