final_proj_demo / app.py
YuvalShaffir's picture
Update app.py
37f416d verified
import torch
import torchaudio
from einops import rearrange
import gradio as gr
import spaces
import os
import random
import uuid
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools import get_pretrained_model
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
# Download model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
model = model.to(device)
return model, sample_rate, sample_size
@spaces.GPU(duration=120)
def inference(audio_path, prompt ="drums beats with snares", noise_level = 2.7):
# Fetch the Hugging Face token from the environment variable
hf_token = os.getenv('HF_TOKEN')
print(f"Hugging Face token: {hf_token}")
print(f"audio path: {audio_path}")
model, sample_rate, sample_size = load_model()
print(f"sample size is: {sample_size} and sample rate is: {sample_rate}.")
# Set up text and timing conditioning
conditioning = [{
"prompt": "electronic sound with fast and intensive drums",
"seconds_start": 0,
"seconds_total": 30
}]
# import random
diffusion_steps = [100]
float_values = [2.2, 2.6, 3.0, 3.4]
# float_values = [round(random.uniform(2.2, 4), 2) for _ in range(20)
len_in_sec = 30
our_sample_size = sample_rate*len_in_sec
with torch.no_grad():
# for example in range(len(data)):
print(f"prompt: {prompt}")
conditioning[0]["prompt"] = prompt
for i in range(len(diffusion_steps)):
steps = diffusion_steps[i]
print(f"number of steps: {steps}")
# for j in range(len(float_values)):
# noise_level = float_values[j]
print(f"Noise level is: {noise_level}")
audio, sr = torchaudio.load(audio_path)
output = generate_diffusion_cond(
model,
steps=steps,
cfg_scale=7,
conditioning=conditioning,
sample_size=our_sample_size,
sigma_min=0.3,
sigma_max=500,
sampler_type="dpmpp-3m-sde",
device=device,
init_audio=(sr, audio),
init_noise_level=noise_level,
# use_init = True,
)
# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")
print("rearranged the output into a single sequence")
# Peak normalize, clip, convert to int16, and save to file
output = (
output.to(torch.float32)
.div(torch.max(torch.abs(output)))
.clamp(-1, 1)
.mul(32767)
.to(torch.int16)
.cpu()
)
print("Normalized the output, clip and convert to int16")
# Generate a unique filename for the output
unique_filename = f"output_{uuid.uuid4().hex}.mp3"
print(f"Saving audio to file: {unique_filename}")
torchaudio.save(unique_filename, output, sample_rate)
print(f"saved to filename {unique_filename}")
return unique_filename
interface = gr.Interface(
fn=inference,
inputs=[
# gr.UploadButton(label="Audio without drums",file_types=['mp3']),
gr.Audio(type="filepath", label="Audio without drums"),
gr.Textbox(label="Text prompt", placeholder="Enter your text prompt here"),
gr.Slider(2.5, 3.5, step=0.1, value=2.7, label="Noise Level", info="Choose between 2.5 and 3.5"),
],
outputs=gr.Audio(type="filepath", label="Generated Audio"),
title="Stable Audio Generator",
description="Generate variable-length stereo audio at 44.1kHz from text prompts using Stable Audio Open 1.0.",
examples=[
[
"the_chosen_ones/085838/no_drums.mp3", # Audio without drums
"A techno song with fast, outer space-themed drum beats.", # Text prompt
2.7 # Noise Level
],
[
"the_chosen_ones/103522/no_drums.mp3", # Audio without drums
"A slow country melody accompanied by drum beats.", # Text prompt
2.7 # Noise Level
],
[
"the_chosen_ones/103800/no_drums.mp3", # Audio without drums
"A rap song featuring slow, groovy drums with intermittent snares.", # Text prompt
2.7 # Noise Level
],
[
"the_chosen_ones/103808/no_drums.mp3", # Audio without drums
"Smooth, slow piano grooves paired with intense, rapid drum rhythms.", # Text prompt
2.7 # Noise Level
],
[
"the_chosen_ones/134796/no_drums.mp3", # Audio without drums
"A rap track with rapid drum beats and snares.", # Text prompt
2.7 # Noise Level
]
],
cache_examples=True
)
model, sample_rate, sample_size = load_model()
interface.launch()