5090_test / custom_code /fastvideo /layers /visual_embedding.py
yitongl's picture
Upload FastVideo 5090 safetensors checkpoint2950
d4cc469 verified
# SPDX-License-Identifier: Apache-2.0
import math
import torch
import torch.nn as nn
from fastvideo.layers.activation import get_act_fn
from fastvideo.layers.linear import ReplicatedLinear
from fastvideo.layers.mlp import MLP
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
prefix: str = ""):
super().__init__()
# Convert patch_size to 2-tuple
if isinstance(patch_size, list | tuple):
if len(patch_size) == 1:
patch_size = (patch_size[0], patch_size[0])
else:
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class WanCamControlPatchEmbedding(nn.Module):
"""Lingbot World Patch embedding for Plucker features."""
def __init__(
self,
patch_size=(1, 2, 2),
in_chans=384, # 6 * 64
embed_dim=2048,
bias=True,
dtype=None,
prefix: str = ""):
super().__init__()
# must be 3-tuple
if isinstance(patch_size, list | tuple):
if len(patch_size) != 3:
raise ValueError(f"patch_size must have length 3, got {len(patch_size)}")
else:
raise ValueError(f"Unsupported patch_size type: {type(patch_size)}")
self.patch_size = patch_size
pt, ph, pw = self.patch_size
self.in_features = in_chans * pt * ph * pw
self.proj = nn.Linear(self.in_features, embed_dim, bias=bias, dtype=dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 5:
raise ValueError(f"Expected camera embedding shape [B, C, F, H, W], got {x.shape}")
bsz, channels, frames, height, width = x.shape
pt, ph, pw = self.patch_size
if (frames % pt) != 0 or (height % ph) != 0 or (width % pw) != 0:
raise ValueError(f"Input shape {x.shape} must be divisible by patch_size {self.patch_size}")
# '1 c (f c1) (h c2) (w c3) -> 1 (f h w) (c c1 c2 c3)',
x = x.view(
bsz,
channels,
frames // pt,
pt,
height // ph,
ph,
width // pw,
pw,
)
x = x.permute(0, 2, 4, 6, 1, 3, 5, 7).reshape(bsz, -1, self.in_features)
return self.proj(x)
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(
self,
hidden_size,
act_layer="silu",
frequency_embedding_size=256,
max_period=10000,
dtype=None,
freq_dtype=torch.float32,
prefix: str = "",
):
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
self.mlp = MLP(frequency_embedding_size, hidden_size, hidden_size, act_type=act_layer, dtype=dtype)
self.freq_dtype = freq_dtype
def forward(self, t: torch.Tensor, timestep_seq_len: int | None = None) -> torch.Tensor:
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period,
dtype=self.freq_dtype).to(self.mlp.fc_in.weight.dtype)
if timestep_seq_len is not None:
t_freq = t_freq.unflatten(0, (1, timestep_seq_len))
# t_freq = t_freq.to(self.mlp.fc_in.weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
def timestep_embedding(t: torch.Tensor,
dim: int,
max_period: int = 10000,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
Args:
t: Tensor of shape [B] with timesteps
dim: Embedding dimension
max_period: Controls the minimum frequency of the embeddings
Returns:
Tensor of shape [B, dim] with embeddings
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=dtype) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class ModulateProjection(nn.Module):
"""Modulation layer for DiT blocks."""
def __init__(
self,
hidden_size: int,
factor: int = 2,
act_layer: str = "silu",
dtype: torch.dtype | None = None,
prefix: str = "",
):
super().__init__()
self.factor = factor
self.hidden_size = hidden_size
self.linear = ReplicatedLinear(hidden_size, hidden_size * factor, bias=True, params_dtype=dtype)
self.act = get_act_fn(act_layer)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.act(x)
x, _ = self.linear(x)
return x
def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor:
"""
Convert patched representation back to image space.
Args:
x: Tensor of shape [B, T*H*W, C*P_t*P_h*P_w]
t, h, w: Temporal and spatial dimensions
Returns:
Unpatchified tensor of shape [B, C, T*P_t, H*P_h, W*P_w]
"""
assert x.ndim == 3, f"x.ndim: {x.ndim}"
assert len(patch_size) == 3, f"patch_size: {patch_size}"
assert t * h * w == x.shape[1], f"t * h * w: {t * h * w}, x.shape[1]: {x.shape[1]}"
c = channels
pt, ph, pw = patch_size
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb