| import argparse, os, sys, glob
|
| import pathlib
|
| directory = pathlib.Path(os.getcwd())
|
| print(directory)
|
| sys.path.append(str(directory))
|
| import torch
|
| import numpy as np
|
| from omegaconf import OmegaConf
|
| from PIL import Image
|
| from tqdm import tqdm, trange
|
| from ldm.util import instantiate_from_config
|
| from ldm.models.diffusion.scheduling_lcm import LCMSampler
|
| from ldm.models.diffusion.plms import PLMSSampler
|
| import pandas as pd
|
| from torch.utils.data import DataLoader
|
| from tqdm import tqdm
|
| from icecream import ic
|
| from pathlib import Path
|
| import soundfile as sf
|
| import yaml
|
| import datetime
|
| from vocoder.bigvgan.models import VocoderBigVGAN
|
| import soundfile
|
|
|
| import gradio
|
|
|
| def load_model_from_config(config, ckpt = None, verbose=True):
|
| model = instantiate_from_config(config.model)
|
| if ckpt:
|
| print(f"Loading model from {ckpt}")
|
| pl_sd = torch.load(ckpt, map_location="cpu")
|
| sd = pl_sd["state_dict"]
|
|
|
| m, u = model.load_state_dict(sd, strict=False)
|
| if len(m) > 0 and verbose:
|
| print("missing keys:")
|
| print(m)
|
| if len(u) > 0 and verbose:
|
| print("unexpected keys:")
|
| print(u)
|
| else:
|
| print(f"Note chat no ckpt is loaded !!!")
|
|
|
| model.cuda()
|
| model.eval()
|
| return model
|
|
|
|
|
|
|
|
|
| class GenSamples:
|
| def __init__(self,sampler,model,outpath,vocoder = None,save_mel = True,save_wav = True, original_inference_steps=None) -> None:
|
| self.sampler = sampler
|
| self.model = model
|
| self.outpath = outpath
|
| if save_wav:
|
| assert vocoder is not None
|
| self.vocoder = vocoder
|
| self.save_mel = save_mel
|
| self.save_wav = save_wav
|
| self.channel_dim = self.model.channels
|
| self.original_inference_steps = original_inference_steps
|
|
|
| def gen_test_sample(self,prompt,mel_name = None,wav_name = None):
|
| uc = None
|
| record_dicts = []
|
|
|
|
|
| emptycap = {'ori_caption':1*[""],'struct_caption':1*[""]}
|
| uc = self.model.get_learned_conditioning(emptycap)
|
|
|
| for n in range(1):
|
| for k,v in prompt.items():
|
| prompt[k] = 1 * [v]
|
| c = self.model.get_learned_conditioning(prompt)
|
| if self.channel_dim>0:
|
| shape = [self.channel_dim, 20, 312]
|
| else:
|
| shape = [20, 312]
|
| samples_ddim, _ = self.sampler.sample(S=2,
|
| conditioning=c,
|
| batch_size=1,
|
| shape=shape,
|
| verbose=False,
|
| guidance_scale=5,
|
| original_inference_steps=self.original_inference_steps
|
| )
|
| x_samples_ddim = self.model.decode_first_stage(samples_ddim)
|
| for idx,spec in enumerate(x_samples_ddim):
|
| spec = spec.squeeze(0).cpu().numpy()
|
| record_dict = {'caption':prompt['ori_caption'][0]}
|
| if self.save_mel:
|
| mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy')
|
| np.save(mel_path,spec)
|
| record_dict['mel_path'] = mel_path
|
| if self.save_wav:
|
| wav = self.vocoder.vocode(spec)
|
| wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav')
|
| soundfile.write(wav_path, wav, 16000)
|
| record_dict['audio_path'] = wav_path
|
| record_dicts.append(record_dict)
|
| return record_dicts
|
|
|
|
|
| def infer(ori_prompt):
|
|
|
| prompt = dict(ori_caption=ori_prompt,struct_caption=f'<{ori_prompt}& all>')
|
|
|
|
|
| config = OmegaConf.load("configs/audiolcm.yaml")
|
|
|
|
|
|
|
| model = load_model_from_config(config, "../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt")
|
|
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| model = model.to(device)
|
|
|
| sampler = LCMSampler(model)
|
|
|
| os.makedirs("results/test", exist_ok=True)
|
|
|
| vocoder = VocoderBigVGAN("../vocoder/logs/bigvnat16k93.5w",device)
|
|
|
|
|
| generator = GenSamples(sampler,model,"results/test",vocoder,save_mel = False,save_wav = True, original_inference_steps=config.model.params.num_ddim_timesteps)
|
| csv_dicts = []
|
|
|
| with torch.no_grad():
|
| with model.ema_scope():
|
| wav_name = f'{prompt.strip().replace(" ", "-")}'
|
| generator.gen_test_sample(prompt,wav_name=wav_name)
|
|
|
| print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.")
|
| return "results/test/"+wav_name+"_0.wav"
|
|
|
| def my_inference_function(prompt_oir):
|
| file_path = infer(prompt_oir)
|
| return file_path
|
|
|
|
|
|
|
| gradio_interface = gradio.Interface(
|
| fn = my_inference_function,
|
| inputs = "text",
|
| outputs = "audio"
|
| )
|
| gradio_interface.launch() |