Carter-123 commited on
Commit
810f719
·
verified ·
1 Parent(s): a318a83

Update app.py from anycoder

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Text-to-Music Gradio 6 Demo using Riffusion
4
+ Generates music from text prompts via spectrogram diffusion.
5
+ """
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from diffusers import StableDiffusionPipeline
10
+ import numpy as np
11
+ import io
12
+ import os
13
+
14
+ from riffusion.spectrogram_image_converter import SpectrogramImageConverter
15
+ from riffusion.audio_utils import audio_buffer_to_wav, normalize_audio
16
+
17
+ # Global model cache
18
+ _pipe = None
19
+ _converter = None
20
+
21
+
22
+ def get_pipeline():
23
+ """Lazy load the Riffusion pipeline."""
24
+ global _pipe
25
+ if _pipe is None:
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ print(f"Loading Riffusion model on {device}...")
28
+ _pipe = StableDiffusionPipeline.from_pretrained(
29
+ "riffusion/riffusion-model-v1",
30
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
31
+ )
32
+ _pipe = _pipe.to(device)
33
+ _pipe.enable_attention_slicing()
34
+ print("Model loaded!")
35
+ return _pipe
36
+
37
+
38
+ def get_converter():
39
+ """Lazy load the spectrogram converter."""
40
+ global _converter
41
+ if _converter is None:
42
+ _converter = SpectrogramImageConverter()
43
+ return _converter
44
+
45
+
46
+ def generate_music(prompt: str, duration: float, bpm: float, seed: int = None, progress=gr.Progress()):
47
+ """
48
+ Generate music from text prompt using Riffusion.
49
+
50
+ Args:
51
+ prompt: Text description of desired music
52
+ duration: Duration in seconds (clamped to model limits)
53
+ bpm: Beats per minute (affects spectrogram parameters)
54
+ seed: Random seed for reproducibility
55
+
56
+ Returns:
57
+ Tuple of (audio_path, spectrogram_path) for Gradio
58
+ """
59
+ # Clamp duration to reasonable range (Riffusion works best ~5-10s)
60
+ duration = max(2.0, min(duration, 10.0))
61
+
62
+ # Adjust prompt with BPM hint if provided
63
+ full_prompt = f"{prompt}, {int(bpm)} bpm" if bpm > 0 else prompt
64
+
65
+ pipe = get_pipeline()
66
+ converter = get_converter()
67
+
68
+ # Set seed for reproducibility
69
+ if seed is None or seed < 0:
70
+ seed = np.random.randint(0, 2**32)
71
+ generator = torch.Generator(device=pipe.device).manual_seed(seed)
72
+
73
+ print(f"Generating: '{full_prompt}' ({duration}s @ {bpm} BPM, seed={seed})")
74
+
75
+ progress(0.1, desc="Generating spectrogram...")
76
+
77
+ # Generate spectrogram image
78
+ # Riffusion generates 512x512 spectrograms ~5 seconds of audio
79
+ image = pipe(
80
+ full_prompt,
81
+ num_inference_steps=50,
82
+ guidance_scale=7.5,
83
+ generator=generator,
84
+ height=512,
85
+ width=512,
86
+ ).images[0]
87
+
88
+ progress(0.6, desc="Converting to audio...")
89
+
90
+ # Convert spectrogram to audio
91
+ audio = converter.spectrogram_to_audio(image, duration=duration)
92
+ audio = normalize_audio(audio)
93
+
94
+ progress(0.9, desc="Saving outputs...")
95
+
96
+ # Save outputs
97
+ os.makedirs("outputs", exist_ok=True)
98
+ base_name = f"output_{seed % 10000:04d}"
99
+ audio_path = f"outputs/{base_name}.wav"
100
+ spec_path = f"outputs/{base_name}_spectrogram.png"
101
+
102
+ # Save audio
103
+ wav_buffer = audio_buffer_to_wav(audio, sample_rate=converter.sample_rate)
104
+ with open(audio_path, "wb") as f:
105
+ f.write(wav_buffer.getvalue())
106
+
107
+ # Save spectrogram for visualization
108
+ image.save(spec_path)
109
+
110
+ progress(1.0, desc="Done!")
111
+ print(f"Saved: {audio_path}")
112
+ return audio_path, spec_path
113
+
114
+
115
+ # Gradio 6 - NO parameters in gr.Blocks() constructor!
116
+ with gr.Blocks() as demo:
117
+ # Header with anycoder link
118
+ gr.Markdown("""
119
+ # 🎵 Text-to-Music Generator
120
+
121
+ Generate music from text descriptions using **Riffusion** -
122
+ a Stable Diffusion model trained on spectrograms.
123
+
124
+ [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder)
125
+ """)
126
+
127
+ with gr.Row():
128
+ with gr.Column(scale=2):
129
+ prompt_input = gr.Textbox(
130
+ label="Music Description",
131
+ placeholder="Describe the music you want to hear...",
132
+ value="smooth jazz saxophone solo, relaxing, nighttime",
133
+ lines=2,
134
+ )
135
+
136
+ with gr.Row():
137
+ duration_slider = gr.Slider(
138
+ minimum=2.0,
139
+ maximum=10.0,
140
+ value=5.0,
141
+ step=0.5,
142
+ label="Duration (seconds)",
143
+ )
144
+ bpm_slider = gr.Slider(
145
+ minimum=60,
146
+ maximum=180,
147
+ value=120,
148
+ step=5,
149
+ label="Tempo (BPM)",
150
+ )
151
+
152
+ seed_input = gr.Number(
153
+ label="Seed (-1 for random)",
154
+ value=-1,
155
+ precision=0,
156
+ )
157
+
158
+ generate_btn = gr.Button("🎹 Generate Music", variant="primary")
159
+
160
+ with gr.Column(scale=1):
161
+ audio_output = gr.Audio(
162
+ label="Generated Music",
163
+ type="filepath",
164
+ )
165
+ spec_output = gr.Image(
166
+ label="Spectrogram Visualization",
167
+ type="filepath",
168
+ )
169
+
170
+ # Examples
171
+ gr.Examples(
172
+ examples=[
173
+ ["piano ballad, emotional, cinematic", 6.0, 70, -1],
174
+ ["funky bass guitar groove, 1970s style", 5.0, 110, -1],
175
+ ["ethereal ambient pads, space atmosphere", 8.0, 60, -1],
176
+ ["heavy metal guitar riff, aggressive", 4.0, 140, -1],
177
+ ["classical violin concerto, elegant", 7.0, 90, -1],
178
+ ],
179
+ inputs=[prompt_input, duration_slider, bpm_slider, seed_input],
180
+ outputs=[audio_output, spec_output],
181
+ fn=generate_music,
182
+ cache_examples=False,
183
+ )
184
+
185
+ with gr.Accordion("How it works", open=False):
186
+ gr.Markdown("""
187
+ ### How it works
188
+
189
+ 1. Your text prompt is used to generate a **spectrogram image** via Stable Diffusion
190
+ 2. The spectrogram is converted back to **audio waveforms** using the Short-Time Fourier Transform (STFT)
191
+ 3. The resulting audio is normalized and returned as a playable WAV file
192
+
193
+ *Note: First generation will download the model (~1.5GB).*
194
+ """)
195
+
196
+ # Event handlers - Gradio 6 uses api_visibility
197
+ generate_btn.click(
198
+ fn=generate_music,
199
+ inputs=[prompt_input, duration_slider, bpm_slider, seed_input],
200
+ outputs=[audio_output, spec_output],
201
+ api_visibility="public",
202
+ )
203
+
204
+
205
+ # Gradio 6 - ALL app parameters go in launch()!
206
+ demo.launch(
207
+ theme=gr.themes.Soft(
208
+ primary_hue="indigo",
209
+ secondary_hue="blue",
210
+ neutral_hue="slate",
211
+ font=gr.themes.GoogleFont("Inter"),
212
+ text_size="lg",
213
+ spacing_size="lg",
214
+ radius_size="md",
215
+ ).set(
216
+ button_primary_background_fill="*primary_600",
217
+ button_primary_background_fill_hover="*primary_700",
218
+ block_title_text_weight="600",
219
+ ),
220
+ footer_links=[
221
+ {"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"},
222
+ {"label": "Gradio", "url": "https://gradio.app"},
223
+ ],
224
+ server_name="0.0.0.0",
225
+ server_port=7860,
226
+ show_error=True,
227
+ )