Spaces:
Running
on
T4
Running
on
T4
| import json | |
| import torch | |
| import torch.nn as nn | |
| from Preprocessing.Codec.env import AttrDict | |
| from Preprocessing.Codec.models import Encoder | |
| from Preprocessing.Codec.models import Generator | |
| from Preprocessing.Codec.models import Quantizer | |
| class VQVAE(nn.Module): | |
| def __init__(self, | |
| config_path, | |
| ckpt_path, | |
| with_encoder=False): | |
| super(VQVAE, self).__init__() | |
| ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
| with open(config_path) as f: | |
| data = f.read() | |
| json_config = json.loads(data) | |
| self.h = AttrDict(json_config) | |
| self.quantizer = Quantizer(self.h) | |
| self.generator = Generator(self.h) | |
| self.generator.load_state_dict(ckpt['generator']) | |
| self.quantizer.load_state_dict(ckpt['quantizer']) | |
| if with_encoder: | |
| self.encoder = Encoder(self.h) | |
| self.encoder.load_state_dict(ckpt['encoder']) | |
| def forward(self, x): | |
| # x is the codebook | |
| # x.shape (B, T, Nq) | |
| quant_emb = self.quantizer.embed(x) | |
| return self.generator(quant_emb) | |
| def encode(self, x): | |
| batch_size = x.size(0) | |
| if len(x.shape) == 3 and x.shape[-1] == 1: | |
| x = x.squeeze(-1) | |
| c = self.encoder(x.unsqueeze(1)) | |
| q, loss_q, c = self.quantizer(c) | |
| c = [code.reshape(batch_size, -1) for code in c] | |
| # shape: [N, T, 4] | |
| return torch.stack(c, -1) | |