|
|
import argparse, os, sys, glob |
|
|
import pathlib |
|
|
directory = pathlib.Path(os.getcwd()) |
|
|
print(directory) |
|
|
sys.path.append(str(directory)) |
|
|
import torch |
|
|
import numpy as np |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
from tqdm import tqdm, trange |
|
|
from ldm.util import instantiate_from_config |
|
|
from ldm.models.diffusion.ddim import DDIMSampler |
|
|
from ldm.models.diffusion.plms import PLMSSampler |
|
|
import pandas as pd |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm import tqdm |
|
|
from icecream import ic |
|
|
from pathlib import Path |
|
|
import yaml |
|
|
from vocoder.bigvgan.models import VocoderBigVGAN |
|
|
import soundfile |
|
|
|
|
|
|
|
|
def load_model_from_config(config, ckpt = None, verbose=True): |
|
|
model = instantiate_from_config(config.model) |
|
|
if ckpt: |
|
|
print(f"Loading model from {ckpt}") |
|
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
|
sd = pl_sd["state_dict"] |
|
|
|
|
|
m, u = model.load_state_dict(sd, strict=False) |
|
|
if len(m) > 0 and verbose: |
|
|
print("missing keys:") |
|
|
print(m) |
|
|
if len(u) > 0 and verbose: |
|
|
print("unexpected keys:") |
|
|
print(u) |
|
|
else: |
|
|
print(f"Note chat no ckpt is loaded !!!") |
|
|
|
|
|
model.cuda() |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument( |
|
|
"--sample_rate", |
|
|
type=int, |
|
|
default="16000", |
|
|
help="sample rate of wav" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--test-dataset", |
|
|
default="none", |
|
|
help="test which dataset: audiocaps/clotho/fsd50k" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--outdir", |
|
|
type=str, |
|
|
nargs="?", |
|
|
help="dir to write results to", |
|
|
default="outputs/txt2audio-samples" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"-r", |
|
|
"--resume", |
|
|
type=str, |
|
|
const=True, |
|
|
default="", |
|
|
nargs="?", |
|
|
help="resume from logdir or checkpoint in logdir", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-b", |
|
|
"--base", |
|
|
type=str, |
|
|
help="paths to base configs. Loaded from left-to-right. " |
|
|
"Parameters can be overwritten or added with command-line options of the form `--key value`.", |
|
|
default="", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--vocoder-ckpt", |
|
|
type=str, |
|
|
help="paths to vocoder checkpoint", |
|
|
default='vocoder/logs/bigvnat16k93.5w', |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
class GenSamples: |
|
|
def __init__(self,opt,model,outpath,vocoder = None,save_mel = False,save_wav = True) -> None: |
|
|
self.opt = opt |
|
|
self.model = model |
|
|
self.outpath = outpath |
|
|
if save_wav: |
|
|
assert vocoder is not None |
|
|
self.vocoder = vocoder |
|
|
self.save_mel = save_mel |
|
|
self.save_wav = save_wav |
|
|
|
|
|
def gen_test_sample(self,mel,mel_name = None,wav_name = None): |
|
|
uc = None |
|
|
record_dicts = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
recon_mel,posterior = self.model(mel) |
|
|
spec = recon_mel.squeeze(0).cpu().numpy() |
|
|
|
|
|
|
|
|
if self.save_wav: |
|
|
wav = self.vocoder.vocode(spec) |
|
|
wav_path = os.path.join(self.outpath,wav_name+'.wav') |
|
|
soundfile.write(wav_path, wav, self.opt.sample_rate) |
|
|
return |
|
|
|
|
|
def main(): |
|
|
opt = parse_args() |
|
|
|
|
|
config = OmegaConf.load(opt.base) |
|
|
|
|
|
|
|
|
model = load_model_from_config(config, opt.resume) |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
os.makedirs(opt.outdir, exist_ok=True) |
|
|
if 'mel' in opt.vocoder_ckpt: |
|
|
vocoder = VocoderMelGan(opt.vocoder_ckpt,device) |
|
|
elif 'hifi' in opt.vocoder_ckpt: |
|
|
vocoder = VocoderHifigan(opt.vocoder_ckpt,device) |
|
|
elif 'bigv' in opt.vocoder_ckpt: |
|
|
vocoder = VocoderBigVGAN(opt.vocoder_ckpt,device) |
|
|
|
|
|
|
|
|
generator = GenSamples(opt,model,opt.outdir,vocoder,save_mel = False,save_wav = True) |
|
|
csv_dicts = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
if opt.test_dataset != 'none': |
|
|
if opt.test_dataset == 'audiocaps': |
|
|
test_dataset = instantiate_from_config(config['test_dataset']) |
|
|
elif opt.test_dataset == 'clotho': |
|
|
test_dataset = instantiate_from_config(config['test_dataset2']) |
|
|
elif opt.test_dataset == 'fsd50k': |
|
|
test_dataset = instantiate_from_config(config['test_dataset3']) |
|
|
elif opt.test_dataset == 'musiccap': |
|
|
test_dataset = instantiate_from_config(config['test_dataset']) |
|
|
print(f"Dataset: {type(test_dataset)} LEN: {len(test_dataset)}") |
|
|
for item in tqdm(test_dataset): |
|
|
mel,f_name = item['image'],item['f_name'] |
|
|
mel = torch.from_numpy(mel).to(device).unsqueeze(0) |
|
|
vname_num_split_index = f_name.rfind('_') |
|
|
v_n,num = f_name[:vname_num_split_index],f_name[vname_num_split_index+1:] |
|
|
mel_name = f'{v_n}_sample_{num}' |
|
|
wav_name = f'{v_n}_sample_{num}' |
|
|
generator.gen_test_sample(mel,mel_name=mel_name,wav_name=wav_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
with open(opt.prompt_txt,'r') as f: |
|
|
prompts = f.readlines() |
|
|
for prompt in prompts: |
|
|
wav_name = f'{prompt.strip().replace(" ", "-")}' |
|
|
generator.gen_test_sample(prompt,wav_name=wav_name) |
|
|
|
|
|
print(f"Your samples are ready and waiting four you here: \n{opt.outdir} \nEnjoy.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|