final_proj_demo / app.py
Fabrice-TIERCELIN's picture
This Pull Request fixes the space by skipping example caching
6ef81c8 verified
raw
history blame
5.17 kB
import torch
import torchaudio
from einops import rearrange
import gradio as gr
import spaces
import os
import random
import uuid
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools import get_pretrained_model
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
# Download model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
model = model.to(device)
return model, sample_rate, sample_size
@spaces.GPU(duration=120)
def inference(audio_path, prompt ="drums beats with snares", noise_level = 2.7):
# Fetch the Hugging Face token from the environment variable
hf_token = os.getenv('HF_TOKEN')
print(f"Hugging Face token: {hf_token}")
print(f"audio path: {audio_path}")
model, sample_rate, sample_size = load_model()
print(f"sample size is: {sample_size} and sample rate is: {sample_rate}.")
# Set up text and timing conditioning
conditioning = [{
"prompt": "electronic sound with fast and intensive drums",
"seconds_start": 0,
"seconds_total": 30
}]
# import random
diffusion_steps = [100]
float_values = [2.2, 2.6, 3.0, 3.4]
# float_values = [round(random.uniform(2.2, 4), 2) for _ in range(20)
len_in_sec = 30
our_sample_size = sample_rate*len_in_sec
with torch.no_grad():
# for example in range(len(data)):
print(f"prompt: {prompt}")
conditioning[0]["prompt"] = prompt
for i in range(len(diffusion_steps)):
steps = diffusion_steps[i]
print(f"number of steps: {steps}")
# for j in range(len(float_values)):
# noise_level = float_values[j]
print(f"Noise level is: {noise_level}")
audio, sr = torchaudio.load(audio_path)
output = generate_diffusion_cond(
model,
steps=steps,
cfg_scale=7,
conditioning=conditioning,
sample_size=our_sample_size,
sigma_min=0.3,
sigma_max=500,
sampler_type="dpmpp-3m-sde",
device=device,
init_audio=(sr, audio),
init_noise_level=noise_level,
# use_init = True,
)
# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")
print("rearranged the output into a single sequence")
# Peak normalize, clip, convert to int16, and save to file
output = (
output.to(torch.float32)
.div(torch.max(torch.abs(output)))
.clamp(-1, 1)
.mul(32767)
.to(torch.int16)
.cpu()
)
print("Normalized the output, clip and convert to int16")
# Generate a unique filename for the output
unique_filename = f"output_{uuid.uuid4().hex}.mp3"
print(f"Saving audio to file: {unique_filename}")
torchaudio.save(unique_filename, output, sample_rate)
print(f"saved to filename {unique_filename}")
return unique_filename
interface = gr.Interface(
fn=inference,
inputs=[
# gr.UploadButton(label="Audio without drums",file_types=['mp3']),
gr.Audio(type="filepath", label="Audio without drums"),
gr.Textbox(label="Text prompt", placeholder="Enter your text prompt here"),
gr.Slider(2.5, 3.5, step=0.1, value=2.7, label="Noise Level", info="Choose between 2.5 and 3.5"),
],
outputs=gr.Audio(type="filepath", label="Generated Audio"),
title="Stable Audio Generator",
description="Generate variable-length stereo audio at 44.1kHz from text prompts using Stable Audio Open 1.0.",
examples=[
[
"the_chosen_ones/085838/no_drums.mp3", # Audio without drums
"A techno song with fast, outer space-themed drum beats.", # Text prompt
2.7 # Noise Level
],
[
"the_chosen_ones/103522/no_drums.mp3", # Audio without drums
"A slow country melody accompanied by drum beats.", # Text prompt
2.7 # Noise Level
],
[
"the_chosen_ones/103800/no_drums.mp3", # Audio without drums
"A rap song featuring slow, groovy drums with intermittent snares.", # Text prompt
2.7 # Noise Level
],
[
"the_chosen_ones/103808/no_drums.mp3", # Audio without drums
"Smooth, slow piano grooves paired with intense, rapid drum rhythms.", # Text prompt
2.7 # Noise Level
],
[
"the_chosen_ones/134796/no_drums.mp3", # Audio without drums
"A rap track with rapid drum beats and snares.", # Text prompt
2.7 # Noise Level
]
],
cache_examples=False
)
model, sample_rate, sample_size = load_model()
interface.launch()