Re-included multiple candidates to improve quality
Browse files
app.py
CHANGED
|
@@ -5,20 +5,27 @@ from diffusers import AudioLDMPipeline
|
|
| 5 |
|
| 6 |
from transformers import AutoProcessor, ClapModel
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# load AudioLDM Diffuser Pipeline
|
| 13 |
pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm-m-full", torch_dtype=torch_dtype).to(device)
|
| 14 |
pipe.unet = torch.compile(pipe.unet)
|
| 15 |
|
| 16 |
-
#
|
|
|
|
|
|
|
| 17 |
|
| 18 |
generator = torch.Generator(device)
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
def text2audio(text, negative_prompt, duration, guidance_scale, random_seed):
|
| 22 |
if text is None:
|
| 23 |
raise gr.Error("Please provide a text input.")
|
| 24 |
|
|
@@ -27,14 +34,27 @@ def text2audio(text, negative_prompt, duration, guidance_scale, random_seed):
|
|
| 27 |
audio_length_in_s=duration,
|
| 28 |
guidance_scale=guidance_scale,
|
| 29 |
negative_prompt=negative_prompt,
|
| 30 |
-
num_waveforms_per_prompt=1,
|
| 31 |
generator=generator.manual_seed(int(random_seed)),
|
| 32 |
)["audios"]
|
| 33 |
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
return gr.make_waveform((16000, waveform), bg_image="bg.png")
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
# duplicate CSS config
|
| 39 |
|
| 40 |
css = """
|
|
@@ -171,13 +191,21 @@ with iface:
|
|
| 171 |
label="Guidance scale",
|
| 172 |
info="Large => better quality and relevancy to text; Small => better diversity",
|
| 173 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
outputs = gr.Video(label="Output", elem_id="output-video")
|
| 176 |
btn = gr.Button("Submit").style(full_width=True)
|
| 177 |
|
| 178 |
btn.click(
|
| 179 |
text2audio,
|
| 180 |
-
inputs=[textbox, negative_textbox, duration, guidance_scale, seed],
|
| 181 |
outputs=[outputs],
|
| 182 |
)
|
| 183 |
|
|
|
|
| 5 |
|
| 6 |
from transformers import AutoProcessor, ClapModel
|
| 7 |
|
| 8 |
+
# cuda code from AudioLDM's original app.py if using GPU
|
| 9 |
+
# allows support for CPU
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
device = "cuda"
|
| 12 |
+
torch_dtype = torch.float16
|
| 13 |
+
else:
|
| 14 |
+
device = "cpu"
|
| 15 |
+
torch_dtype = torch.float32
|
| 16 |
|
| 17 |
# load AudioLDM Diffuser Pipeline
|
| 18 |
pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm-m-full", torch_dtype=torch_dtype).to(device)
|
| 19 |
pipe.unet = torch.compile(pipe.unet)
|
| 20 |
|
| 21 |
+
# include CLAP model because it improves quality
|
| 22 |
+
clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
|
| 23 |
+
processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")
|
| 24 |
|
| 25 |
generator = torch.Generator(device)
|
| 26 |
|
| 27 |
+
# from audioldm app.py
|
| 28 |
+
def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates):
|
| 29 |
if text is None:
|
| 30 |
raise gr.Error("Please provide a text input.")
|
| 31 |
|
|
|
|
| 34 |
audio_length_in_s=duration,
|
| 35 |
guidance_scale=guidance_scale,
|
| 36 |
negative_prompt=negative_prompt,
|
| 37 |
+
num_waveforms_per_prompt=n_candidates if n_candidates else 1,
|
| 38 |
generator=generator.manual_seed(int(random_seed)),
|
| 39 |
)["audios"]
|
| 40 |
|
| 41 |
+
if waveforms.shape[0] > 1:
|
| 42 |
+
waveform = score_waveforms(text, waveforms)
|
| 43 |
+
else:
|
| 44 |
+
waveform = waveforms[0]
|
| 45 |
|
| 46 |
return gr.make_waveform((16000, waveform), bg_image="bg.png")
|
| 47 |
|
| 48 |
+
def score_waveforms(text, waveforms):
|
| 49 |
+
inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
|
| 50 |
+
inputs = {key: inputs[key].to(device) for key in inputs}
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
logits_per_text = clap_model(**inputs).logits_per_text # this is the audio-text similarity score
|
| 53 |
+
probs = logits_per_text.softmax(dim=-1) # we can take the softmax to get the label probabilities
|
| 54 |
+
most_probable = torch.argmax(probs) # and now select the most likely audio waveform
|
| 55 |
+
waveform = waveforms[most_probable]
|
| 56 |
+
return waveform
|
| 57 |
+
|
| 58 |
# duplicate CSS config
|
| 59 |
|
| 60 |
css = """
|
|
|
|
| 191 |
label="Guidance scale",
|
| 192 |
info="Large => better quality and relevancy to text; Small => better diversity",
|
| 193 |
)
|
| 194 |
+
n_candidates = gr.Slider(
|
| 195 |
+
1,
|
| 196 |
+
3,
|
| 197 |
+
value=3,
|
| 198 |
+
step=1,
|
| 199 |
+
label="Number waveforms to generate",
|
| 200 |
+
info="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
|
| 201 |
+
)
|
| 202 |
|
| 203 |
outputs = gr.Video(label="Output", elem_id="output-video")
|
| 204 |
btn = gr.Button("Submit").style(full_width=True)
|
| 205 |
|
| 206 |
btn.click(
|
| 207 |
text2audio,
|
| 208 |
+
inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates],
|
| 209 |
outputs=[outputs],
|
| 210 |
)
|
| 211 |
|