Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Text-to-Music Gradio 6 Demo using Riffusion | |
| Generates music from text prompts via spectrogram diffusion. | |
| """ | |
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| import numpy as np | |
| import io | |
| import os | |
| from riffusion.spectrogram_image_converter import SpectrogramImageConverter | |
| from riffusion.audio_utils import audio_buffer_to_wav, normalize_audio | |
| # Global model cache | |
| _pipe = None | |
| _converter = None | |
| def get_pipeline(): | |
| """Lazy load the Riffusion pipeline.""" | |
| global _pipe | |
| if _pipe is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading Riffusion model on {device}...") | |
| _pipe = StableDiffusionPipeline.from_pretrained( | |
| "riffusion/riffusion-model-v1", | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| ) | |
| _pipe = _pipe.to(device) | |
| _pipe.enable_attention_slicing() | |
| print("Model loaded!") | |
| return _pipe | |
| def get_converter(): | |
| """Lazy load the spectrogram converter.""" | |
| global _converter | |
| if _converter is None: | |
| _converter = SpectrogramImageConverter() | |
| return _converter | |
| def generate_music(prompt: str, duration: float, bpm: float, seed: int = None, progress=gr.Progress()): | |
| """ | |
| Generate music from text prompt using Riffusion. | |
| Args: | |
| prompt: Text description of desired music | |
| duration: Duration in seconds (clamped to model limits) | |
| bpm: Beats per minute (affects spectrogram parameters) | |
| seed: Random seed for reproducibility | |
| Returns: | |
| Tuple of (audio_path, spectrogram_path) for Gradio | |
| """ | |
| # Clamp duration to reasonable range (Riffusion works best ~5-10s) | |
| duration = max(2.0, min(duration, 10.0)) | |
| # Adjust prompt with BPM hint if provided | |
| full_prompt = f"{prompt}, {int(bpm)} bpm" if bpm > 0 else prompt | |
| pipe = get_pipeline() | |
| converter = get_converter() | |
| # Set seed for reproducibility | |
| if seed is None or seed < 0: | |
| seed = np.random.randint(0, 2**32) | |
| generator = torch.Generator(device=pipe.device).manual_seed(seed) | |
| print(f"Generating: '{full_prompt}' ({duration}s @ {bpm} BPM, seed={seed})") | |
| progress(0.1, desc="Generating spectrogram...") | |
| # Generate spectrogram image | |
| # Riffusion generates 512x512 spectrograms ~5 seconds of audio | |
| image = pipe( | |
| full_prompt, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| generator=generator, | |
| height=512, | |
| width=512, | |
| ).images[0] | |
| progress(0.6, desc="Converting to audio...") | |
| # Convert spectrogram to audio | |
| audio = converter.spectrogram_to_audio(image, duration=duration) | |
| audio = normalize_audio(audio) | |
| progress(0.9, desc="Saving outputs...") | |
| # Save outputs | |
| os.makedirs("outputs", exist_ok=True) | |
| base_name = f"output_{seed % 10000:04d}" | |
| audio_path = f"outputs/{base_name}.wav" | |
| spec_path = f"outputs/{base_name}_spectrogram.png" | |
| # Save audio | |
| wav_buffer = audio_buffer_to_wav(audio, sample_rate=converter.sample_rate) | |
| with open(audio_path, "wb") as f: | |
| f.write(wav_buffer.getvalue()) | |
| # Save spectrogram for visualization | |
| image.save(spec_path) | |
| progress(1.0, desc="Done!") | |
| print(f"Saved: {audio_path}") | |
| return audio_path, spec_path | |
| # Gradio 6 - NO parameters in gr.Blocks() constructor! | |
| with gr.Blocks() as demo: | |
| # Header with anycoder link | |
| gr.Markdown(""" | |
| # 🎵 Text-to-Music Generator | |
| Generate music from text descriptions using **Riffusion** - | |
| a Stable Diffusion model trained on spectrograms. | |
| [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Music Description", | |
| placeholder="Describe the music you want to hear...", | |
| value="smooth jazz saxophone solo, relaxing, nighttime", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| duration_slider = gr.Slider( | |
| minimum=2.0, | |
| maximum=10.0, | |
| value=5.0, | |
| step=0.5, | |
| label="Duration (seconds)", | |
| ) | |
| bpm_slider = gr.Slider( | |
| minimum=60, | |
| maximum=180, | |
| value=120, | |
| step=5, | |
| label="Tempo (BPM)", | |
| ) | |
| seed_input = gr.Number( | |
| label="Seed (-1 for random)", | |
| value=-1, | |
| precision=0, | |
| ) | |
| generate_btn = gr.Button("🎹 Generate Music", variant="primary") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio( | |
| label="Generated Music", | |
| type="filepath", | |
| ) | |
| spec_output = gr.Image( | |
| label="Spectrogram Visualization", | |
| type="filepath", | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["piano ballad, emotional, cinematic", 6.0, 70, -1], | |
| ["funky bass guitar groove, 1970s style", 5.0, 110, -1], | |
| ["ethereal ambient pads, space atmosphere", 8.0, 60, -1], | |
| ["heavy metal guitar riff, aggressive", 4.0, 140, -1], | |
| ["classical violin concerto, elegant", 7.0, 90, -1], | |
| ], | |
| inputs=[prompt_input, duration_slider, bpm_slider, seed_input], | |
| outputs=[audio_output, spec_output], | |
| fn=generate_music, | |
| cache_examples=False, | |
| ) | |
| with gr.Accordion("How it works", open=False): | |
| gr.Markdown(""" | |
| ### How it works | |
| 1. Your text prompt is used to generate a **spectrogram image** via Stable Diffusion | |
| 2. The spectrogram is converted back to **audio waveforms** using the Short-Time Fourier Transform (STFT) | |
| 3. The resulting audio is normalized and returned as a playable WAV file | |
| *Note: First generation will download the model (~1.5GB).* | |
| """) | |
| # Event handlers - Gradio 6 uses api_visibility | |
| generate_btn.click( | |
| fn=generate_music, | |
| inputs=[prompt_input, duration_slider, bpm_slider, seed_input], | |
| outputs=[audio_output, spec_output], | |
| api_visibility="public", | |
| ) | |
| # Gradio 6 - ALL app parameters go in launch()! | |
| demo.launch( | |
| theme=gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Inter"), | |
| text_size="lg", | |
| spacing_size="lg", | |
| radius_size="md", | |
| ).set( | |
| button_primary_background_fill="*primary_600", | |
| button_primary_background_fill_hover="*primary_700", | |
| block_title_text_weight="600", | |
| ), | |
| footer_links=[ | |
| {"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"}, | |
| {"label": "Gradio", "url": "https://gradio.app"}, | |
| ], | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ) |