Hive / models /flowsep /latent_encoder /autoencoder.py
JusperLee's picture
update all code
859ee84
import torch
import os
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
import numpy as np
from latent_diffusion.modules.ema import *
# from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import VectorQuantizer
from torch.optim.lr_scheduler import LambdaLR
from latent_diffusion.modules.diffusionmodules.model import Encoder, Decoder
from latent_diffusion.modules.distributions.distributions import (
DiagonalGaussianDistribution,
)
import wandb
from latent_diffusion.util import instantiate_from_config
import soundfile as sf
from huggingface_hub import hf_hub_download
from utilities.model import get_vocoder
from utilities.tools import synth_one_sample
import itertools
from latent_encoder.wavedecoder import Generator
class AutoencoderKL(pl.LightningModule):
def __init__(
self,
ddconfig=None,
lossconfig=None,
batchsize=None,
embed_dim=None,
time_shuffle=1,
subband=1,
sampling_rate=16000,
ckpt_path=None,
reload_from_ckpt=None,
ignore_keys=[],
image_key="fbank",
colorize_nlabels=None,
monitor=None,
base_learning_rate=1e-5,
):
super().__init__()
self.automatic_optimization=False
assert "mel_bins" in ddconfig.keys(), "mel_bins is not specified in the Autoencoder config"
num_mel = ddconfig["mel_bins"]
self.image_key = image_key
self.sampling_rate = sampling_rate
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.subband = int(subband)
if self.subband > 1:
print("Use subband decomposition %s" % self.subband)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if self.image_key == "fbank":
self.vocoder = get_vocoder(None, "cpu", num_mel)
elif self.image_key == "stft":
self.wave_decoder = Generator(input_channel=512)
self.wave_decoder.train()
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.learning_rate = float(base_learning_rate)
print("Initial learning rate %s" % self.learning_rate)
self.time_shuffle = time_shuffle
self.reload_from_ckpt = reload_from_ckpt
self.reloaded = False
self.mean, self.std = None, None
self.feature_cache = None
self.flag_first_run = True
self.train_step = 0
self.logger_save_dir = None
self.logger_exp_name = None
if not self.reloaded and self.reload_from_ckpt is not None:
ckpt_path = self.reload_from_ckpt
if not os.path.exists(ckpt_path):
print(
"--> Local autoencoder ckpt not found: %s, downloading vae.ckpt from Hugging Face..."
% ckpt_path
)
ckpt_path = hf_hub_download(
repo_id="ShandaAI/FlowSep-hive",
filename="vae.ckpt",
)
print("--> Reload weight of autoencoder from %s" % ckpt_path)
checkpoint = torch.load(ckpt_path)
self.load_state_dict(checkpoint["state_dict"])
self.reloaded = True
else:
print("Train from scratch")
def get_log_dir(self):
if self.logger_save_dir is None and self.logger_exp_name is None:
return os.path.join(self.logger.save_dir, self.logger._project)
else:
return os.path.join(self.logger_save_dir, self.logger_exp_name)
def set_log_dir(self, save_dir, exp_name):
self.logger_save_dir = save_dir
self.logger_exp_name = exp_name
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
# x = self.time_shuffle_operation(x)
x = self.freq_split_subband(x)
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
# bs, ch, shuffled_timesteps, fbins = dec.size()
# dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
dec = self.freq_merge_subband(dec)
return dec
def decode_to_waveform(self, dec):
from utilities.model import vocoder_infer
if self.image_key == "fbank":
dec = dec.squeeze(1).permute(0, 2, 1)
wav_reconstruction = vocoder_infer(dec, self.vocoder)
elif self.image_key == "stft":
dec = dec.squeeze(1).permute(0, 2, 1)
wav_reconstruction = self.wave_decoder(dec)
return wav_reconstruction
def visualize_latent(self, input):
import matplotlib.pyplot as plt
np.save("input.npy", input.cpu().detach().numpy())
time_input = input.clone()
time_input[:, :, :, :32] *= 0
time_input[:, :, :, :32] -= 11.59
np.save("time_input.npy", time_input.cpu().detach().numpy())
posterior = self.encode(time_input)
latent = posterior.sample()
np.save("time_latent.npy", latent.cpu().detach().numpy())
avg_latent = torch.mean(latent, dim=1)
for i in range(avg_latent.size(0)):
plt.imshow(avg_latent[i].cpu().detach().numpy().T)
plt.savefig("freq_%s.png" % i)
plt.close()
freq_input = input.clone()
freq_input[:, :, :512, :] *= 0
freq_input[:, :, :512, :] -= 11.59
np.save("freq_input.npy", freq_input.cpu().detach().numpy())
posterior = self.encode(freq_input)
latent = posterior.sample()
np.save("freq_latent.npy", latent.cpu().detach().numpy())
avg_latent = torch.mean(latent, dim=1)
for i in range(avg_latent.size(0)):
plt.imshow(avg_latent[i].cpu().detach().numpy().T)
plt.savefig("time_%s.png" % i)
plt.close()
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
if self.flag_first_run:
print("Latent size: ", z.size())
self.flag_first_run = False
dec = self.decode(z)
return dec, posterior
def get_input(self, batch):
fname, text, label_indices, waveform, stft, fbank = (
batch["fname"],
batch["text"],
batch["label_vector"],
batch["waveform"],
batch["stft"],
batch["log_mel_spec"],
)
ret = {}
ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
fbank.unsqueeze(1),
stft.unsqueeze(1),
fname,
waveform.unsqueeze(1),
)
return ret
def freq_split_subband(self, fbank):
if self.subband == 1 or self.image_key != "stft":
return fbank
bs, ch, tstep, fbins = fbank.size()
assert fbank.size(-1) % self.subband == 0
assert ch == 1
return (
fbank.squeeze(1)
.reshape(bs, tstep, self.subband, fbins // self.subband)
.permute(0, 2, 1, 3)
)
def freq_merge_subband(self, subband_fbank):
if self.subband == 1 or self.image_key != "stft":
return subband_fbank
assert subband_fbank.size(1) == self.subband # Channel dimension
bs, sub_ch, tstep, fbins = subband_fbank.size()
return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
def save_wave(self, batch_wav, fname, save_dir):
os.makedirs(save_dir, exist_ok=True)
for wav, name in zip(batch_wav, fname):
name = os.path.basename(name)
sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
log = dict()
x = batch.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
log["samples"] = self.decode(posterior.sample())
log["reconstructions"] = xrec
log["inputs"] = x
wavs = self._log_img(log, train=train, index=0, waveform=waveform)
return wavs
def _log_img(self, log, train=True, index=0, waveform=None):
images_input = self.tensor2numpy(log["inputs"][index, 0]).T
images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
images_samples = self.tensor2numpy(log["samples"][index, 0]).T
if train:
name = "train"
else:
name = "val"
if self.logger is not None:
self.logger.log_image(
"img_%s" % name,
[images_input, images_reconstruct, images_samples],
caption=["input", "reconstruct", "samples"],
)
inputs, reconstructions, samples = (
log["inputs"],
log["reconstructions"],
log["samples"],
)
if self.image_key == "fbank":
wav_original, wav_prediction = synth_one_sample(
inputs[index],
reconstructions[index],
labels="validation",
vocoder=self.vocoder,
)
wav_original, wav_samples = synth_one_sample(
inputs[index], samples[index], labels="validation", vocoder=self.vocoder
)
wav_original, wav_samples, wav_prediction = (
wav_original[0],
wav_samples[0],
wav_prediction[0],
)
elif self.image_key == "stft":
wav_prediction = (
self.decode_to_waveform(reconstructions)[index, 0]
.cpu()
.detach()
.numpy()
)
wav_samples = (
self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
)
wav_original = waveform[index, 0].cpu().detach().numpy()
if self.logger is not None:
self.logger.experiment.log(
{
"original_%s"
% name: wandb.Audio(
wav_original, caption="original", sample_rate=self.sampling_rate
),
"reconstruct_%s"
% name: wandb.Audio(
wav_prediction, caption="reconstruct", sample_rate=self.sampling_rate
),
"samples_%s"
% name: wandb.Audio(
wav_samples, caption="samples", sample_rate=self.sampling_rate
),
}
)
return wav_original, wav_prediction, wav_samples
def tensor2numpy(self, tensor):
return tensor.cpu().detach().numpy()
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class VQModel(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False,
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(
f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = (
x.permute(0, 3, 1, 2)
.to(memory_format=torch.contiguous_format)
.float()
.contiguous()
)
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x