Spaces:
Running
Running
added conditional diffusion, descriptions, and examples
Browse files
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
# Imports
|
| 2 |
import gradio as gr
|
|
|
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import torch
|
| 5 |
import torchaudio
|
|
@@ -109,19 +110,40 @@ def load_checkpoint(model, ckpt_path) -> None:
|
|
| 109 |
|
| 110 |
|
| 111 |
# Generate Samples
|
| 112 |
-
def generate_samples(model_name, num_samples, num_steps, duration=32768):
|
| 113 |
# load_checkpoint
|
| 114 |
ckpt_path = models[model_name]
|
| 115 |
load_checkpoint(model, ckpt_path)
|
| 116 |
-
|
| 117 |
if num_samples > 1:
|
| 118 |
-
duration = duration / 2
|
| 119 |
|
|
|
|
| 120 |
with torch.no_grad():
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
for i in range(num_samples):
|
| 123 |
-
noise = torch.
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
# concatenate all samples:
|
| 127 |
all_samples = torch.concat((all_samples, generated_sample), dim=1)
|
|
@@ -133,6 +155,8 @@ def generate_samples(model_name, num_samples, num_steps, duration=32768):
|
|
| 133 |
|
| 134 |
return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
|
| 135 |
|
|
|
|
|
|
|
| 136 |
# load model & configs
|
| 137 |
sr = 44100 # sampling rate
|
| 138 |
config_path = "saved_models/config.yaml" # config path
|
|
@@ -147,19 +171,70 @@ models = {
|
|
| 147 |
"Percussion": "saved_models/percussion/percussion_v0.ckpt"
|
| 148 |
}
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Imports
|
| 2 |
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import torch
|
| 6 |
import torchaudio
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
# Generate Samples
|
| 113 |
+
def generate_samples(model_name, num_samples, num_steps, init_audio=None, noise_level=0.7, duration=32768):
|
| 114 |
# load_checkpoint
|
| 115 |
ckpt_path = models[model_name]
|
| 116 |
load_checkpoint(model, ckpt_path)
|
| 117 |
+
|
| 118 |
if num_samples > 1:
|
| 119 |
+
duration = int(duration / 2)
|
| 120 |
|
| 121 |
+
# Generate samples
|
| 122 |
with torch.no_grad():
|
| 123 |
+
if init_audio:
|
| 124 |
+
# load audio sample
|
| 125 |
+
audio_sample = torch.tensor(init_audio[1].T, dtype=torch.float32).unsqueeze(0).to(model.device)
|
| 126 |
+
audio_sample = audio_sample / torch.max(torch.abs(audio_sample)) # normalize init_audio
|
| 127 |
+
|
| 128 |
+
# Trim audio
|
| 129 |
+
og_shape = audio_sample.shape
|
| 130 |
+
if duration < og_shape[2]:
|
| 131 |
+
audio_sample = audio_sample[:,:,:duration]
|
| 132 |
+
elif duration > og_shape[2]:
|
| 133 |
+
# Pad tensor with zeros to match sample length
|
| 134 |
+
audio_sample = torch.concat((audio_sample, torch.zeros(og_shape[0], og_shape[1], duration - og_shape[2]).to(model.device)), dim=2)
|
| 135 |
+
|
| 136 |
+
else:
|
| 137 |
+
audio_sample = torch.zeros((1, 2, int(duration)), device=model.device)
|
| 138 |
+
noise_level = 1.0
|
| 139 |
+
|
| 140 |
+
all_samples = torch.zeros(2, 0)
|
| 141 |
for i in range(num_samples):
|
| 142 |
+
noise = torch.randn_like(audio_sample, device=model.device) * noise_level # [batch_size, in_channels, length]
|
| 143 |
+
audio = (audio_sample * abs(1-noise_level)) + noise # add noise
|
| 144 |
+
|
| 145 |
+
# generate samples
|
| 146 |
+
generated_sample = model.model_ema.ema_model.sample(audio, num_steps=num_steps).squeeze(0).cpu() # Suggested num_steps 10-100
|
| 147 |
|
| 148 |
# concatenate all samples:
|
| 149 |
all_samples = torch.concat((all_samples, generated_sample), dim=1)
|
|
|
|
| 155 |
|
| 156 |
return (sr, all_samples.cpu().detach().numpy().T), fig # (sample rate, audio), plot
|
| 157 |
|
| 158 |
+
|
| 159 |
+
# Define Constants & initialize model
|
| 160 |
# load model & configs
|
| 161 |
sr = 44100 # sampling rate
|
| 162 |
config_path = "saved_models/config.yaml" # config path
|
|
|
|
| 171 |
"Percussion": "saved_models/percussion/percussion_v0.ckpt"
|
| 172 |
}
|
| 173 |
|
| 174 |
+
intro = """
|
| 175 |
+
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 6px;">
|
| 176 |
+
Tiny Audio Diffusion
|
| 177 |
+
</h1>
|
| 178 |
+
<h3 style="font-weight: 600; text-align: center;">
|
| 179 |
+
Christopher Landschoot - Audio waveform diffusion built to run on consumer-grade hardware (<2GB VRAM)
|
| 180 |
+
</h3>
|
| 181 |
+
<h4 style="text-align: center; margin-bottom: 6px;">
|
| 182 |
+
<a href="https://github.com/crlandsc/tiny-audio-diffusion" style="text-decoration: underline;" target="_blank">GitHub Repo</a>
|
| 183 |
+
| <a href="https://www.youtube.com/watch?v=m6Eh2srtTro&t=3s" style="text-decoration: underline;" target="_blank">Repo Tutorial Video</a>
|
| 184 |
+
| <a href="https://medium.com/towards-data-science/tiny-audio-diffusion-ddc19e90af9b" style="text-decoration: underline;" target="_blank">Towards Data Science Article</a>
|
| 185 |
+
</h4>
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
with gr.Blocks() as demo:
|
| 190 |
+
# Layout
|
| 191 |
+
gr.HTML(intro)
|
| 192 |
+
|
| 193 |
+
with gr.Row(equal_height=False):
|
| 194 |
+
with gr.Column():
|
| 195 |
+
# Inputs
|
| 196 |
+
model_name = gr.Dropdown(choices=list(models.keys()), value=list(models.keys())[3], label="Model")
|
| 197 |
+
num_samples = gr.Slider(1, 25, step=1, label="Number of Samples to Generate", value=3)
|
| 198 |
+
num_steps = gr.Slider(1, 100, step=1, label="Number of Diffusion Steps", value=15)
|
| 199 |
+
|
| 200 |
+
# Conditioning Audio Input
|
| 201 |
+
with gr.Accordion("Input Audio (optional)", open=False):
|
| 202 |
+
init_audio_description = gr.HTML('Upload an audio file to perform conditional "style transfer" diffusion.<br>Leaving input audio blank results in unconditional generation.')
|
| 203 |
+
init_audio = gr.Audio(label="Input Audio Sample")
|
| 204 |
+
init_audio_noise = gr.Slider(0, 1, step=0.01, label="Noise to add to input audio", value=0.70)#, visible=True)
|
| 205 |
+
|
| 206 |
+
# Examples
|
| 207 |
+
gr.Examples(
|
| 208 |
+
examples=[
|
| 209 |
+
os.path.join(os.path.dirname(__file__), "samples", "guitar.wav"),
|
| 210 |
+
os.path.join(os.path.dirname(__file__), "samples", "snare.wav"),
|
| 211 |
+
os.path.join(os.path.dirname(__file__), "samples", "kick.wav"),
|
| 212 |
+
os.path.join(os.path.dirname(__file__), "samples", "hihat.wav")
|
| 213 |
+
],
|
| 214 |
+
inputs=init_audio,
|
| 215 |
+
label="Example Audio Inputs"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Buttons
|
| 219 |
+
with gr.Row():
|
| 220 |
+
with gr.Column():
|
| 221 |
+
clear_button = gr.Button(value="Reset All")
|
| 222 |
+
with gr.Column():
|
| 223 |
+
generate_btn = gr.Button("Generate Samples!")
|
| 224 |
+
|
| 225 |
+
with gr.Column():
|
| 226 |
+
# Outputs
|
| 227 |
+
output_audio = gr.Audio(label="Generated Audio Sample")
|
| 228 |
+
output_plot = gr.Plot(label="Generated Audio Spectrogram")
|
| 229 |
+
|
| 230 |
+
# Functionality
|
| 231 |
+
# Generate samples
|
| 232 |
+
generate_btn.click(fn=generate_samples, inputs=[model_name, num_samples, num_steps, init_audio, init_audio_noise], outputs=[output_audio, output_plot])
|
| 233 |
+
|
| 234 |
+
# clear_button button to reset everything
|
| 235 |
+
clear_button.click(fn=lambda: [3, 15, None, 0.70, None, None], outputs=[num_samples, num_steps, init_audio, init_audio_noise, output_audio, output_plot])
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
demo.launch()
|