Spaces:
Running on Zero
Running on Zero
| import lightning as L | |
| import torch | |
| import torch.nn as nn | |
| from dac import Encoder, Decoder | |
| from rvq import ResidualVectorQuantization | |
| from typing import List, Dict, Any, Union | |
| class DacRVQ(L.LightningModule): | |
| def __init__(self, configs): | |
| super().__init__() | |
| self.encoder = Encoder(**configs['encoder']) | |
| self.decoder = Decoder(**configs['decoder']) | |
| self.q_dropout = configs['quantizer'].pop('q_dropout') | |
| self.quantizer = ResidualVectorQuantization(**configs['quantizer']) | |
| self.apply(self.init_weights) | |
| def init_weights(self, m): | |
| if isinstance(m, nn.Conv1d): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x, force_full_quantization: bool = False) -> Union[tuple]: | |
| # x shape: (batch_size, channels, T) | |
| assert x.dim() == 3 | |
| z = self.encoder(x) | |
| if self.training and self.q_dropout and not force_full_quantization: | |
| # rand_nq = int(torch.randint(1, self.quantizer.num_quantizers + 1, (1,)).item()) | |
| # q_z, codes, commit_loss = self.quantizer(z_t, nq_to_use=rand_nq) | |
| if torch.rand(()) < 0.2: | |
| rand_nq = self.quantizer.num_quantizers | |
| else: | |
| rand_nq = int(torch.randint(1, self.q_dropout + 1, (1,)).item()) | |
| q_z, codes, commit_loss = self.quantizer(z, nq_to_use=rand_nq) | |
| else: | |
| q_z, codes, commit_loss = self.quantizer(z) | |
| x_pred = self.decoder(q_z) | |
| return x_pred, q_z, codes, commit_loss | |
| def encode(self, x: torch.Tensor) -> torch.Tensor: | |
| """Encodes an input waveform into a list of discrete code tensors.""" | |
| assert x.dim() == 3 | |
| z = self.encoder(x) | |
| _, codes, _ = self.quantizer(z) | |
| return codes | |
| def decode(self, codes: torch.Tensor) -> torch.Tensor: | |
| """Decodes a list of discrete code tensors back into a waveform.""" | |
| assert isinstance(codes, torch.Tensor), "Input `codes` must be a tensors." | |
| q_z = self.quantizer.decode(codes) #q_z shape: (B, C, T) | |
| x_pred = self.decoder(q_z) | |
| return x_pred | |