JiT-AnimeFace-Demo / model /denoiser.py
Plat
init
4b08319
raw
history blame
24.2 kB
# Reference: https://github.com/LTH14/JiT/blob/main/model_jit.py
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
import torch.nn.functional as F
from .config import DenoiserConfig
# https://github.com/huggingface/diffusers/blob/66bf7ea5be7099c8a47b9cba135f276d55247447/src/diffusers/models/embeddings.py#L27
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class FP32RMSNorm(nn.RMSNorm):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return F.rms_norm(
hidden_states.to(torch.float32),
self.normalized_shape,
weight=self.weight,
eps=self.eps,
).to(hidden_states.dtype)
class BottleneckPatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(
self,
patch_size: int = 16,
in_channels: int = 3,
bottleneck_dim: int = 128,
hidden_dim: int = 768,
bias: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.bottleneck_dim = bottleneck_dim
self.hidden_dim = hidden_dim
self.bias = bias
self.proj_1 = nn.Conv2d(
in_channels,
bottleneck_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
self.proj_2 = nn.Conv2d(
bottleneck_dim,
hidden_dim,
kernel_size=1,
stride=1,
bias=bias,
)
def forward(self, image: torch.Tensor) -> torch.Tensor:
# B, C, H, W = image.shape
# [B, C, H, W]
# -> [B, bottleneck_dim, H/patch_size, W/patch_size] (proj_1)
# -> [B, hidden_dim, H/patch_size, W/patch_size] (proj_2)
# -> [B, hidden_dim, num_patches] (flatten)
# -> [B, num_patches, hidden_dim] (transpose)
patches = (
self.proj_2(
self.proj_1(image),
)
.flatten(2)
.transpose(1, 2)
)
return patches
class TimestepEmbedder(nn.Module):
def __init__(
self,
hidden_dim: int,
freq_embedding_size: int = 256,
):
super().__init__()
self.freq_embedding_size = freq_embedding_size
self.mlp = nn.Sequential(
nn.Linear(freq_embedding_size, hidden_dim, bias=True),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim, bias=True),
)
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
freq_emb = get_timestep_embedding(
timestep,
embedding_dim=self.freq_embedding_size,
flip_sin_to_cos=True,
downscale_freq_shift=0,
)
time_embed = self.mlp(freq_emb.to(dtype=self.mlp[0].weight.dtype))
return time_embed
def apply_rope(
inputs: torch.Tensor, # (batch_size, num_heads, seq_len, dim)
freqs_cis: torch.Tensor, # (batch_size, seq_len, dim//2) complex64
) -> torch.Tensor:
batch_size, num_heads, seq_len, dim = inputs.shape
with torch.autocast(device_type="cuda", enabled=False):
inputs_cis = torch.view_as_complex(
inputs.float().view(batch_size, num_heads, seq_len, dim // 2, 2)
)
freqs_cis = freqs_cis.unsqueeze(1) # (batch_size, 1, seq_len, dim//2)
output = torch.view_as_real(inputs_cis * freqs_cis).flatten(3)
return output.type_as(inputs)
class RopeEmbedder:
def __init__(
self,
rope_theta: float = 256.0, # ref: Z-Image
axes_dims: list[int] = [32, 64, 64], # text, height, width
axes_lens: list[int] = [256, 128, 128], # text, height, width
zero_centered: list[bool] = [False, True, True],
):
self.rope_theta = rope_theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.zero_centered = zero_centered
# text starts with 0, image axes are zero-centered
self.freqs_cis = self.precompute_freqs_cis(
theta=self.rope_theta,
dims=self.axes_dims,
lens=self.axes_lens,
zero_centered=self.zero_centered,
)
@staticmethod
def get_rope_freqs(
dim: int,
min_position: int = 0,
max_position: int = 128,
theta: float = 10000.0,
) -> torch.Tensor:
freqs = 1.0 / (
theta
** (
torch.arange(0, dim, 2, dtype=torch.float64, device=torch.device("cpu"))
/ dim
)
)
positions = torch.arange(
start=min_position,
end=max_position,
dtype=torch.float64,
device=torch.device("cpu"),
)
freqs = torch.outer(positions, freqs).float() # (max_position, dim//2)
# ↓pos, → dim//2
# [ min_position * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)]
# ...
# 0 * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)]
# 1 * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)]
# 2 * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)]
# ...
# max_position * [1/θ^(0/dim), 1/θ^(2/dim), 1/θ^(4/dim), ..., 1/θ^((dim-2)/dim)] ]
freqs_cis = torch.polar(
abs=torch.ones_like(freqs),
angle=freqs,
).to(torch.complex64) # (min_position~max_position, dim//2) complex64
# 大きさは変えずに回転を表す複素数
return freqs_cis
@staticmethod
def precompute_freqs_cis(
theta: float,
dims: list[int],
lens: list[int],
zero_centered: list[bool],
):
freqs_cis = []
for i, (dim, len_) in enumerate(zip(dims, lens)):
freq_cis = RopeEmbedder.get_rope_freqs(
dim=dim,
min_position=(len_ // 2) - len_ if zero_centered[i] else 0,
max_position=len_ // 2 if zero_centered[i] else len_,
theta=theta,
) # (len_, dim//2) complex64
freqs_cis.append(freq_cis)
return freqs_cis
# get frequencies for given position ids
def __call__(self, position_ids: torch.Tensor):
# move to device
freqs_cis = [fc.to(position_ids.device) for fc in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = (
position_ids[..., i : i + 1]
.repeat(
# match dimensions for each axis
1, # batch size?
1, # sequence length?
freqs_cis[i].shape[-1],
)
.to(torch.int64)
)
result.append(
torch.gather(
freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1),
dim=1,
index=index,
)
)
return torch.cat(result, dim=-1)
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = True,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_norm = FP32RMSNorm(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = FP32RMSNorm(self.head_dim) if qk_norm else nn.Identity()
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
self.to_k = nn.Linear(dim, dim, bias=qkv_bias)
self.to_v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_dropout = nn.Dropout(attn_dropout)
self.to_o = nn.Linear(dim, dim)
self.proj_dropout = nn.Dropout(proj_dropout)
def _pre_attn_reshape(self, x: torch.Tensor):
batch_size, seq_len, dim = x.shape
# [B, N, D] -> [B, N, num_heads, D/num_heads] -> [B, num_heads, N, D/num_heads]
x = x.view(
batch_size,
seq_len,
self.num_heads,
self.head_dim,
).permute(0, 2, 1, 3) # [B, num_heads, N, head_dim]
return x
def _post_attn_reshape(self, x: torch.Tensor):
batch_size, num_heads, seq_len, head_dim = x.shape
# [B, num_heads, N, head_dim] -> [B, N, num_heads, head_dim] -> [B, N, D]
x = (
x.permute(0, 2, 1, 3)
.contiguous()
.view(batch_size, seq_len, num_heads * head_dim)
)
return x
def forward(
self,
hidden_states: torch.Tensor,
rope_freqs: torch.Tensor,
mask: torch.Tensor | None = None, # 1: attend, 0: ignore
) -> torch.Tensor:
batch_size, seq_len, _dim = hidden_states.shape
# QKV
q = self.to_q(hidden_states)
k = self.to_k(hidden_states)
v = self.to_v(hidden_states)
q = self._pre_attn_reshape(q) # [B, num_heads, N, head_dim]
k = self._pre_attn_reshape(k)
v = self._pre_attn_reshape(v)
# QKNorm
q = self.q_norm(q)
k = self.k_norm(k)
q = apply_rope(q, rope_freqs)
k = apply_rope(k, rope_freqs)
if mask is not None:
# mask: (batch_size, seq_len) -> (batch_size, num_heads, seq_len, seq_len)
mask = (
mask.bool()
.view(batch_size, 1, 1, seq_len)
.expand(-1, self.num_heads, seq_len, -1)
)
attn = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_dropout.p if self.training else 0.0,
attn_mask=mask,
is_causal=False,
).to(hidden_states.dtype)
attn = self._post_attn_reshape(attn)
# output
out = self.to_o(attn)
out = self.proj_dropout(out)
return out
class SwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
hidden_dim = int(hidden_dim * 2 / 3)
self.w_1 = nn.Linear(dim, hidden_dim, bias=bias)
self.w_2 = nn.Linear(dim, hidden_dim, bias=bias)
self.w_3 = nn.Linear(hidden_dim, dim, bias=bias)
self.ffn_dropout = nn.Dropout(dropout)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x_1 = self.w_1(hidden_states)
x_2 = self.w_2(hidden_states)
x = F.silu(x_1) * x_2
x = self.w_3(self.ffn_dropout(x))
return x
class FinalLayer(nn.Module):
def __init__(
self,
hidden_dim: int,
mlp_ratio: float,
patch_size: int,
out_channels: int,
):
super().__init__()
self.norm_final = FP32RMSNorm(hidden_dim)
self.mlp = SwiGLU(
dim=hidden_dim,
hidden_dim=int(hidden_dim * mlp_ratio),
dropout=0.0,
bias=True,
)
self.linear = nn.Linear(
hidden_dim,
patch_size * patch_size * out_channels,
bias=True,
)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
x = self.norm_final(hidden_states)
x = self.mlp(x)
x = self.linear(x)
return x
class JiTBlock(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
ffn_dropout: float = 0.0,
qkv_bias: bool = True,
qk_norm: bool = True,
bias: bool = True,
):
super().__init__()
self.norm1 = FP32RMSNorm(hidden_dim, eps=1e-6)
self.attn = Attention(
dim=hidden_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
)
self.norm2 = FP32RMSNorm(hidden_dim)
self.mlp = SwiGLU(
dim=hidden_dim,
hidden_dim=int(hidden_dim * mlp_ratio),
dropout=ffn_dropout,
bias=bias,
)
def forward(
self,
hidden_states: torch.Tensor,
rope_freqs: torch.Tensor,
mask: torch.Tensor | None = None,
):
# attn
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
rope_freqs,
mask=mask,
)
# mlp
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
class JiT(nn.Module):
def __init__(self, config: DenoiserConfig):
super().__init__()
self.config = config
assert (config.hidden_size // config.num_heads) == sum(config.rope_axes_dims), (
"The sum of rope_axes_dims must equal to hidden_size / num_heads = head_dim."
)
self.num_axes = len(
config.rope_axes_dims
) # 0: image_index, 1: height, 2: width
# image patch embedder
self.patch_embedder = BottleneckPatchEmbed(
patch_size=config.patch_size,
in_channels=config.in_channels,
bottleneck_dim=config.bottleneck_dim,
hidden_dim=config.hidden_size,
bias=True,
)
# timestep embedder
self.time_embedder = TimestepEmbedder(
hidden_dim=config.hidden_size,
freq_embedding_size=256,
)
self.time_position_embeds = nn.Parameter(
torch.randn(
config.num_time_tokens,
config.hidden_size,
),
requires_grad=True,
)
# RoPE embedder
self.rope_embedder = RopeEmbedder(
rope_theta=config.rope_theta,
axes_dims=config.rope_axes_dims,
axes_lens=config.rope_axes_lens,
zero_centered=config.rope_zero_centered,
)
# class condition or text embedding
self.context_embedder = nn.Linear(
config.context_dim,
config.hidden_size,
bias=True,
)
self.blocks = nn.ModuleList(
[
JiTBlock(
hidden_dim=config.hidden_size,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
attn_dropout=config.attn_dropout,
proj_dropout=config.proj_dropout,
ffn_dropout=0.0,
qkv_bias=True,
qk_norm=True,
bias=True,
)
for _ in range(config.depth)
]
)
self.final_layer = FinalLayer(
hidden_dim=config.hidden_size,
mlp_ratio=config.mlp_ratio,
patch_size=config.patch_size,
out_channels=config.in_channels,
)
self.gradient_checkpointing = False
def initialize_weights(self):
# Initialize weights
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.RMSNorm):
nn.init.ones_(m.weight)
# patch embed
w_1 = self.patch_embedder.proj_1.weight
nn.init.xavier_uniform_(w_1.view([w_1.shape[0], -1]))
w_2 = self.patch_embedder.proj_2.weight
nn.init.xavier_uniform_(w_2.view([w_2.shape[0], -1]))
if self.patch_embedder.proj_2.bias is not None:
nn.init.zeros_(self.patch_embedder.proj_2.bias)
# time position embeds
nn.init.normal_(
self.time_position_embeds,
std=0.02,
)
# time embedder
nn.init.normal_(
self.time_embedder.mlp[0].weight, # type: ignore
std=0.02,
)
nn.init.normal_(
self.time_embedder.mlp[2].weight, # type: ignore
std=0.02,
)
def set_gradient_checkpointing(self, enable: bool = True):
self.gradient_checkpointing = enable
def prepare_image_position_ids(
self,
height: int,
width: int,
image_index: int,
) -> torch.Tensor:
# [H/patch_size, W/patch_size]
patch_size = self.config.patch_size
h_patches = height // patch_size
w_patches = width // patch_size
position_ids = torch.zeros(
h_patches,
w_patches,
self.num_axes,
)
# image_index
position_ids[:, :, 0] = image_index # image
# height (y-index)
position_ids[:, :, 1] = (
torch.arange(
h_patches,
)
.unsqueeze(1)
.repeat(1, w_patches)
)
# width (x-index)
position_ids[:, :, 2] = (
torch.arange(
w_patches,
)
.unsqueeze(0)
.repeat(h_patches, 1)
)
return position_ids.view(-1, self.num_axes) # (num_patches, n_axes)
def prepare_context_position_ids(
self,
seq_len: int,
context_start_index: int = 0,
xy_position: int = 0,
) -> torch.Tensor:
position_ids = torch.zeros(
seq_len,
self.num_axes,
)
# context_index (0, ..., seq_len-1)
position_ids[:, 0] = torch.arange(
context_start_index,
context_start_index + seq_len,
) # text
# token indices are (0, 0)...(0, 0)
position_ids[:, 1] = xy_position
position_ids[:, 2] = xy_position
return position_ids
def prepare_time_position_ids(
self,
seq_len: int,
time_start_index: int,
xy_position: int = 0,
) -> torch.Tensor:
position_ids = torch.zeros(
seq_len,
self.num_axes,
)
# time_index
position_ids[:, 0] = torch.arange(
time_start_index, time_start_index + seq_len
) # time
# token indices are (0, 0)...(0, 0)
position_ids[:, 1] = xy_position
position_ids[:, 2] = xy_position
return position_ids
def unpatchify(
self,
patches: torch.Tensor,
height: int,
width: int,
) -> torch.Tensor:
batch_size, num_patches, _patch_dim = patches.shape
patch_size = self.config.patch_size
out_channels = self.config.out_channels
h_patches = height // patch_size
w_patches = width // patch_size
assert num_patches == h_patches * w_patches, "Mismatch in number of patches"
# [B, N, patch_size*patch_size*C] -> [B, H_patch, W_patch, patch_size, patch_size, C]
patches = patches.view(
batch_size,
h_patches,
w_patches,
patch_size,
patch_size,
out_channels,
)
# [B, H_patch, W_patch, patch_size, patch_size, C]
# -> [B, C, H_patch, patch_size, W_patch, patch_size]
patches = patches.permute(0, 5, 1, 3, 2, 4)
# -> [B, C, H_img, W_img]
images = patches.reshape(batch_size, out_channels, height, width)
return images
def forward(
self,
image: torch.Tensor, # [B, C, H, W]
timestep: torch.Tensor, # [B]
context: torch.Tensor, # [B, context_len, context_dim]
context_mask: torch.Tensor | None = None, # [B, context_len]
):
batch_size, _in_channels, height, width = image.shape
time_embed: torch.Tensor = self.time_embedder(timestep) # [B, hidden_dim]
time_tokens = time_embed.unsqueeze(1).repeat( # add seq_len dim
1,
self.time_position_embeds.shape[0], # num_time_tokens
1,
) + self.time_position_embeds.unsqueeze(0).repeat( # add batch dim
batch_size,
1,
1,
) # [B, num_time_tokens, hidden_dim]
num_time_tokens = time_tokens.shape[1]
context_embed = self.context_embedder(context)
context_len = context_embed.shape[1]
patches = self.patch_embedder(image) # [B, N, hidden_dim]]
patches_len = patches.shape[1]
# context -> time -> patches
context_position_ids = self.prepare_context_position_ids(
seq_len=context_len,
context_start_index=0,
)
time_position_ids = self.prepare_time_position_ids(
seq_len=num_time_tokens,
time_start_index=context_len,
)
patches_position_ids = self.prepare_image_position_ids(
height=height,
width=width,
image_index=context_len + num_time_tokens, # after context and time tokens
)
# actually: patches -> time -> context
position_ids = torch.cat(
[
patches_position_ids,
time_position_ids,
context_position_ids,
],
dim=0,
).view(1, -1, self.num_axes) # (1, total_seq_len, n_axes)
# prepare RoPE
freqs_cis = (
self.rope_embedder(position_ids=position_ids)
.repeat(
batch_size,
1,
1,
)
.to(device=image.device)
)
# attention mask
if context_mask is not None:
patches_mask = torch.ones(batch_size, patches_len, device=image.device)
time_mask = torch.ones(batch_size, num_time_tokens, device=image.device)
mask = torch.cat(
[
patches_mask,
time_mask,
context_mask.to(image.device),
],
dim=1,
)
else:
# attend all
mask = torch.ones(
batch_size,
patches_len + num_time_tokens + context_len,
device=image.device,
)
for _i, block in enumerate(self.blocks):
tokens = torch.cat(
[
patches, # 16x16
time_tokens, # 4
context_embed, # 64
],
dim=1, # cat in seq_len dimension
)
if self.gradient_checkpointing and self.training:
patches = checkpoint.checkpoint( # type: ignore
block,
tokens,
freqs_cis,
mask,
)[:, :patches_len, :]
else:
patches = block(
tokens,
rope_freqs=freqs_cis,
mask=mask,
)[:, :patches_len, :] # only keep patch tokens
patches = self.final_layer(patches)
pred_image = self.unpatchify(
patches,
height=height,
width=width,
)
return pred_image