|
|
|
|
|
""" |
|
|
Tiny AutoEncoder for Hunyuan Video https://github.com/madebyollin/taehv |
|
|
(DNN for encoding / decoding videos to Hunyuan Video's latent space) |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from tqdm.auto import tqdm |
|
|
from collections import namedtuple |
|
|
|
|
|
DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) |
|
|
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) |
|
|
|
|
|
def conv(n_in, n_out, **kwargs): |
|
|
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) |
|
|
|
|
|
class Clamp(nn.Module): |
|
|
def forward(self, x): |
|
|
return torch.tanh(x / 3) * 3 |
|
|
|
|
|
class MemBlock(nn.Module): |
|
|
def __init__(self, n_in, n_out): |
|
|
super().__init__() |
|
|
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out)) |
|
|
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() |
|
|
self.act = nn.ReLU(inplace=True) |
|
|
def forward(self, x, past): |
|
|
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) |
|
|
|
|
|
class TPool(nn.Module): |
|
|
def __init__(self, n_f, stride): |
|
|
super().__init__() |
|
|
self.stride = stride |
|
|
self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False) |
|
|
def forward(self, x): |
|
|
_NT, C, H, W = x.shape |
|
|
return self.conv(x.reshape(-1, self.stride * C, H, W)) |
|
|
|
|
|
class TGrow(nn.Module): |
|
|
def __init__(self, n_f, stride): |
|
|
super().__init__() |
|
|
self.stride = stride |
|
|
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False) |
|
|
def forward(self, x): |
|
|
_NT, C, H, W = x.shape |
|
|
x = self.conv(x) |
|
|
return x.reshape(-1, C, H, W) |
|
|
|
|
|
def apply_model_with_memblocks(model, x, parallel, show_progress_bar): |
|
|
""" |
|
|
Apply a sequential model with memblocks to the given input. |
|
|
Args: |
|
|
- model: nn.Sequential of blocks to apply |
|
|
- x: input data, of dimensions NTCHW |
|
|
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory) |
|
|
if False, each timestep will be processed sequentially (slow but uses O(1) memory) |
|
|
- show_progress_bar: if True, enables tqdm progressbar display |
|
|
|
|
|
Returns NTCHW tensor of output data. |
|
|
""" |
|
|
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor" |
|
|
N, T, C, H, W = x.shape |
|
|
if parallel: |
|
|
x = x.reshape(N*T, C, H, W) |
|
|
|
|
|
for b in tqdm(model, disable=not show_progress_bar): |
|
|
if isinstance(b, MemBlock): |
|
|
NT, C, H, W = x.shape |
|
|
T = NT // N |
|
|
_x = x.reshape(N, T, C, H, W) |
|
|
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) |
|
|
x = b(x, mem) |
|
|
else: |
|
|
x = b(x) |
|
|
NT, C, H, W = x.shape |
|
|
T = NT // N |
|
|
x = x.view(N, T, C, H, W) |
|
|
else: |
|
|
|
|
|
|
|
|
out = [] |
|
|
|
|
|
|
|
|
|
|
|
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))] |
|
|
|
|
|
|
|
|
progress_bar = tqdm(range(T), disable=not show_progress_bar) |
|
|
|
|
|
mem = [None] * len(model) |
|
|
while work_queue: |
|
|
xt, i = work_queue.pop(0) |
|
|
if i == 0: |
|
|
|
|
|
progress_bar.update(1) |
|
|
if i == len(model): |
|
|
|
|
|
out.append(xt) |
|
|
else: |
|
|
|
|
|
b = model[i] |
|
|
if isinstance(b, MemBlock): |
|
|
|
|
|
if mem[i] is None: |
|
|
xt_new = b(xt, xt * 0) |
|
|
mem[i] = xt |
|
|
else: |
|
|
xt_new = b(xt, mem[i]) |
|
|
mem[i].copy_(xt) |
|
|
|
|
|
work_queue.insert(0, TWorkItem(xt_new, i+1)) |
|
|
elif isinstance(b, TPool): |
|
|
|
|
|
if mem[i] is None: |
|
|
mem[i] = [] |
|
|
mem[i].append(xt) |
|
|
if len(mem[i]) > b.stride: |
|
|
|
|
|
raise ValueError("???") |
|
|
elif len(mem[i]) < b.stride: |
|
|
|
|
|
pass |
|
|
else: |
|
|
|
|
|
N, C, H, W = xt.shape |
|
|
xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W)) |
|
|
|
|
|
mem[i] = [] |
|
|
|
|
|
work_queue.insert(0, TWorkItem(xt, i+1)) |
|
|
elif isinstance(b, TGrow): |
|
|
xt = b(xt) |
|
|
NT, C, H, W = xt.shape |
|
|
|
|
|
for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)): |
|
|
|
|
|
work_queue.insert(0, TWorkItem(xt_next, i+1)) |
|
|
else: |
|
|
|
|
|
xt = b(xt) |
|
|
|
|
|
work_queue.insert(0, TWorkItem(xt, i+1)) |
|
|
progress_bar.close() |
|
|
x = torch.stack(out, 1) |
|
|
return x |
|
|
|
|
|
class TAEHV(nn.Module): |
|
|
def __init__(self, state_dict, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), dtype=torch.float16): |
|
|
"""Initialize pretrained TAEHV from the given checkpoint. |
|
|
|
|
|
Arg: |
|
|
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1. |
|
|
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview. |
|
|
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview. |
|
|
""" |
|
|
super().__init__() |
|
|
self.image_channels = 3 |
|
|
self.latent_channels = state_dict["decoder.1.weight"].shape[1] |
|
|
self.patch_size = 1 |
|
|
if self.latent_channels == 48: |
|
|
self.patch_size = 2 |
|
|
self.dtype = dtype |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
conv(self.image_channels*self.patch_size**2, 64), nn.ReLU(inplace=True), |
|
|
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), |
|
|
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), |
|
|
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), |
|
|
conv(64, self.latent_channels), |
|
|
) |
|
|
n_f = [256, 128, 64, 64] |
|
|
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 |
|
|
self.decoder = nn.Sequential( |
|
|
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True), |
|
|
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), |
|
|
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), |
|
|
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), |
|
|
nn.ReLU(inplace=True), conv(n_f[3], self.image_channels*self.patch_size**2), |
|
|
) |
|
|
if state_dict is not None: |
|
|
self.load_state_dict(self.patch_tgrow_layers(state_dict)) |
|
|
|
|
|
self.parallel = parallel |
|
|
|
|
|
def patch_tgrow_layers(self, sd): |
|
|
"""Patch TGrow layers to use a smaller kernel if needed. |
|
|
|
|
|
Args: |
|
|
sd: state dict to patch |
|
|
""" |
|
|
new_sd = self.state_dict() |
|
|
for i, layer in enumerate(self.decoder): |
|
|
if isinstance(layer, TGrow): |
|
|
key = f"decoder.{i}.conv.weight" |
|
|
if sd[key].shape[0] > new_sd[key].shape[0]: |
|
|
|
|
|
sd[key] = sd[key][-new_sd[key].shape[0]:] |
|
|
return sd |
|
|
|
|
|
def encode_video(self, x, parallel=False, show_progress_bar=True): |
|
|
"""Encode a sequence of frames. |
|
|
|
|
|
Args: |
|
|
x: input NTCHW RGB (C=3) tensor with values in [0, 1]. |
|
|
parallel: if True, all frames will be processed at once. |
|
|
(this is faster but may require more memory). |
|
|
if False, frames will be processed sequentially. |
|
|
Returns NTCHW latent tensor with ~Gaussian values. |
|
|
""" |
|
|
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) |
|
|
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar) |
|
|
|
|
|
def decode_video(self, x, parallel=False, show_progress_bar=True): |
|
|
"""Decode a sequence of frames. |
|
|
|
|
|
Args: |
|
|
x: input NTCHW latent (C=12) tensor with ~Gaussian values. |
|
|
parallel: if True, all frames will be processed at once. |
|
|
(this is faster but may require more memory). |
|
|
if False, frames will be processed sequentially. |
|
|
Returns NTCHW RGB tensor with ~[0, 1] values. |
|
|
""" |
|
|
x = apply_model_with_memblocks(self.decoder, x, self.parallel, show_progress_bar) |
|
|
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) |
|
|
return x[:, self.frames_to_trim:] |
|
|
|
|
|
def forward(self, x): |
|
|
return self.c(x) |
|
|
|