|
|
|
|
|
|
|
|
from typing import Dict, Optional, Tuple, Union |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange, repeat |
|
|
import torch.nn.functional as F |
|
|
from .misc import swish |
|
|
from infinity.models.videovae.modules.normalization import SpatialGroupNorm |
|
|
from infinity.models.videovae.utils.context_parallel import ContextParallelUtils as cp |
|
|
from infinity.models.videovae.utils.context_parallel import dist_conv_cache_send, dist_conv_cache_recv |
|
|
|
|
|
|
|
|
class DCDownBlock3d(nn.Module): |
|
|
def __init__(self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
shortcut: bool = True, |
|
|
group_norm=False, |
|
|
compress_time=False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.shortcut = shortcut |
|
|
self.compress_time = compress_time |
|
|
if group_norm: |
|
|
self.norm = SpatialGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) |
|
|
self.nonlinearity = swish |
|
|
else: |
|
|
self.norm = nn.Identity() |
|
|
self.nonlinearity = nn.Identity() |
|
|
self.spatial_factor = 2 |
|
|
self.temporal_factor = int(compress_time) if compress_time else 1 |
|
|
out_ratio = self.spatial_factor**2 |
|
|
assert out_channels % out_ratio == 0 |
|
|
out_channels = out_channels // out_ratio |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv = CogVideoXCausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode="first") |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None, temporal_compress = True) -> torch.Tensor: |
|
|
new_conv_cache = {} |
|
|
conv_cache = conv_cache or {} |
|
|
|
|
|
x = hidden_states |
|
|
x = self.nonlinearity(self.norm(x)) |
|
|
assert x.ndim == 5, f"x.ndim must be (B C T H W)" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, new_conv_cache["conv"] = self.conv(x, conv_cache=conv_cache.get("conv")) |
|
|
|
|
|
if x.shape[2] > 1: |
|
|
if x.shape[2] % 2 == 1: |
|
|
x_first, x_rest = x[:, :, 0, ...], x[:, :, 1:, ...] |
|
|
y_first, y_rest = hidden_states[:, :, 0, ...], hidden_states[:, :, 1:, ...] |
|
|
else: |
|
|
x_first, x_rest = None, x |
|
|
y_first, y_rest = None, hidden_states |
|
|
elif x.shape[2] == 1: |
|
|
x_first, x_rest = x[:, :, 0, ...], None |
|
|
y_first, y_rest = hidden_states[:, :, 0, ...], None |
|
|
else: |
|
|
raise NotImplementedError |
|
|
if x_first is not None: |
|
|
x_first = rearrange(x_first, "b c (h ph) (w pw) -> b (ph pw c) h w", ph=self.spatial_factor, pw=self.spatial_factor) |
|
|
y_first = rearrange(y_first, "b c (h ph) (w pw) -> b (ph pw c) h w", ph=self.spatial_factor, pw=self.spatial_factor) |
|
|
if x_rest is not None: |
|
|
if temporal_compress: |
|
|
x_rest = rearrange(x_rest, "b c (t pt) (h ph) (w pw) -> b (ph pw c) t pt h w", pt=self.temporal_factor, ph=self.spatial_factor, pw=self.spatial_factor) |
|
|
x_rest = x_rest.mean(dim=3) |
|
|
y_rest = rearrange(y_rest, "b c (t pt) (h ph) (w pw) -> b (ph pw c) t pt h w", pt=self.temporal_factor, ph=self.spatial_factor, pw=self.spatial_factor) |
|
|
y_rest = y_rest.mean(dim=3) |
|
|
else: |
|
|
x_rest = rearrange(x_rest, "b c (t pt) (h ph) (w pw) -> b (ph pw c) (t pt) h w", pt=self.temporal_factor, ph=self.spatial_factor, pw=self.spatial_factor) |
|
|
y_rest = rearrange(y_rest, "b c (t pt) (h ph) (w pw) -> b (ph pw c) (t pt) h w", pt=self.temporal_factor, ph=self.spatial_factor, pw=self.spatial_factor) |
|
|
if x_first is not None and x_rest is not None: |
|
|
x = torch.cat([x_first[:,:, None,...], x_rest], dim=2) |
|
|
y = torch.cat([y_first[:,:, None,...], y_rest], dim=2) |
|
|
else: |
|
|
x = x_first[:,:, None,...] if x_first is not None else x_rest |
|
|
y = y_first[:,:, None,...] if y_first is not None else y_rest |
|
|
if self.shortcut: |
|
|
y = rearrange(y, "b (g c) t h w -> b g c t h w", c=x.shape[1]).mean(dim=1) |
|
|
hidden_states = x + y |
|
|
else: |
|
|
hidden_states = x |
|
|
return hidden_states, new_conv_cache |
|
|
|
|
|
class DCUpBlock3d(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
shortcut: bool = True, |
|
|
interpolation_mode: str = "nearest", |
|
|
group_norm=False, |
|
|
compress_time=False |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.compress_time = compress_time |
|
|
if group_norm: |
|
|
self.norm = SpatialGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) |
|
|
self.nonlinearity = swish |
|
|
else: |
|
|
self.norm = nn.Identity() |
|
|
self.nonlinearity = nn.Identity() |
|
|
self.interpolation_mode = interpolation_mode |
|
|
self.shortcut = shortcut |
|
|
self.spatial_factor = 2 |
|
|
self.temporal_factor = int(compress_time) if compress_time else 1 |
|
|
out_channels = out_channels * self.spatial_factor**2 * self.temporal_factor |
|
|
|
|
|
self.conv = CogVideoXCausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode="first") |
|
|
assert out_channels % in_channels == 0 |
|
|
self.repeats = out_channels // in_channels |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor: |
|
|
new_conv_cache = {} |
|
|
conv_cache = conv_cache or {} |
|
|
|
|
|
x = hidden_states |
|
|
x = self.nonlinearity(self.norm(x)) |
|
|
|
|
|
compress_first = False |
|
|
if x.shape[2] % 2 == 1: |
|
|
compress_first = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, new_conv_cache["conv"] = self.conv(x, conv_cache=conv_cache.get("conv")) |
|
|
|
|
|
x = rearrange(x, "b (pt ph pw c) t h w -> b c (t pt) (h ph) (w pw)", pt=self.temporal_factor, ph=self.spatial_factor, pw=self.spatial_factor) |
|
|
y = repeat(hidden_states, "b c t h w -> b (r c) t h w", r=self.repeats) |
|
|
y = rearrange(y, "b (pt ph pw c) t h w -> b c (t pt) (h ph) (w pw)", pt=self.temporal_factor, ph=self.spatial_factor, pw=self.spatial_factor) |
|
|
|
|
|
|
|
|
if self.temporal_factor > 1 and compress_first: |
|
|
if x.shape[2] > 1: |
|
|
x_first, x_rest = x[:, :, :self.temporal_factor, ...], x[:, :, self.temporal_factor:, ...] |
|
|
y_first, y_rest = y[:, :, :self.temporal_factor, ...], y[:, :, self.temporal_factor:, ...] |
|
|
elif x.shape[2] == 1: |
|
|
assert x.shape[2] == y.shape[2] == self.temporal_factor |
|
|
x_first, x_rest = x, None |
|
|
y_first, y_rest = y, None |
|
|
else: |
|
|
raise NotImplementedError |
|
|
x = torch.cat([x_first.mean(dim=2, keepdim=True), x_rest], dim=2) |
|
|
y = torch.cat([y_first.mean(dim=2, keepdim=True), y_rest], dim=2) |
|
|
if self.shortcut: |
|
|
hidden_states = x + y |
|
|
else: |
|
|
hidden_states = x |
|
|
return hidden_states, new_conv_cache |
|
|
|
|
|
class DCDownBlock2d(nn.Module): |
|
|
def __init__(self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
downsample: bool = False, |
|
|
shortcut: bool = True, |
|
|
group_norm=False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
if group_norm: |
|
|
self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) |
|
|
self.nonlinearity = swish |
|
|
else: |
|
|
self.norm = nn.Identity() |
|
|
self.nonlinearity = nn.Identity() |
|
|
self.downsample = downsample |
|
|
self.factor = 2 |
|
|
self.stride = 1 if downsample else 2 |
|
|
self.group_size = in_channels * self.factor**2 // out_channels |
|
|
self.shortcut = shortcut |
|
|
|
|
|
out_ratio = self.factor**2 |
|
|
if downsample: |
|
|
assert out_channels % out_ratio == 0 |
|
|
out_channels = out_channels // out_ratio |
|
|
|
|
|
self.conv = nn.Conv2d( |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=self.stride, |
|
|
padding=1, |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
x = self.nonlinearity(self.norm(hidden_states)) |
|
|
x = self.conv(x) |
|
|
if self.downsample: |
|
|
x = F.pixel_unshuffle(x, self.factor) |
|
|
|
|
|
if self.shortcut: |
|
|
y = F.pixel_unshuffle(hidden_states, self.factor) |
|
|
y = y.unflatten(1, (-1, self.group_size)) |
|
|
y = y.mean(dim=2) |
|
|
hidden_states = x + y |
|
|
else: |
|
|
hidden_states = x |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class DCUpBlock2d(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
interpolate: bool = False, |
|
|
shortcut: bool = True, |
|
|
interpolation_mode: str = "nearest", |
|
|
group_norm=False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
if group_norm: |
|
|
self.norm = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) |
|
|
self.nonlinearity = swish |
|
|
else: |
|
|
self.norm = nn.Identity() |
|
|
self.nonlinearity = nn.Identity() |
|
|
self.interpolate = interpolate |
|
|
self.interpolation_mode = interpolation_mode |
|
|
self.shortcut = shortcut |
|
|
self.factor = 2 |
|
|
self.repeats = out_channels * self.factor**2 // in_channels |
|
|
|
|
|
out_ratio = self.factor**2 |
|
|
|
|
|
if not interpolate: |
|
|
out_channels = out_channels * out_ratio |
|
|
|
|
|
self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
x = self.nonlinearity(self.norm(hidden_states)) |
|
|
if self.interpolate: |
|
|
x = F.interpolate(x, scale_factor=self.factor, mode=self.interpolation_mode) |
|
|
x = self.conv(x) |
|
|
else: |
|
|
x = self.conv(x) |
|
|
x = F.pixel_shuffle(x, self.factor) |
|
|
|
|
|
if self.shortcut: |
|
|
y = hidden_states.repeat_interleave(self.repeats, dim=1) |
|
|
y = F.pixel_shuffle(y, self.factor) |
|
|
hidden_states = x + y |
|
|
else: |
|
|
hidden_states = x |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class CogVideoXSafeConv3d(nn.Conv3d): |
|
|
r""" |
|
|
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. |
|
|
""" |
|
|
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
|
memory_count = ( |
|
|
(input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3 |
|
|
) |
|
|
|
|
|
|
|
|
if memory_count > 2: |
|
|
kernel_size = self.kernel_size[0] |
|
|
part_num = int(memory_count / 2) + 1 |
|
|
input_chunks = torch.chunk(input, part_num, dim=2) |
|
|
|
|
|
if kernel_size > 1: |
|
|
input_chunks = [input_chunks[0]] + [ |
|
|
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) |
|
|
for i in range(1, len(input_chunks)) |
|
|
] |
|
|
|
|
|
output_chunks = [] |
|
|
for input_chunk in input_chunks: |
|
|
output_chunks.append(super().forward(input_chunk)) |
|
|
output = torch.cat(output_chunks, dim=2) |
|
|
return output |
|
|
else: |
|
|
return super().forward(input) |
|
|
|
|
|
|
|
|
class CogVideoXCausalConv3d(nn.Module): |
|
|
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. |
|
|
|
|
|
Args: |
|
|
in_channels (`int`): Number of channels in the input tensor. |
|
|
out_channels (`int`): Number of output channels produced by the convolution. |
|
|
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. |
|
|
stride (`int`, defaults to `1`): Stride of the convolution. |
|
|
dilation (`int`, defaults to `1`): Dilation rate of the convolution. |
|
|
pad_mode (`str`, defaults to `"constant"`): Padding mode. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
kernel_size: Union[int, Tuple[int, int, int]], |
|
|
stride: int = 1, |
|
|
dilation: int = 1, |
|
|
pad_mode: str = "constant", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if isinstance(kernel_size, int): |
|
|
kernel_size = (kernel_size,) * 3 |
|
|
|
|
|
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size |
|
|
|
|
|
self.pad_mode = pad_mode |
|
|
time_pad = dilation * (time_kernel_size - 1) + (1 - stride) |
|
|
height_pad = height_kernel_size // 2 |
|
|
width_pad = width_kernel_size // 2 |
|
|
|
|
|
self.height_pad = height_pad |
|
|
self.width_pad = width_pad |
|
|
self.time_pad = time_pad |
|
|
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) |
|
|
|
|
|
self.temporal_dim = 2 |
|
|
self.time_kernel_size = time_kernel_size |
|
|
|
|
|
stride = (stride, 1, 1) |
|
|
dilation = (dilation, 1, 1) |
|
|
self.conv = CogVideoXSafeConv3d( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
dilation=dilation, |
|
|
) |
|
|
|
|
|
def fake_context_parallel_forward( |
|
|
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
kernel_size = self.time_kernel_size |
|
|
|
|
|
if cp.cp_on(): |
|
|
conv_cache = dist_conv_cache_recv() |
|
|
|
|
|
if kernel_size > 1: |
|
|
cached_inputs = [conv_cache.to(inputs.device)] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) |
|
|
inputs = torch.cat(cached_inputs + [inputs], dim=2) |
|
|
return inputs |
|
|
|
|
|
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
inputs = self.fake_context_parallel_forward(inputs, conv_cache) |
|
|
|
|
|
if cp.cp_on(): |
|
|
dist_conv_cache_send(inputs[:, :, -self.time_kernel_size + 1 :]) |
|
|
else: |
|
|
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() |
|
|
|
|
|
|
|
|
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) |
|
|
inputs = F.pad(inputs, padding_2d, mode="constant", value=0) |
|
|
|
|
|
output = self.conv(inputs) |
|
|
return output, conv_cache |
|
|
|
|
|
class FluxConv(nn.Module): |
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, cnn_type="2d", cnn_slice_seq_len=17, causal_offset=0, temporal_down=False): |
|
|
super().__init__() |
|
|
self.cnn_type = cnn_type |
|
|
self.slice_seq_len = cnn_slice_seq_len |
|
|
|
|
|
if cnn_type == "2d": |
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) |
|
|
if cnn_type == "3d": |
|
|
if temporal_down == False: |
|
|
stride = (1, stride, stride) |
|
|
else: |
|
|
stride = (stride, stride, stride) |
|
|
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0) |
|
|
if isinstance(kernel_size, int): |
|
|
kernel_size = (kernel_size, kernel_size, kernel_size) |
|
|
self.padding = ( |
|
|
kernel_size[0] - 1 + causal_offset, |
|
|
padding, |
|
|
padding |
|
|
) |
|
|
self.causal_offset = causal_offset |
|
|
self.stride = stride |
|
|
self.kernel_size = kernel_size |
|
|
|
|
|
def forward(self, x): |
|
|
if self.cnn_type == "2d": |
|
|
if type(x) == list: |
|
|
for i in range(len(x)): |
|
|
x[i] = self.forward(x[i]) |
|
|
return x |
|
|
if x.ndim == 5: |
|
|
B, C, T, H, W = x.shape |
|
|
x = rearrange(x, "B C T H W -> (B T) C H W") |
|
|
x = self.conv(x) |
|
|
x = rearrange(x, "(B T) C H W -> B C T H W", T=T) |
|
|
return x |
|
|
else: |
|
|
return self.conv(x) |
|
|
if self.cnn_type == "3d": |
|
|
if x.ndim == 5: |
|
|
assert self.stride[0] == 1 or self.stride[0] == 2, f"only temporal stride = 1 or 2 are supported" |
|
|
if self.stride[0] == 1: |
|
|
for i in reversed(range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1)): |
|
|
st = i |
|
|
en = min(i+self.slice_seq_len, x.shape[2]) |
|
|
_x = x[:,:,st:en,:,:] |
|
|
if i == 0: |
|
|
_x = F.pad(_x, (self.padding[2], self.padding[2], |
|
|
self.padding[1], self.padding[1], |
|
|
self.padding[0], 0)) |
|
|
_x[:,:,:self.padding[0], |
|
|
self.padding[1]:_x.shape[-2]-self.padding[1], |
|
|
self.padding[2]:_x.shape[-1]-self.padding[2]] = x[:,:,0:1,:,:].clone() |
|
|
else: |
|
|
padding_0 = self.kernel_size[0] - 1 |
|
|
_x = F.pad(_x, (self.padding[2], self.padding[2], |
|
|
self.padding[1], self.padding[1], |
|
|
padding_0, 0)) |
|
|
_x[:,:,:padding_0, |
|
|
self.padding[1]:_x.shape[-2]-self.padding[1], |
|
|
self.padding[2]:_x.shape[-1]-self.padding[2]] = x[:,:,i-padding_0:i,:,:].clone() |
|
|
try: |
|
|
_x = self.conv(_x) |
|
|
except: |
|
|
xs = [_x[:,:,:,:,i-1:i+2] for i in range(1,_x.shape[-1]-1)] |
|
|
for i in range(len(xs)): |
|
|
xs[i] = self.conv(xs[i]) |
|
|
_x = torch.cat(xs, dim=-1) |
|
|
if i == 0: |
|
|
x[:,:,st-self.causal_offset:en,:,:] = _x |
|
|
x = x[:,:,1:,:,:] |
|
|
else: |
|
|
x[:,:,st:en,:,:] = _x |
|
|
else: |
|
|
xs = [] |
|
|
for i in range(0, x.shape[2], self.slice_seq_len+self.stride[0]-1): |
|
|
st = i |
|
|
en = min(i+self.slice_seq_len, x.shape[2]) |
|
|
_x = x[:,:,st:en,:,:] |
|
|
if i == 0: |
|
|
_x = F.pad(_x, (self.padding[2], self.padding[2], |
|
|
self.padding[1], self.padding[1], |
|
|
self.padding[0], 0)) |
|
|
_x[:,:,:self.padding[0], |
|
|
self.padding[1]:_x.shape[-2]-self.padding[1], |
|
|
self.padding[2]:_x.shape[-1]-self.padding[2]] = x[:,:,0:1,:,:].clone() |
|
|
else: |
|
|
padding_0 = self.kernel_size[0] - 1 |
|
|
_x = F.pad(_x, (self.padding[2], self.padding[2], |
|
|
self.padding[1], self.padding[1], |
|
|
padding_0, 0)) |
|
|
_x[:,:,:padding_0, |
|
|
self.padding[1]:_x.shape[-2]-self.padding[1], |
|
|
self.padding[2]:_x.shape[-1]-self.padding[2]] = x[:,:,i-padding_0:i,:,:].clone() |
|
|
_x = self.conv(_x) |
|
|
xs.append(_x) |
|
|
try: |
|
|
x = torch.cat(xs, dim=2) |
|
|
except: |
|
|
device = x.device |
|
|
del x |
|
|
xs = [_x.cpu().pin_memory() for _x in xs] |
|
|
torch.cuda.empty_cache() |
|
|
x = torch.cat([_x for _x in xs], dim=2).to(device=device) |
|
|
else: |
|
|
x = F.pad(x, (self.padding[2], self.padding[2], |
|
|
self.padding[1], self.padding[1])) |
|
|
weight = torch.sum(self.conv.weight, dim=2) |
|
|
bias = self.conv.bias |
|
|
x = F.conv2d(x, weight=weight, bias=bias,stride=self.conv.stride[1:]) |
|
|
return x |
|
|
|