File size: 7,226 Bytes
810f719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#!/usr/bin/env python3
"""
Text-to-Music Gradio 6 Demo using Riffusion
Generates music from text prompts via spectrogram diffusion.
"""

import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
import numpy as np
import io
import os

from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.audio_utils import audio_buffer_to_wav, normalize_audio

# Global model cache
_pipe = None
_converter = None


def get_pipeline():
    """Lazy load the Riffusion pipeline."""
    global _pipe
    if _pipe is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Loading Riffusion model on {device}...")
        _pipe = StableDiffusionPipeline.from_pretrained(
            "riffusion/riffusion-model-v1",
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        )
        _pipe = _pipe.to(device)
        _pipe.enable_attention_slicing()
        print("Model loaded!")
    return _pipe


def get_converter():
    """Lazy load the spectrogram converter."""
    global _converter
    if _converter is None:
        _converter = SpectrogramImageConverter()
    return _converter


def generate_music(prompt: str, duration: float, bpm: float, seed: int = None, progress=gr.Progress()):
    """
    Generate music from text prompt using Riffusion.
    
    Args:
        prompt: Text description of desired music
        duration: Duration in seconds (clamped to model limits)
        bpm: Beats per minute (affects spectrogram parameters)
        seed: Random seed for reproducibility
    
    Returns:
        Tuple of (audio_path, spectrogram_path) for Gradio
    """
    # Clamp duration to reasonable range (Riffusion works best ~5-10s)
    duration = max(2.0, min(duration, 10.0))
    
    # Adjust prompt with BPM hint if provided
    full_prompt = f"{prompt}, {int(bpm)} bpm" if bpm > 0 else prompt
    
    pipe = get_pipeline()
    converter = get_converter()
    
    # Set seed for reproducibility
    if seed is None or seed < 0:
        seed = np.random.randint(0, 2**32)
    generator = torch.Generator(device=pipe.device).manual_seed(seed)
    
    print(f"Generating: '{full_prompt}' ({duration}s @ {bpm} BPM, seed={seed})")
    
    progress(0.1, desc="Generating spectrogram...")
    
    # Generate spectrogram image
    # Riffusion generates 512x512 spectrograms ~5 seconds of audio
    image = pipe(
        full_prompt,
        num_inference_steps=50,
        guidance_scale=7.5,
        generator=generator,
        height=512,
        width=512,
    ).images[0]
    
    progress(0.6, desc="Converting to audio...")
    
    # Convert spectrogram to audio
    audio = converter.spectrogram_to_audio(image, duration=duration)
    audio = normalize_audio(audio)
    
    progress(0.9, desc="Saving outputs...")
    
    # Save outputs
    os.makedirs("outputs", exist_ok=True)
    base_name = f"output_{seed % 10000:04d}"
    audio_path = f"outputs/{base_name}.wav"
    spec_path = f"outputs/{base_name}_spectrogram.png"
    
    # Save audio
    wav_buffer = audio_buffer_to_wav(audio, sample_rate=converter.sample_rate)
    with open(audio_path, "wb") as f:
        f.write(wav_buffer.getvalue())
    
    # Save spectrogram for visualization
    image.save(spec_path)
    
    progress(1.0, desc="Done!")
    print(f"Saved: {audio_path}")
    return audio_path, spec_path


# Gradio 6 - NO parameters in gr.Blocks() constructor!
with gr.Blocks() as demo:
    # Header with anycoder link
    gr.Markdown("""
    # 🎵 Text-to-Music Generator
    
    Generate music from text descriptions using **Riffusion** - 
    a Stable Diffusion model trained on spectrograms.
    
    [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            prompt_input = gr.Textbox(
                label="Music Description",
                placeholder="Describe the music you want to hear...",
                value="smooth jazz saxophone solo, relaxing, nighttime",
                lines=2,
            )
            
            with gr.Row():
                duration_slider = gr.Slider(
                    minimum=2.0,
                    maximum=10.0,
                    value=5.0,
                    step=0.5,
                    label="Duration (seconds)",
                )
                bpm_slider = gr.Slider(
                    minimum=60,
                    maximum=180,
                    value=120,
                    step=5,
                    label="Tempo (BPM)",
                )
            
            seed_input = gr.Number(
                label="Seed (-1 for random)",
                value=-1,
                precision=0,
            )
            
            generate_btn = gr.Button("🎹 Generate Music", variant="primary")
        
        with gr.Column(scale=1):
            audio_output = gr.Audio(
                label="Generated Music",
                type="filepath",
            )
            spec_output = gr.Image(
                label="Spectrogram Visualization",
                type="filepath",
            )
    
    # Examples
    gr.Examples(
        examples=[
            ["piano ballad, emotional, cinematic", 6.0, 70, -1],
            ["funky bass guitar groove, 1970s style", 5.0, 110, -1],
            ["ethereal ambient pads, space atmosphere", 8.0, 60, -1],
            ["heavy metal guitar riff, aggressive", 4.0, 140, -1],
            ["classical violin concerto, elegant", 7.0, 90, -1],
        ],
        inputs=[prompt_input, duration_slider, bpm_slider, seed_input],
        outputs=[audio_output, spec_output],
        fn=generate_music,
        cache_examples=False,
    )
    
    with gr.Accordion("How it works", open=False):
        gr.Markdown("""
        ### How it works
        
        1. Your text prompt is used to generate a **spectrogram image** via Stable Diffusion
        2. The spectrogram is converted back to **audio waveforms** using the Short-Time Fourier Transform (STFT)
        3. The resulting audio is normalized and returned as a playable WAV file
        
        *Note: First generation will download the model (~1.5GB).*
        """)
    
    # Event handlers - Gradio 6 uses api_visibility
    generate_btn.click(
        fn=generate_music,
        inputs=[prompt_input, duration_slider, bpm_slider, seed_input],
        outputs=[audio_output, spec_output],
        api_visibility="public",
    )


# Gradio 6 - ALL app parameters go in launch()!
demo.launch(
    theme=gr.themes.Soft(
        primary_hue="indigo",
        secondary_hue="blue",
        neutral_hue="slate",
        font=gr.themes.GoogleFont("Inter"),
        text_size="lg",
        spacing_size="lg",
        radius_size="md",
    ).set(
        button_primary_background_fill="*primary_600",
        button_primary_background_fill_hover="*primary_700",
        block_title_text_weight="600",
    ),
    footer_links=[
        {"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"},
        {"label": "Gradio", "url": "https://gradio.app"},
    ],
    server_name="0.0.0.0",
    server_port=7860,
    show_error=True,
)