|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""The network definition for discrete video tokenizer with VQ, LFQ, FSQ or ResidualFSQ. """ |
|
|
from collections import OrderedDict, namedtuple |
|
|
|
|
|
import torch |
|
|
from loguru import logger as logging |
|
|
from torch import nn |
|
|
|
|
|
from cosmos_predict1.tokenizer.modules import Decoder3DType, DiscreteQuantizer, Encoder3DType |
|
|
from cosmos_predict1.tokenizer.modules.layers3d import CausalConv3d |
|
|
from cosmos_predict1.tokenizer.modules.quantizers import InvQuantizerJit |
|
|
|
|
|
NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) |
|
|
|
|
|
|
|
|
class CausalDiscreteVideoTokenizer(nn.Module): |
|
|
def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: |
|
|
super().__init__() |
|
|
self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") |
|
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) |
|
|
self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) |
|
|
|
|
|
decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) |
|
|
self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) |
|
|
|
|
|
self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) |
|
|
self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) |
|
|
|
|
|
quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) |
|
|
if quantizer_name == DiscreteQuantizer.VQ.name: |
|
|
assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." |
|
|
kwargs.update(dict(embedding_dim=embedding_dim)) |
|
|
elif quantizer_name == DiscreteQuantizer.LFQ.name: |
|
|
assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." |
|
|
assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." |
|
|
elif quantizer_name == DiscreteQuantizer.FSQ.name: |
|
|
assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." |
|
|
elif quantizer_name == DiscreteQuantizer.RESFSQ.name: |
|
|
assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." |
|
|
assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." |
|
|
self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) |
|
|
logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") |
|
|
|
|
|
num_parameters = sum(param.numel() for param in self.parameters()) |
|
|
logging.info(f"model={self.name}, num_parameters={num_parameters:,}") |
|
|
logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") |
|
|
|
|
|
def to(self, *args, **kwargs): |
|
|
setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) |
|
|
return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) |
|
|
|
|
|
def encoder_jit(self): |
|
|
return nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("encoder", self.encoder), |
|
|
("quant_conv", self.quant_conv), |
|
|
("quantizer", self.quantizer), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
def decoder_jit(self): |
|
|
return nn.Sequential( |
|
|
OrderedDict( |
|
|
[ |
|
|
("inv_quant", InvQuantizerJit(self.quantizer)), |
|
|
("post_quant_conv", self.post_quant_conv), |
|
|
("decoder", self.decoder), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
def last_decoder_layer(self): |
|
|
return self.decoder.conv_out |
|
|
|
|
|
def encode(self, x): |
|
|
h = self.encoder(x) |
|
|
h = self.quant_conv(h) |
|
|
return self.quantizer(h) |
|
|
|
|
|
def decode(self, quant): |
|
|
quant = self.post_quant_conv(quant) |
|
|
return self.decoder(quant) |
|
|
|
|
|
def decode_code(self, code_b): |
|
|
quant_b = self.quantizer.indices_to_codes(code_b) |
|
|
quant_b = self.post_quant_conv(quant_b) |
|
|
return self.decoder(quant_b) |
|
|
|
|
|
def forward(self, input): |
|
|
quant_info, quant_codes, quant_loss = self.encode(input) |
|
|
reconstructions = self.decode(quant_codes) |
|
|
if self.training: |
|
|
return dict( |
|
|
reconstructions=reconstructions, |
|
|
quant_loss=quant_loss, |
|
|
quant_info=quant_info, |
|
|
) |
|
|
return NetworkEval( |
|
|
reconstructions=reconstructions, |
|
|
quant_loss=quant_loss, |
|
|
quant_info=quant_info, |
|
|
) |
|
|
|