|
|
|
|
|
import math |
|
|
from typing import Any, Dict, List, Literal, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.cuda.amp as amp |
|
|
import torch.nn as nn |
|
|
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
|
|
from diffusers.utils import BaseOutput, is_torch_version |
|
|
from einops import rearrange, repeat |
|
|
|
|
|
from ..model import flash_attention |
|
|
from .s2v_utils import rope_precompute |
|
|
|
|
|
|
|
|
def sinusoidal_embedding_1d(dim, position): |
|
|
|
|
|
assert dim % 2 == 0 |
|
|
half = dim // 2 |
|
|
position = position.type(torch.float64) |
|
|
|
|
|
|
|
|
sinusoid = torch.outer( |
|
|
position, torch.pow(10000, -torch.arange(half).to(position).div(half))) |
|
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) |
|
|
return x |
|
|
|
|
|
|
|
|
@amp.autocast(enabled=False) |
|
|
def rope_params(max_seq_len, dim, theta=10000): |
|
|
assert dim % 2 == 0 |
|
|
freqs = torch.outer( |
|
|
torch.arange(max_seq_len), |
|
|
1.0 / torch.pow(theta, |
|
|
torch.arange(0, dim, 2).to(torch.float64).div(dim))) |
|
|
freqs = torch.polar(torch.ones_like(freqs), freqs) |
|
|
return freqs |
|
|
|
|
|
|
|
|
@amp.autocast(enabled=False) |
|
|
def rope_apply(x, grid_sizes, freqs, start=None): |
|
|
n, c = x.size(2), x.size(3) // 2 |
|
|
|
|
|
|
|
|
if type(freqs) is list: |
|
|
trainable_freqs = freqs[1] |
|
|
freqs = freqs[0] |
|
|
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) |
|
|
|
|
|
|
|
|
output = [] |
|
|
output = x.clone() |
|
|
seq_bucket = [0] |
|
|
if not type(grid_sizes) is list: |
|
|
grid_sizes = [grid_sizes] |
|
|
for g in grid_sizes: |
|
|
if not type(g) is list: |
|
|
g = [torch.zeros_like(g), g] |
|
|
batch_size = g[0].shape[0] |
|
|
for i in range(batch_size): |
|
|
if start is None: |
|
|
f_o, h_o, w_o = g[0][i] |
|
|
else: |
|
|
f_o, h_o, w_o = start[i] |
|
|
|
|
|
f, h, w = g[1][i] |
|
|
t_f, t_h, t_w = g[2][i] |
|
|
seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o |
|
|
seq_len = int(seq_f * seq_h * seq_w) |
|
|
if seq_len > 0: |
|
|
if t_f > 0: |
|
|
factor_f, factor_h, factor_w = (t_f / seq_f).item(), ( |
|
|
t_h / seq_h).item(), (t_w / seq_w).item() |
|
|
|
|
|
if f_o >= 0: |
|
|
f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1, |
|
|
seq_f).astype(int).tolist() |
|
|
else: |
|
|
f_sam = np.linspace(-f_o.item(), |
|
|
(-t_f - f_o).item() + 1, |
|
|
seq_f).astype(int).tolist() |
|
|
h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1, |
|
|
seq_h).astype(int).tolist() |
|
|
w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1, |
|
|
seq_w).astype(int).tolist() |
|
|
|
|
|
assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0 |
|
|
freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][ |
|
|
f_sam].conj() |
|
|
freqs_0 = freqs_0.view(seq_f, 1, 1, -1) |
|
|
|
|
|
freqs_i = torch.cat([ |
|
|
freqs_0.expand(seq_f, seq_h, seq_w, -1), |
|
|
freqs[1][h_sam].view(1, seq_h, 1, -1).expand( |
|
|
seq_f, seq_h, seq_w, -1), |
|
|
freqs[2][w_sam].view(1, 1, seq_w, -1).expand( |
|
|
seq_f, seq_h, seq_w, -1), |
|
|
], |
|
|
dim=-1).reshape(seq_len, 1, -1) |
|
|
elif t_f < 0: |
|
|
freqs_i = trainable_freqs.unsqueeze(1) |
|
|
|
|
|
|
|
|
x_i = torch.view_as_complex( |
|
|
x[i, seq_bucket[-1]:seq_bucket[-1] + seq_len].to( |
|
|
torch.float64).reshape(seq_len, n, -1, 2)) |
|
|
x_i = torch.view_as_real(x_i * freqs_i).flatten(2) |
|
|
output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = x_i |
|
|
seq_bucket.append(seq_bucket[-1] + seq_len) |
|
|
return output.float() |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
|
|
|
def __init__(self, dim, eps=1e-5): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
return self._norm(x.float()).type_as(x) * self.weight |
|
|
|
|
|
def _norm(self, x): |
|
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
|
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
|
|
|
|
def __init__(self, dim, eps=1e-6, elementwise_affine=False): |
|
|
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) |
|
|
|
|
|
def forward(self, x): |
|
|
return super().forward(x.float()).type_as(x) |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
dim, |
|
|
num_heads, |
|
|
window_size=(-1, -1), |
|
|
qk_norm=True, |
|
|
eps=1e-6): |
|
|
assert dim % num_heads == 0 |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.window_size = window_size |
|
|
self.qk_norm = qk_norm |
|
|
self.eps = eps |
|
|
|
|
|
|
|
|
self.q = nn.Linear(dim, dim) |
|
|
self.k = nn.Linear(dim, dim) |
|
|
self.v = nn.Linear(dim, dim) |
|
|
self.o = nn.Linear(dim, dim) |
|
|
self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() |
|
|
self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() |
|
|
|
|
|
def forward(self, x, seq_lens, grid_sizes, freqs): |
|
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
|
|
|
|
|
|
|
|
def qkv_fn(x): |
|
|
q = self.norm_q(self.q(x)).view(b, s, n, d) |
|
|
k = self.norm_k(self.k(x)).view(b, s, n, d) |
|
|
v = self.v(x).view(b, s, n, d) |
|
|
return q, k, v |
|
|
|
|
|
q, k, v = qkv_fn(x) |
|
|
|
|
|
x = flash_attention( |
|
|
q=rope_apply(q, grid_sizes, freqs), |
|
|
k=rope_apply(k, grid_sizes, freqs), |
|
|
v=v, |
|
|
k_lens=seq_lens, |
|
|
window_size=self.window_size) |
|
|
|
|
|
|
|
|
x = x.flatten(2) |
|
|
x = self.o(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class SwinSelfAttention(SelfAttention): |
|
|
|
|
|
def forward(self, x, seq_lens, grid_sizes, freqs): |
|
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
|
|
assert b == 1, 'Only support batch_size 1' |
|
|
|
|
|
|
|
|
def qkv_fn(x): |
|
|
q = self.norm_q(self.q(x)).view(b, s, n, d) |
|
|
k = self.norm_k(self.k(x)).view(b, s, n, d) |
|
|
v = self.v(x).view(b, s, n, d) |
|
|
return q, k, v |
|
|
|
|
|
q, k, v = qkv_fn(x) |
|
|
|
|
|
q = rope_apply(q, grid_sizes, freqs) |
|
|
k = rope_apply(k, grid_sizes, freqs) |
|
|
T, H, W = grid_sizes[0].tolist() |
|
|
|
|
|
q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
|
|
k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
|
|
v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
|
|
|
|
|
ref_q = q[-1:] |
|
|
q = q[:-1] |
|
|
|
|
|
ref_k = repeat( |
|
|
k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) |
|
|
k = k[:-1] |
|
|
k = torch.cat([k[:1], k, k[-1:]]) |
|
|
k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) |
|
|
|
|
|
ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1) |
|
|
v = v[:-1] |
|
|
v = torch.cat([v[:1], v, v[-1:]]) |
|
|
v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
out = flash_attention( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
|
|
|
window_size=self.window_size) |
|
|
out = torch.cat([out, ref_v[:1]], axis=0) |
|
|
out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W) |
|
|
x = out |
|
|
|
|
|
|
|
|
x = x.flatten(2) |
|
|
x = self.o(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CasualSelfAttention(SelfAttention): |
|
|
|
|
|
def forward(self, x, seq_lens, grid_sizes, freqs): |
|
|
shifting = 3 |
|
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
|
|
assert b == 1, 'Only support batch_size 1' |
|
|
|
|
|
|
|
|
def qkv_fn(x): |
|
|
q = self.norm_q(self.q(x)).view(b, s, n, d) |
|
|
k = self.norm_k(self.k(x)).view(b, s, n, d) |
|
|
v = self.v(x).view(b, s, n, d) |
|
|
return q, k, v |
|
|
|
|
|
q, k, v = qkv_fn(x) |
|
|
|
|
|
T, H, W = grid_sizes[0].tolist() |
|
|
|
|
|
q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
|
|
k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
|
|
v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
|
|
|
|
|
ref_q = q[-1:] |
|
|
q = q[:-1] |
|
|
|
|
|
grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long) |
|
|
start = [[shifting, 0, 0]] * q.shape[0] |
|
|
q = rope_apply(q, grid_sizes, freqs, start=start) |
|
|
|
|
|
ref_k = k[-1:] |
|
|
grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long) |
|
|
|
|
|
|
|
|
start = [[shifting + 10, 0, 0]] |
|
|
ref_k = rope_apply(ref_k, grid_sizes, freqs, start) |
|
|
ref_k = repeat( |
|
|
ref_k, "1 s n d -> t s n d", t=k.shape[0] - 1) |
|
|
|
|
|
k = k[:-1] |
|
|
k = torch.cat([*([k[:1]] * shifting), k]) |
|
|
cat_k = [] |
|
|
for i in range(shifting): |
|
|
cat_k.append(k[i:i - shifting]) |
|
|
cat_k.append(k[shifting:]) |
|
|
k = torch.cat(cat_k, dim=1) |
|
|
|
|
|
grid_sizes = torch.tensor( |
|
|
[[shifting + 1, H, W]] * q.shape[0], dtype=torch.long) |
|
|
k = rope_apply(k, grid_sizes, freqs) |
|
|
k = torch.cat([k, ref_k], dim=1) |
|
|
|
|
|
ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=q.shape[0]) |
|
|
v = v[:-1] |
|
|
v = torch.cat([*([v[:1]] * shifting), v]) |
|
|
cat_v = [] |
|
|
for i in range(shifting): |
|
|
cat_v.append(v[i:i - shifting]) |
|
|
cat_v.append(v[shifting:]) |
|
|
v = torch.cat(cat_v, dim=1) |
|
|
v = torch.cat([v, ref_v], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
outs = [] |
|
|
for i in range(q.shape[0]): |
|
|
out = flash_attention( |
|
|
q=q[i:i + 1], |
|
|
k=k[i:i + 1], |
|
|
v=v[i:i + 1], |
|
|
window_size=self.window_size) |
|
|
outs.append(out) |
|
|
out = torch.cat(outs, dim=0) |
|
|
out = torch.cat([out, ref_v[:1]], axis=0) |
|
|
out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W) |
|
|
x = out |
|
|
|
|
|
|
|
|
x = x.flatten(2) |
|
|
x = self.o(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class MotionerAttentionBlock(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
dim, |
|
|
ffn_dim, |
|
|
num_heads, |
|
|
window_size=(-1, -1), |
|
|
qk_norm=True, |
|
|
cross_attn_norm=False, |
|
|
eps=1e-6, |
|
|
self_attn_block="SelfAttention"): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.ffn_dim = ffn_dim |
|
|
self.num_heads = num_heads |
|
|
self.window_size = window_size |
|
|
self.qk_norm = qk_norm |
|
|
self.cross_attn_norm = cross_attn_norm |
|
|
self.eps = eps |
|
|
|
|
|
|
|
|
self.norm1 = LayerNorm(dim, eps) |
|
|
if self_attn_block == "SelfAttention": |
|
|
self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm, |
|
|
eps) |
|
|
elif self_attn_block == "SwinSelfAttention": |
|
|
self.self_attn = SwinSelfAttention(dim, num_heads, window_size, |
|
|
qk_norm, eps) |
|
|
elif self_attn_block == "CasualSelfAttention": |
|
|
self.self_attn = CasualSelfAttention(dim, num_heads, window_size, |
|
|
qk_norm, eps) |
|
|
|
|
|
self.norm2 = LayerNorm(dim, eps) |
|
|
self.ffn = nn.Sequential( |
|
|
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'), |
|
|
nn.Linear(ffn_dim, dim)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
seq_lens, |
|
|
grid_sizes, |
|
|
freqs, |
|
|
): |
|
|
|
|
|
y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs) |
|
|
x = x + y |
|
|
y = self.ffn(self.norm2(x).float()) |
|
|
x = x + y |
|
|
return x |
|
|
|
|
|
|
|
|
class Head(nn.Module): |
|
|
|
|
|
def __init__(self, dim, out_dim, patch_size, eps=1e-6): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.out_dim = out_dim |
|
|
self.patch_size = patch_size |
|
|
self.eps = eps |
|
|
|
|
|
|
|
|
out_dim = math.prod(patch_size) * out_dim |
|
|
self.norm = LayerNorm(dim, eps) |
|
|
self.head = nn.Linear(dim, out_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.head(self.norm(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class MotionerTransformers(nn.Module, PeftAdapterMixin): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
patch_size=(1, 2, 2), |
|
|
in_dim=16, |
|
|
dim=2048, |
|
|
ffn_dim=8192, |
|
|
freq_dim=256, |
|
|
out_dim=16, |
|
|
num_heads=16, |
|
|
num_layers=32, |
|
|
window_size=(-1, -1), |
|
|
qk_norm=True, |
|
|
cross_attn_norm=False, |
|
|
eps=1e-6, |
|
|
self_attn_block="SelfAttention", |
|
|
motion_token_num=1024, |
|
|
enable_tsm=False, |
|
|
motion_stride=4, |
|
|
expand_ratio=2, |
|
|
trainable_token_pos_emb=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
self.in_dim = in_dim |
|
|
self.dim = dim |
|
|
self.ffn_dim = ffn_dim |
|
|
self.freq_dim = freq_dim |
|
|
self.out_dim = out_dim |
|
|
self.num_heads = num_heads |
|
|
self.num_layers = num_layers |
|
|
self.window_size = window_size |
|
|
self.qk_norm = qk_norm |
|
|
self.cross_attn_norm = cross_attn_norm |
|
|
self.eps = eps |
|
|
|
|
|
self.enable_tsm = enable_tsm |
|
|
self.motion_stride = motion_stride |
|
|
self.expand_ratio = expand_ratio |
|
|
self.sample_c = self.patch_size[0] |
|
|
|
|
|
|
|
|
self.patch_embedding = nn.Conv3d( |
|
|
in_dim, dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
MotionerAttentionBlock( |
|
|
dim, |
|
|
ffn_dim, |
|
|
num_heads, |
|
|
window_size, |
|
|
qk_norm, |
|
|
cross_attn_norm, |
|
|
eps, |
|
|
self_attn_block=self_attn_block) for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 |
|
|
d = dim // num_heads |
|
|
self.freqs = torch.cat([ |
|
|
rope_params(1024, d - 4 * (d // 6)), |
|
|
rope_params(1024, 2 * (d // 6)), |
|
|
rope_params(1024, 2 * (d // 6)) |
|
|
], |
|
|
dim=1) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.motion_side_len = int(math.sqrt(motion_token_num)) |
|
|
assert self.motion_side_len**2 == motion_token_num |
|
|
self.token = nn.Parameter( |
|
|
torch.zeros(1, motion_token_num, dim).contiguous()) |
|
|
|
|
|
self.trainable_token_pos_emb = trainable_token_pos_emb |
|
|
if trainable_token_pos_emb: |
|
|
x = torch.zeros([1, motion_token_num, num_heads, d]) |
|
|
x[..., ::2] = 1 |
|
|
|
|
|
gride_sizes = [[ |
|
|
torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([1, self.motion_side_len, |
|
|
self.motion_side_len]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([1, self.motion_side_len, |
|
|
self.motion_side_len]).unsqueeze(0).repeat(1, 1), |
|
|
]] |
|
|
token_freqs = rope_apply(x, gride_sizes, self.freqs) |
|
|
token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2) |
|
|
token_freqs = token_freqs * 0.01 |
|
|
self.token_freqs = torch.nn.Parameter(token_freqs) |
|
|
|
|
|
def after_patch_embedding(self, x): |
|
|
return x |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
): |
|
|
""" |
|
|
x: A list of videos each with shape [C, T, H, W]. |
|
|
t: [B]. |
|
|
context: A list of text embeddings each with shape [L, C]. |
|
|
""" |
|
|
|
|
|
motion_frames = x[0].shape[1] |
|
|
device = self.patch_embedding.weight.device |
|
|
freqs = self.freqs |
|
|
if freqs.device != device: |
|
|
freqs = freqs.to(device) |
|
|
|
|
|
if self.trainable_token_pos_emb: |
|
|
with amp.autocast(dtype=torch.float64): |
|
|
token_freqs = self.token_freqs.to(torch.float64) |
|
|
token_freqs = token_freqs / token_freqs.norm( |
|
|
dim=-1, keepdim=True) |
|
|
freqs = [freqs, torch.view_as_complex(token_freqs)] |
|
|
|
|
|
if self.enable_tsm: |
|
|
sample_idx = [ |
|
|
sample_indices( |
|
|
u.shape[1], |
|
|
stride=self.motion_stride, |
|
|
expand_ratio=self.expand_ratio, |
|
|
c=self.sample_c) for u in x |
|
|
] |
|
|
x = [ |
|
|
torch.flip(torch.flip(u, [1])[:, idx], [1]) |
|
|
for idx, u in zip(sample_idx, x) |
|
|
] |
|
|
|
|
|
|
|
|
x = [self.patch_embedding(u.unsqueeze(0)) for u in x] |
|
|
x = self.after_patch_embedding(x) |
|
|
|
|
|
seq_f, seq_h, seq_w = x[0].shape[-3:] |
|
|
batch_size = len(x) |
|
|
if not self.enable_tsm: |
|
|
grid_sizes = torch.stack( |
|
|
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) |
|
|
grid_sizes = [[ |
|
|
torch.zeros_like(grid_sizes), grid_sizes, grid_sizes |
|
|
]] |
|
|
seq_f = 0 |
|
|
else: |
|
|
grid_sizes = [] |
|
|
for idx in sample_idx[0][::-1][::self.sample_c]: |
|
|
tsm_frame_grid_sizes = [[ |
|
|
torch.tensor([idx, 0, |
|
|
0]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor([idx + 1, seq_h, |
|
|
seq_w]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor([1, seq_h, |
|
|
seq_w]).unsqueeze(0).repeat(batch_size, 1), |
|
|
]] |
|
|
grid_sizes += tsm_frame_grid_sizes |
|
|
seq_f = sample_idx[0][-1] + 1 |
|
|
|
|
|
x = [u.flatten(2).transpose(1, 2) for u in x] |
|
|
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) |
|
|
x = torch.cat([u for u in x]) |
|
|
|
|
|
batch_size = len(x) |
|
|
|
|
|
token_grid_sizes = [[ |
|
|
torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor( |
|
|
[seq_f + 1, self.motion_side_len, |
|
|
self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1), |
|
|
torch.tensor( |
|
|
[1 if not self.trainable_token_pos_emb else -1, seq_h, |
|
|
seq_w]).unsqueeze(0).repeat(batch_size, 1), |
|
|
] |
|
|
] |
|
|
|
|
|
grid_sizes = grid_sizes + token_grid_sizes |
|
|
token_unpatch_grid_sizes = torch.stack([ |
|
|
torch.tensor([1, 32, 32], dtype=torch.long) |
|
|
for b in range(batch_size) |
|
|
]) |
|
|
token_len = self.token.shape[1] |
|
|
token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous() |
|
|
seq_lens = seq_lens + torch.tensor([t.size(0) for t in token], |
|
|
dtype=torch.long) |
|
|
x = torch.cat([x, token], dim=1) |
|
|
|
|
|
kwargs = dict( |
|
|
seq_lens=seq_lens, |
|
|
grid_sizes=grid_sizes, |
|
|
freqs=freqs, |
|
|
) |
|
|
|
|
|
|
|
|
def create_custom_forward(module, return_dict=None): |
|
|
|
|
|
def custom_forward(*inputs, **kwargs): |
|
|
if return_dict is not None: |
|
|
return module(*inputs, **kwargs, return_dict=return_dict) |
|
|
else: |
|
|
return module(*inputs, **kwargs) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
ckpt_kwargs: Dict[str, Any] = ({ |
|
|
"use_reentrant": False |
|
|
} if is_torch_version(">=", "1.11.0") else {}) |
|
|
|
|
|
for idx, block in enumerate(self.blocks): |
|
|
if self.training and self.gradient_checkpointing: |
|
|
x = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
x, |
|
|
**kwargs, |
|
|
**ckpt_kwargs, |
|
|
) |
|
|
else: |
|
|
x = block(x, **kwargs) |
|
|
|
|
|
out = x[:, -token_len:] |
|
|
return out |
|
|
|
|
|
def unpatchify(self, x, grid_sizes): |
|
|
c = self.out_dim |
|
|
out = [] |
|
|
for u, v in zip(x, grid_sizes.tolist()): |
|
|
u = u[:math.prod(v)].view(*v, *self.patch_size, c) |
|
|
u = torch.einsum('fhwpqrc->cfphqwr', u) |
|
|
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) |
|
|
out.append(u) |
|
|
return out |
|
|
|
|
|
def init_weights(self): |
|
|
|
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
|
|
|
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) |
|
|
|
|
|
|
|
|
class FramePackMotioner(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
inner_dim=1024, |
|
|
num_heads=16, |
|
|
zip_frame_buckets=[ |
|
|
1, 2, 16 |
|
|
], |
|
|
drop_mode="drop", |
|
|
*args, |
|
|
**kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.proj = nn.Conv3d( |
|
|
16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) |
|
|
self.proj_2x = nn.Conv3d( |
|
|
16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) |
|
|
self.proj_4x = nn.Conv3d( |
|
|
16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) |
|
|
self.zip_frame_buckets = torch.tensor( |
|
|
zip_frame_buckets, dtype=torch.long) |
|
|
|
|
|
self.inner_dim = inner_dim |
|
|
self.num_heads = num_heads |
|
|
|
|
|
assert (inner_dim % |
|
|
num_heads) == 0 and (inner_dim // num_heads) % 2 == 0 |
|
|
d = inner_dim // num_heads |
|
|
self.freqs = torch.cat([ |
|
|
rope_params(1024, d - 4 * (d // 6)), |
|
|
rope_params(1024, 2 * (d // 6)), |
|
|
rope_params(1024, 2 * (d // 6)) |
|
|
], |
|
|
dim=1) |
|
|
self.drop_mode = drop_mode |
|
|
|
|
|
def forward(self, motion_latents, add_last_motion=2): |
|
|
motion_frames = motion_latents[0].shape[1] |
|
|
mot = [] |
|
|
mot_remb = [] |
|
|
for m in motion_latents: |
|
|
lat_height, lat_width = m.shape[2], m.shape[3] |
|
|
padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height, |
|
|
lat_width).to( |
|
|
device=m.device, dtype=m.dtype) |
|
|
overlap_frame = min(padd_lat.shape[1], m.shape[1]) |
|
|
if overlap_frame > 0: |
|
|
padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:] |
|
|
|
|
|
if add_last_motion < 2 and self.drop_mode != "drop": |
|
|
zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets. |
|
|
__len__() - |
|
|
add_last_motion - |
|
|
1].sum() |
|
|
padd_lat[:, -zero_end_frame:] = 0 |
|
|
|
|
|
padd_lat = padd_lat.unsqueeze(0) |
|
|
clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum( |
|
|
):, :, :].split( |
|
|
list(self.zip_frame_buckets)[::-1], dim=2) |
|
|
|
|
|
|
|
|
clean_latents_post = self.proj(clean_latents_post).flatten( |
|
|
2).transpose(1, 2) |
|
|
clean_latents_2x = self.proj_2x(clean_latents_2x).flatten( |
|
|
2).transpose(1, 2) |
|
|
clean_latents_4x = self.proj_4x(clean_latents_4x).flatten( |
|
|
2).transpose(1, 2) |
|
|
|
|
|
if add_last_motion < 2 and self.drop_mode == "drop": |
|
|
clean_latents_post = clean_latents_post[:, : |
|
|
0] if add_last_motion < 2 else clean_latents_post |
|
|
clean_latents_2x = clean_latents_2x[:, : |
|
|
0] if add_last_motion < 1 else clean_latents_2x |
|
|
|
|
|
motion_lat = torch.cat( |
|
|
[clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1) |
|
|
|
|
|
|
|
|
start_time_id = -(self.zip_frame_buckets[:1].sum()) |
|
|
end_time_id = start_time_id + self.zip_frame_buckets[0] |
|
|
grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \ |
|
|
[ |
|
|
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] |
|
|
] |
|
|
|
|
|
start_time_id = -(self.zip_frame_buckets[:2].sum()) |
|
|
end_time_id = start_time_id + self.zip_frame_buckets[1] // 2 |
|
|
grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \ |
|
|
[ |
|
|
[torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ] |
|
|
] |
|
|
|
|
|
start_time_id = -(self.zip_frame_buckets[:3].sum()) |
|
|
end_time_id = start_time_id + self.zip_frame_buckets[2] // 4 |
|
|
grid_sizes_4x = [[ |
|
|
torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([end_time_id, lat_height // 8, |
|
|
lat_width // 8]).unsqueeze(0).repeat(1, 1), |
|
|
torch.tensor([ |
|
|
self.zip_frame_buckets[2], lat_height // 2, lat_width // 2 |
|
|
]).unsqueeze(0).repeat(1, 1), |
|
|
]] |
|
|
|
|
|
grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x |
|
|
|
|
|
motion_rope_emb = rope_precompute( |
|
|
motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads, |
|
|
self.inner_dim // self.num_heads), |
|
|
grid_sizes, |
|
|
self.freqs, |
|
|
start=None) |
|
|
|
|
|
mot.append(motion_lat) |
|
|
mot_remb.append(motion_rope_emb) |
|
|
return mot, mot_remb |
|
|
|
|
|
|
|
|
def sample_indices(N, stride, expand_ratio, c): |
|
|
indices = [] |
|
|
current_start = 0 |
|
|
|
|
|
while current_start < N: |
|
|
bucket_width = int(stride * (expand_ratio**(len(indices) / stride))) |
|
|
|
|
|
interval = int(bucket_width / stride * c) |
|
|
current_end = min(N, current_start + bucket_width) |
|
|
bucket_samples = [] |
|
|
for i in range(current_end - 1, current_start - 1, -interval): |
|
|
for near in range(c): |
|
|
bucket_samples.append(i - near) |
|
|
|
|
|
indices += bucket_samples[::-1] |
|
|
current_start += bucket_width |
|
|
|
|
|
return indices |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
device = "cuda" |
|
|
model = FramePackMotioner(inner_dim=1024) |
|
|
batch_size = 2 |
|
|
num_frame, height, width = (28, 32, 32) |
|
|
single_input = torch.ones([16, num_frame, height, width], device=device) |
|
|
for i in range(num_frame): |
|
|
single_input[:, num_frame - 1 - i] *= i |
|
|
x = [single_input] * batch_size |
|
|
model.forward(x) |
|
|
|