|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
|
|
|
from diffusers.models.activations import deprecate, FP32SiLU |
|
|
|
|
|
|
|
|
def pixcell_get_2d_sincos_pos_embed( |
|
|
embed_dim, |
|
|
grid_size, |
|
|
cls_token=False, |
|
|
extra_tokens=0, |
|
|
interpolation_scale=1.0, |
|
|
base_size=16, |
|
|
device: Optional[torch.device] = None, |
|
|
phase=0, |
|
|
output_type: str = "np", |
|
|
): |
|
|
""" |
|
|
Creates 2D sinusoidal positional embeddings. |
|
|
|
|
|
Args: |
|
|
embed_dim (`int`): |
|
|
The embedding dimension. |
|
|
grid_size (`int`): |
|
|
The size of the grid height and width. |
|
|
cls_token (`bool`, defaults to `False`): |
|
|
Whether or not to add a classification token. |
|
|
extra_tokens (`int`, defaults to `0`): |
|
|
The number of extra tokens to add. |
|
|
interpolation_scale (`float`, defaults to `1.0`): |
|
|
The scale of the interpolation. |
|
|
|
|
|
Returns: |
|
|
pos_embed (`torch.Tensor`): |
|
|
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, |
|
|
embed_dim]` if using cls_token |
|
|
""" |
|
|
if output_type == "np": |
|
|
deprecation_message = ( |
|
|
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`." |
|
|
" `from_numpy` is no longer required." |
|
|
" Pass `output_type='pt' to use the new version now." |
|
|
) |
|
|
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) |
|
|
raise ValueError("Not supported") |
|
|
if isinstance(grid_size, int): |
|
|
grid_size = (grid_size, grid_size) |
|
|
|
|
|
grid_h = ( |
|
|
torch.arange(grid_size[0], device=device, dtype=torch.float32) |
|
|
/ (grid_size[0] / base_size) |
|
|
/ interpolation_scale |
|
|
) |
|
|
grid_w = ( |
|
|
torch.arange(grid_size[1], device=device, dtype=torch.float32) |
|
|
/ (grid_size[1] / base_size) |
|
|
/ interpolation_scale |
|
|
) |
|
|
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") |
|
|
grid = torch.stack(grid, dim=0) |
|
|
|
|
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) |
|
|
pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type) |
|
|
if cls_token and extra_tokens > 0: |
|
|
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0) |
|
|
return pos_embed |
|
|
|
|
|
|
|
|
def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"): |
|
|
r""" |
|
|
This function generates 2D sinusoidal positional embeddings from a grid. |
|
|
|
|
|
Args: |
|
|
embed_dim (`int`): The embedding dimension. |
|
|
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` |
|
|
""" |
|
|
if output_type == "np": |
|
|
deprecation_message = ( |
|
|
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." |
|
|
" `from_numpy` is no longer required." |
|
|
" Pass `output_type='pt' to use the new version now." |
|
|
) |
|
|
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) |
|
|
raise ValueError("Not supported") |
|
|
if embed_dim % 2 != 0: |
|
|
raise ValueError("embed_dim must be divisible by 2") |
|
|
|
|
|
|
|
|
emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) |
|
|
emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) |
|
|
|
|
|
emb = torch.concat([emb_h, emb_w], dim=1) |
|
|
return emb |
|
|
|
|
|
|
|
|
def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"): |
|
|
""" |
|
|
This function generates 1D positional embeddings from a grid. |
|
|
|
|
|
Args: |
|
|
embed_dim (`int`): The embedding dimension `D` |
|
|
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. |
|
|
""" |
|
|
if output_type == "np": |
|
|
deprecation_message = ( |
|
|
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." |
|
|
" `from_numpy` is no longer required." |
|
|
" Pass `output_type='pt' to use the new version now." |
|
|
) |
|
|
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False) |
|
|
raise ValueError("Not supported") |
|
|
if embed_dim % 2 != 0: |
|
|
raise ValueError("embed_dim must be divisible by 2") |
|
|
|
|
|
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) |
|
|
omega /= embed_dim / 2.0 |
|
|
omega = 1.0 / 10000**omega |
|
|
|
|
|
pos = pos.reshape(-1) + phase |
|
|
out = torch.outer(pos, omega) |
|
|
|
|
|
emb_sin = torch.sin(out) |
|
|
emb_cos = torch.cos(out) |
|
|
|
|
|
emb = torch.concat([emb_sin, emb_cos], dim=1) |
|
|
return emb |
|
|
|
|
|
|
|
|
class PixcellUNIProjection(nn.Module): |
|
|
""" |
|
|
Projects UNI embeddings. Also handles dropout for classifier-free guidance. |
|
|
|
|
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py |
|
|
""" |
|
|
|
|
|
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1): |
|
|
super().__init__() |
|
|
if out_features is None: |
|
|
out_features = hidden_size |
|
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) |
|
|
if act_fn == "gelu_tanh": |
|
|
self.act_1 = nn.GELU(approximate="tanh") |
|
|
elif act_fn == "silu": |
|
|
self.act_1 = nn.SiLU() |
|
|
elif act_fn == "silu_fp32": |
|
|
self.act_1 = FP32SiLU() |
|
|
else: |
|
|
raise ValueError(f"Unknown activation function: {act_fn}") |
|
|
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) |
|
|
|
|
|
self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5)) |
|
|
|
|
|
def forward(self, caption): |
|
|
hidden_states = self.linear_1(caption) |
|
|
hidden_states = self.act_1(hidden_states) |
|
|
hidden_states = self.linear_2(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
class UNIPosEmbed(nn.Module): |
|
|
""" |
|
|
Adds positional embeddings to the UNI conditions. |
|
|
|
|
|
Args: |
|
|
height (`int`, defaults to `224`): The height of the image. |
|
|
width (`int`, defaults to `224`): The width of the image. |
|
|
patch_size (`int`, defaults to `16`): The size of the patches. |
|
|
in_channels (`int`, defaults to `3`): The number of input channels. |
|
|
embed_dim (`int`, defaults to `768`): The output dimension of the embedding. |
|
|
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization. |
|
|
flatten (`bool`, defaults to `True`): Whether or not to flatten the output. |
|
|
bias (`bool`, defaults to `True`): Whether or not to use bias. |
|
|
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation. |
|
|
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding. |
|
|
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
height=1, |
|
|
width=1, |
|
|
base_size=16, |
|
|
embed_dim=768, |
|
|
interpolation_scale=1, |
|
|
pos_embed_type="sincos", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
num_embeds = height*width |
|
|
grid_size = int(num_embeds ** 0.5) |
|
|
|
|
|
if pos_embed_type == "sincos": |
|
|
y_pos_embed = pixcell_get_2d_sincos_pos_embed( |
|
|
embed_dim, |
|
|
grid_size, |
|
|
base_size=base_size, |
|
|
interpolation_scale=interpolation_scale, |
|
|
output_type="pt", |
|
|
phase = base_size // num_embeds |
|
|
) |
|
|
self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0)) |
|
|
else: |
|
|
raise ValueError("`pos_embed_type` not supported") |
|
|
|
|
|
def forward(self, uni_embeds): |
|
|
return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype) |
|
|
|
|
|
|