|
|
import spaces |
|
|
import argparse, os, sys, glob |
|
|
import pathlib |
|
|
directory = pathlib.Path(os.getcwd()) |
|
|
print(directory) |
|
|
sys.path.append(str(directory)) |
|
|
import torch |
|
|
import numpy as np |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
from tqdm import tqdm, trange |
|
|
from ldm.util import instantiate_from_config |
|
|
from ldm.models.diffusion.scheduling_lcm import LCMSampler |
|
|
from ldm.models.diffusion.plms import PLMSSampler |
|
|
import pandas as pd |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm import tqdm |
|
|
from icecream import ic |
|
|
from pathlib import Path |
|
|
import soundfile as sf |
|
|
import yaml |
|
|
import datetime |
|
|
from vocoder.bigvgan.models import VocoderBigVGAN |
|
|
import soundfile |
|
|
|
|
|
import gradio |
|
|
import gradio as gr |
|
|
|
|
|
def load_model_from_config(config, ckpt = None, verbose=True): |
|
|
model = instantiate_from_config(config.model) |
|
|
if ckpt: |
|
|
print(f"Loading model from {ckpt}") |
|
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
|
sd = pl_sd["state_dict"] |
|
|
|
|
|
m, u = model.load_state_dict(sd, strict=False) |
|
|
if len(m) > 0 and verbose: |
|
|
print("missing keys:") |
|
|
print(m) |
|
|
if len(u) > 0 and verbose: |
|
|
print("unexpected keys:") |
|
|
print(u) |
|
|
else: |
|
|
print(f"Note chat no ckpt is loaded !!!") |
|
|
|
|
|
model.cuda() |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GenSamples: |
|
|
def __init__(self,sampler,model,outpath,vocoder = None,save_mel = True,save_wav = True, original_inference_steps=None, ddim_steps=2, scale=5, num_samples=1) -> None: |
|
|
self.sampler = sampler |
|
|
self.model = model |
|
|
self.outpath = outpath |
|
|
if save_wav: |
|
|
assert vocoder is not None |
|
|
self.vocoder = vocoder |
|
|
self.save_mel = save_mel |
|
|
self.save_wav = save_wav |
|
|
self.channel_dim = self.model.channels |
|
|
self.original_inference_steps = original_inference_steps |
|
|
self.ddim_steps = ddim_steps |
|
|
self.scale = scale |
|
|
self.num_samples = num_samples |
|
|
|
|
|
def gen_test_sample(self,prompt,mel_name = None,wav_name = None): |
|
|
uc = None |
|
|
record_dicts = [] |
|
|
|
|
|
|
|
|
if self.scale != 1.0: |
|
|
emptycap = {'ori_caption':self.num_samples*[""],'struct_caption':self.num_samples*[""]} |
|
|
uc = self.model.get_learned_conditioning(emptycap) |
|
|
|
|
|
for n in range(1): |
|
|
for k,v in prompt.items(): |
|
|
prompt[k] = self.num_samples * [v] |
|
|
c = self.model.get_learned_conditioning(prompt) |
|
|
if self.channel_dim>0: |
|
|
shape = [self.channel_dim, 20, 312] |
|
|
else: |
|
|
shape = [20, 312] |
|
|
samples_ddim, _ = self.sampler.sample(S=self.ddim_steps, |
|
|
conditioning=c, |
|
|
batch_size=self.num_samples, |
|
|
shape=shape, |
|
|
verbose=False, |
|
|
guidance_scale=self.scale, |
|
|
original_inference_steps=self.original_inference_steps |
|
|
) |
|
|
x_samples_ddim = self.model.decode_first_stage(samples_ddim) |
|
|
for idx,spec in enumerate(x_samples_ddim): |
|
|
spec = spec.squeeze(0).cpu().numpy() |
|
|
record_dict = {'caption':prompt['ori_caption'][0]} |
|
|
if self.save_mel: |
|
|
mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy') |
|
|
np.save(mel_path,spec) |
|
|
record_dict['mel_path'] = mel_path |
|
|
if self.save_wav: |
|
|
wav = self.vocoder.vocode(spec) |
|
|
wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav') |
|
|
soundfile.write(wav_path, wav, 16000) |
|
|
record_dict['audio_path'] = wav_path |
|
|
record_dicts.append(record_dict) |
|
|
return record_dicts |
|
|
|
|
|
@spaces.GPU(enable_queue=True) |
|
|
def infer(ori_prompt, ddim_steps, num_samples, scale, seed): |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
prompt = dict(ori_caption=ori_prompt,struct_caption=f'<{ori_prompt}& all>') |
|
|
|
|
|
|
|
|
config = OmegaConf.load("configs/audiolcm.yaml") |
|
|
|
|
|
|
|
|
|
|
|
model = load_model_from_config(config, "./model/000184.ckpt") |
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
model = model.to(device) |
|
|
|
|
|
sampler = LCMSampler(model) |
|
|
|
|
|
os.makedirs("results/test", exist_ok=True) |
|
|
|
|
|
vocoder = VocoderBigVGAN("./model/vocoder",device) |
|
|
|
|
|
|
|
|
generator = GenSamples(sampler,model,"results/test",vocoder,save_mel = False,save_wav = True, original_inference_steps=config.model.params.num_ddim_timesteps, ddim_steps=ddim_steps, scale=scale, num_samples=num_samples) |
|
|
csv_dicts = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
with model.ema_scope(): |
|
|
wav_name = f'{prompt["ori_caption"].strip().replace(" ", "-")}' |
|
|
generator.gen_test_sample(prompt,wav_name=wav_name) |
|
|
|
|
|
print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.") |
|
|
return "results/test/"+wav_name+"_0.wav" |
|
|
|
|
|
def my_inference_function(text_prompt, ddim_steps, num_samples, scale, seed): |
|
|
file_path = infer(text_prompt, ddim_steps, num_samples, scale, seed) |
|
|
return file_path |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Row(): |
|
|
gr.Markdown("## AudioLCM:Text-to-Audio Generation with Latent Consistency Models") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
prompt = gr.Textbox(label="Prompt: Input your text here. ") |
|
|
run_button = gr.Button() |
|
|
|
|
|
with gr.Accordion("Advanced options", open=False): |
|
|
num_samples = gr.Slider( |
|
|
label="Select from audios num.This number control the number of candidates \ |
|
|
(e.g., generate three audios and choose the best to show you). A Larger value usually lead to \ |
|
|
better quality with heavier computation", minimum=1, maximum=10, value=1, step=1) |
|
|
ddim_steps = gr.Slider(label="ddim_steps", minimum=1, |
|
|
maximum=50, value=2, step=1) |
|
|
scale = gr.Slider( |
|
|
label="Guidance Scale:(Large => more relevant to text but the quality may drop)", minimum=0.1, maximum=8.0, value=5.0, step=0.1 |
|
|
) |
|
|
seed = gr.Slider( |
|
|
label="Seed:Change this value (any integer number) will lead to a different generation result.", |
|
|
minimum=0, |
|
|
maximum=2147483647, |
|
|
step=1, |
|
|
value=44, |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
outaudio = gr.Audio() |
|
|
|
|
|
run_button.click(fn=my_inference_function, inputs=[ |
|
|
prompt,ddim_steps, num_samples, scale, seed], outputs=[outaudio]) |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Examples( |
|
|
examples = [['An engine revving and then tires squealing',2,1,5,55],['A group of people laughing followed by farting',2,1,5,55], |
|
|
['Duck quacking repeatedly',2,1,5,88],['A man speaks as birds chirp and dogs bark',2,1,5,55],['Continuous snoring of a person',2,1,5,55]], |
|
|
inputs = [prompt,ddim_steps, num_samples, scale, seed], |
|
|
outputs = [outaudio] |
|
|
) |
|
|
with gr.Column(): |
|
|
pass |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|