Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import os | |
| import warnings | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| # Try to import AudioLDM2 pipeline | |
| try: | |
| from diffusers import AudioLDM2Pipeline | |
| AUDIO_LDM_AVAILABLE = True | |
| except ImportError: | |
| try: | |
| # Alternative import path | |
| from diffusers import DiffusionPipeline | |
| AUDIO_LDM_AVAILABLE = True | |
| AudioLDM2Pipeline = None # Will use DiffusionPipeline instead | |
| except ImportError: | |
| AUDIO_LDM_AVAILABLE = False | |
| AudioLDM2Pipeline = None | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # Load environment variables | |
| load_dotenv() | |
| # Set up Hugging Face authentication | |
| hf_token = os.getenv("HF_TOKEN") | |
| if hf_token: | |
| try: | |
| login(token=hf_token) | |
| print("✅ Hugging Face authentication successful") | |
| except Exception as e: | |
| print(f"⚠️ Hugging Face authentication failed: {e}") | |
| print(" Continuing without authentication...") | |
| else: | |
| print("ℹ️ No Hugging Face token found. Some models may have rate limits.") | |
| # Model configuration | |
| MODEL_ID = "cvssp/audioldm2" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| # Global model cache | |
| model_cache = { | |
| "pipeline": None, | |
| "loaded": False | |
| } | |
| def load_model(): | |
| """ | |
| Load the AudioLDM2 model with caching to avoid reloading on every request | |
| """ | |
| if not AUDIO_LDM_AVAILABLE: | |
| raise ImportError("diffusers library not available. Please install: pip install diffusers transformers accelerate") | |
| if model_cache["loaded"] and model_cache["pipeline"] is not None: | |
| print("Using cached model") | |
| return model_cache["pipeline"] | |
| try: | |
| print(f"Loading AudioLDM2 model: {MODEL_ID}") | |
| print(f"Device: {DEVICE}, Dtype: {DTYPE}") | |
| # Try AudioLDM2Pipeline first, fallback to DiffusionPipeline | |
| if AudioLDM2Pipeline is not None: | |
| pipeline = AudioLDM2Pipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| ) | |
| else: | |
| from diffusers import DiffusionPipeline | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| ) | |
| pipeline = pipeline.to(DEVICE) | |
| # Enable memory efficient attention if available | |
| if hasattr(pipeline, "enable_attention_slicing"): | |
| pipeline.enable_attention_slicing() | |
| if hasattr(pipeline, "enable_vae_slicing"): | |
| pipeline.enable_vae_slicing() | |
| # Cache the model | |
| model_cache["pipeline"] = pipeline | |
| model_cache["loaded"] = True | |
| print("Model loaded successfully!") | |
| return pipeline | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| model_cache["loaded"] = False | |
| raise | |
| def generate_audio_with_model(prompt, duration, seed): | |
| """ | |
| Generate audio using the AudioLDM2 model | |
| """ | |
| try: | |
| # Load model (will use cache if already loaded) | |
| pipeline = load_model() | |
| # Prepare seed | |
| generator = None | |
| if seed is not None: | |
| try: | |
| seed_int = int(seed) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed_int) | |
| except (ValueError, TypeError, OverflowError): | |
| generator = None | |
| # Generate audio | |
| print(f"Generating audio: prompt='{prompt}', duration={duration}s, seed={seed}") | |
| # Stable Audio expects duration in seconds | |
| # Note: The model may have limits on duration, so we clamp it | |
| audio_duration = float(max(1.0, min(30.0, duration))) | |
| # Generate audio using the AudioLDM2 pipeline | |
| # AudioLDM2 API - uses different parameters than Stable Audio | |
| output = None | |
| try: | |
| # AudioLDM2 standard API | |
| output = pipeline( | |
| prompt=prompt, | |
| num_inference_steps=50, # Balance between quality and speed | |
| audio_length_in_s=audio_duration, | |
| generator=generator, | |
| ) | |
| except TypeError as e1: | |
| try: | |
| # Try alternative parameter name (some models use 'duration' instead of 'audio_length_in_s') | |
| output = pipeline( | |
| prompt=prompt, | |
| num_inference_steps=50, | |
| duration=audio_duration, | |
| guidance_scale=3.5, # Add guidance for better quality | |
| generator=generator, | |
| ) | |
| except TypeError as e2: | |
| try: | |
| # Try without duration parameter | |
| output = pipeline( | |
| prompt=prompt, | |
| num_inference_steps=50, | |
| generator=generator, | |
| ) | |
| print(f"Warning: Duration parameter not supported, using model default") | |
| except Exception as e3: | |
| raise RuntimeError(f"Failed to generate audio with any parameter combination: {e1}, {e2}, {e3}") | |
| if output is None: | |
| raise RuntimeError("Pipeline returned None") | |
| # Extract audio array and sample rate | |
| # Handle different output formats from diffusers | |
| audio = None | |
| sample_rate = 44100 # Default | |
| # Try different output attribute names | |
| if hasattr(output, 'audios'): | |
| audio_data = output.audios | |
| if isinstance(audio_data, (list, tuple)) and len(audio_data) > 0: | |
| audio = audio_data[0] | |
| else: | |
| audio = audio_data | |
| elif hasattr(output, 'audio'): | |
| audio_data = output.audio | |
| if isinstance(audio_data, (list, tuple)) and len(audio_data) > 0: | |
| audio = audio_data[0] | |
| else: | |
| audio = audio_data | |
| elif isinstance(output, dict): | |
| audio = output.get('audios', output.get('audio', None)) | |
| if isinstance(audio, (list, tuple)) and len(audio) > 0: | |
| audio = audio[0] | |
| elif isinstance(output, (list, tuple)) and len(output) > 0: | |
| audio = output[0] | |
| elif isinstance(output, np.ndarray): | |
| audio = output | |
| elif isinstance(output, torch.Tensor): | |
| audio = output | |
| # Get sample rate | |
| if hasattr(output, 'sample_rate'): | |
| sample_rate = output.sample_rate | |
| elif isinstance(output, dict): | |
| sample_rate = output.get('sample_rate', 44100) | |
| if audio is None: | |
| raise ValueError("Could not extract audio from pipeline output") | |
| # Handle different audio shapes | |
| if len(audio.shape) > 1: | |
| # If multi-channel, convert to mono by averaging | |
| if audio.shape[0] > audio.shape[1]: | |
| audio = audio.mean(axis=0) | |
| else: | |
| audio = audio.mean(axis=1) | |
| # Ensure audio is numpy array and float32 | |
| if isinstance(audio, torch.Tensor): | |
| audio = audio.cpu().numpy() | |
| audio = audio.astype(np.float32) | |
| # Normalize to prevent clipping | |
| max_val = np.abs(audio).max() | |
| if max_val > 0: | |
| audio = audio / max_val * 0.95 | |
| print(f"Audio generated: shape={audio.shape}, dtype={audio.dtype}, sample_rate={sample_rate}") | |
| return sample_rate, audio | |
| except Exception as e: | |
| print(f"Error in model generation: {e}") | |
| raise | |
| def generate_audio_fallback(prompt, duration, seed): | |
| """ | |
| Fallback audio generation using simple synthesis | |
| """ | |
| # Input validation and sanitization | |
| if prompt is None: | |
| prompt = "gentle melody" | |
| if not isinstance(prompt, str): | |
| prompt = str(prompt) | |
| if duration is None or not isinstance(duration, (int, float)) or duration <= 0: | |
| duration = 10.0 | |
| duration = min(max(duration, 1.0), 30.0) | |
| sample_rate = 44100 | |
| duration_samples = int(duration * sample_rate) | |
| # Set seed for reproducibility | |
| if seed is not None: | |
| try: | |
| seed_int = int(seed) | |
| np.random.seed(seed_int) | |
| except (ValueError, TypeError, OverflowError): | |
| pass | |
| # Extract features from prompt to influence audio | |
| prompt_lower = prompt.lower() | |
| base_freq = 220 # A3 note | |
| if 'high' in prompt_lower or 'bright' in prompt_lower: | |
| base_freq *= 2 | |
| elif 'low' in prompt_lower or 'deep' in prompt_lower: | |
| base_freq /= 2 | |
| if 'fast' in prompt_lower or 'quick' in prompt_lower: | |
| vibrato_freq = 5 | |
| vibrato_depth = 0.1 | |
| else: | |
| vibrato_freq = 0 | |
| vibrato_depth = 0 | |
| # Generate time array | |
| t = np.linspace(0, duration, duration_samples, endpoint=False) | |
| # Create base waveform | |
| if 'noise' in prompt_lower or 'wind' in prompt_lower or 'rain' in prompt_lower: | |
| audio = np.random.normal(0, 0.3, duration_samples) | |
| elif 'pulse' in prompt_lower or 'beep' in prompt_lower: | |
| audio = 0.3 * np.sign(np.sin(2 * np.pi * base_freq * t)) | |
| else: | |
| if vibrato_freq > 0: | |
| phase_modulation = vibrato_depth * np.sin(2 * np.pi * vibrato_freq * t) | |
| audio = 0.3 * np.sin(2 * np.pi * base_freq * t + phase_modulation) | |
| else: | |
| audio = 0.3 * np.sin(2 * np.pi * base_freq * t) | |
| # Add harmonics | |
| if 'rich' in prompt_lower or 'full' in prompt_lower or 'warm' in prompt_lower: | |
| harmonic = 0.2 * np.sin(2 * np.pi * (base_freq * 2) * t) | |
| audio += harmonic | |
| # Add natural variation | |
| if 'natural' in prompt_lower or 'organic' in prompt_lower: | |
| variation = np.random.normal(0, 0.05, duration_samples) | |
| audio += variation | |
| # Normalize | |
| audio = np.clip(audio, -0.95, 0.95) | |
| audio = audio.astype(np.float32) | |
| return sample_rate, audio | |
| def create_audio_generation_interface(): | |
| """ | |
| Create a Gradio interface for Stable Audio generation | |
| """ | |
| def generate_audio(prompt, duration, seed): | |
| """ | |
| Generate audio based on text prompt using AudioLDM2 model | |
| """ | |
| try: | |
| # Input validation | |
| if prompt is None or prompt.strip() == "": | |
| prompt = "gentle melody" | |
| if not isinstance(prompt, str): | |
| prompt = str(prompt) | |
| if duration is None or not isinstance(duration, (int, float)): | |
| duration = 10.0 | |
| duration = float(max(1.0, min(30.0, duration))) | |
| print(f"Generating audio for prompt: '{prompt}', duration: {duration}s, seed: {seed}") | |
| # Try to use the model first | |
| try: | |
| sample_rate, audio = generate_audio_with_model(prompt, duration, seed) | |
| status_msg = f"✅ Audio generated successfully using AudioLDM2! ({len(audio)/sample_rate:.1f}s)" | |
| except Exception as model_error: | |
| print(f"Model generation failed: {model_error}") | |
| print("Falling back to simple synthesis...") | |
| # Fallback to simple synthesis | |
| sample_rate, audio = generate_audio_fallback(prompt, duration, seed) | |
| status_msg = f"⚠️ Model unavailable, using fallback synthesis. Error: {str(model_error)[:100]}" | |
| # Verify audio was generated correctly | |
| if audio is None or len(audio) == 0: | |
| raise ValueError("Generated audio is empty") | |
| print(f"Audio generated: shape={audio.shape}, dtype={audio.dtype}, sample_rate={sample_rate}") | |
| return (sample_rate, audio), status_msg | |
| except Exception as e: | |
| print(f"Error generating audio: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Ultimate fallback | |
| try: | |
| safe_duration = float(max(1.0, min(30.0, duration if isinstance(duration, (int, float)) else 10.0))) | |
| sample_rate = 44100 | |
| duration_samples = int(safe_duration * sample_rate) | |
| t = np.linspace(0, safe_duration, duration_samples, endpoint=False) | |
| audio = 0.3 * np.sin(2 * np.pi * 440 * t) | |
| audio = audio.astype(np.float32) | |
| return (sample_rate, audio), f"❌ Error: {str(e)[:100]}. Using emergency fallback." | |
| except Exception as fallback_error: | |
| print(f"Fallback also failed: {fallback_error}") | |
| # Absolute minimum fallback | |
| sample_rate = 44100 | |
| duration_samples = 441000 # 10 seconds | |
| t = np.linspace(0, 10.0, duration_samples, endpoint=False) | |
| audio = 0.3 * np.sin(2 * np.pi * 440 * t) | |
| audio = audio.astype(np.float32) | |
| return (sample_rate, audio), "❌ Critical error occurred. Using emergency fallback." | |
| # Create the Gradio interface | |
| device_info = "GPU" if DEVICE == "cuda" else "CPU" | |
| with gr.Blocks(title="AudioLDM2 Audio Generation", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown(f""" | |
| # 🎵 AudioLDM2 Audio Generation | |
| Generate high-quality audio from text prompts using AudioLDM2 technology. | |
| **Device:** {device_info} | **Model:** {MODEL_ID} | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="Describe the audio you want to generate...", | |
| lines=3, | |
| value="A gentle piano melody playing in a cozy room" | |
| ) | |
| duration_input = gr.Slider( | |
| label="Duration (seconds)", | |
| minimum=1, | |
| maximum=30, | |
| value=10, | |
| step=1 | |
| ) | |
| seed_input = gr.Number( | |
| label="Random Seed (optional)", | |
| value=None, | |
| precision=0, | |
| minimum=0, | |
| maximum=999999 | |
| ) | |
| generate_btn = gr.Button("🎵 Generate Audio", variant="primary") | |
| with gr.Column(): | |
| audio_output = gr.Audio(label="Generated Audio") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| # Connect the generate button to the function | |
| generate_btn.click( | |
| fn=generate_audio, | |
| inputs=[prompt_input, duration_input, seed_input], | |
| outputs=[audio_output, status_output], | |
| show_progress=True | |
| ) | |
| # Add some example prompts | |
| examples = gr.Examples( | |
| examples=[ | |
| ["A calming ocean wave sound with seagulls", 15, 42], | |
| ["Upbeat electronic dance music", 20, 123], | |
| ["Classical violin concerto", 25, 999], | |
| ["Rain falling on a tin roof", 10, 777] | |
| ], | |
| inputs=[prompt_input, duration_input, seed_input], | |
| outputs=[audio_output, status_output], | |
| fn=generate_audio, | |
| cache_examples=False | |
| ) | |
| return interface | |
| # Application is ready for health monitoring | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| print(f"Starting AudioLDM2 Audio Generation application...") | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device: {torch.cuda.get_device_name(0)}") | |
| interface = create_audio_generation_interface() | |
| # Health check available via Gradio's built-in endpoints | |
| print("Application ready at: http://localhost:7860/") | |
| print("Health status: System is operational") | |
| interface.launch(server_name="0.0.0.0", server_port=7860) | |