File size: 5,165 Bytes
4107d10
 
 
443dfa0
b4590cd
 
 
 
 
02cb994
4107d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85970d3
37f416d
041c4aa
 
 
 
f6773ce
041c4aa
c371ba4
041c4aa
4107d10
 
 
 
 
 
 
 
87bb34a
4107d10
 
 
 
 
 
e16a5b5
 
 
 
 
 
37f416d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4107d10
 
 
8a1e5cf
443dfa0
 
 
b75f17a
ec1529a
37f416d
 
443dfa0
453e379
443dfa0
 
f6773ce
 
a45fa90
37f416d
 
a45fa90
 
 
37f416d
 
a45fa90
 
 
37f416d
 
a45fa90
 
 
37f416d
 
a45fa90
 
 
37f416d
 
f6773ce
b75f17a
 
6fb16db
443dfa0
db31a2c
443dfa0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torchaudio
from einops import rearrange
import gradio as gr
import spaces
import os
import random
import uuid


from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools import get_pretrained_model

device = "cuda" if torch.cuda.is_available() else "cpu"


def load_model():

    # Download model
    model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
    sample_rate = model_config["sample_rate"]
    sample_size = model_config["sample_size"]

    model = model.to(device)
    return model, sample_rate, sample_size


@spaces.GPU(duration=120)
def inference(audio_path, prompt ="drums beats with snares", noise_level = 2.7):
    # Fetch the Hugging Face token from the environment variable
    hf_token = os.getenv('HF_TOKEN')
    print(f"Hugging Face token: {hf_token}")

    print(f"audio path: {audio_path}")
    model, sample_rate, sample_size = load_model()
    print(f"sample size is: {sample_size} and sample rate is: {sample_rate}.")
    
    # Set up text and timing conditioning
    conditioning = [{
        "prompt": "electronic sound with fast and intensive drums",
        "seconds_start": 0,
        "seconds_total": 30
    }]
    
    # import random
    diffusion_steps = [100]

    float_values = [2.2, 2.6, 3.0, 3.4]
    # float_values = [round(random.uniform(2.2, 4), 2) for _ in range(20)
    len_in_sec = 30
    our_sample_size = sample_rate*len_in_sec
    with torch.no_grad():
        # for example in range(len(data)):
        print(f"prompt: {prompt}")
        conditioning[0]["prompt"] = prompt
        for i in range(len(diffusion_steps)):
            steps = diffusion_steps[i]
            print(f"number of steps: {steps}")
            # for j in range(len(float_values)):
            # noise_level = float_values[j]
            print(f"Noise level is: {noise_level}")
            audio, sr = torchaudio.load(audio_path)
            output = generate_diffusion_cond(
                model,
                steps=steps,
                cfg_scale=7,
                conditioning=conditioning,
                sample_size=our_sample_size,
                sigma_min=0.3,
                sigma_max=500,
                sampler_type="dpmpp-3m-sde",
                device=device,
                init_audio=(sr, audio),
                init_noise_level=noise_level,
                # use_init = True,
            )

            # Rearrange audio batch to a single sequence
            output = rearrange(output, "b d n -> d (b n)")
            print("rearranged the output into a single sequence")
            
            # Peak normalize, clip, convert to int16, and save to file
            output = (
                output.to(torch.float32)
                .div(torch.max(torch.abs(output)))
                .clamp(-1, 1)
                .mul(32767)
                .to(torch.int16)
                .cpu()
            )
            print("Normalized the output, clip and convert to int16")
            
             # Generate a unique filename for the output
            unique_filename = f"output_{uuid.uuid4().hex}.mp3"
            print(f"Saving audio to file: {unique_filename}")
            torchaudio.save(unique_filename, output, sample_rate)
            print(f"saved to filename {unique_filename}")
            
            return unique_filename




interface = gr.Interface(
    fn=inference,
    inputs=[
        # gr.UploadButton(label="Audio without drums",file_types=['mp3']),
        gr.Audio(type="filepath", label="Audio without drums"),
        gr.Textbox(label="Text prompt", placeholder="Enter your text prompt here"),
        gr.Slider(2.5, 3.5, step=0.1, value=2.7, label="Noise Level", info="Choose between 2.5 and 3.5"),
    ],
    outputs=gr.Audio(type="filepath", label="Generated Audio"),
    title="Stable Audio Generator",
    description="Generate variable-length stereo audio at 44.1kHz from text prompts using Stable Audio Open 1.0.",
    examples=[
        [
            "the_chosen_ones/085838/no_drums.mp3",  # Audio without drums
            "A techno song with fast, outer space-themed drum beats.",  # Text prompt
            2.7 # Noise Level
        ],
        [
            "the_chosen_ones/103522/no_drums.mp3",  # Audio without drums
            "A slow country melody accompanied by drum beats.",  # Text prompt
            2.7 # Noise Level
        ],
        [
            "the_chosen_ones/103800/no_drums.mp3",  # Audio without drums
            "A rap song featuring slow, groovy drums with intermittent snares.",  # Text prompt
            2.7 # Noise Level
        ],
        [
            "the_chosen_ones/103808/no_drums.mp3",  # Audio without drums
            "Smooth, slow piano grooves paired with intense, rapid drum rhythms.",  # Text prompt
            2.7 # Noise Level
        ],
        [
            "the_chosen_ones/134796/no_drums.mp3",  # Audio without drums
            "A rap track with rapid drum beats and snares.",  # Text prompt
            2.7 # Noise Level
        ]
    ], 
    cache_examples=True
)

model, sample_rate, sample_size = load_model()
interface.launch()