OnyxMunk's picture
Update environment example and Dockerfile health check
f3c8dbf
import gradio as gr
import numpy as np
import torch
import os
import warnings
from dotenv import load_dotenv
from huggingface_hub import login
# Try to import AudioLDM2 pipeline
try:
from diffusers import AudioLDM2Pipeline
AUDIO_LDM_AVAILABLE = True
except ImportError:
try:
# Alternative import path
from diffusers import DiffusionPipeline
AUDIO_LDM_AVAILABLE = True
AudioLDM2Pipeline = None # Will use DiffusionPipeline instead
except ImportError:
AUDIO_LDM_AVAILABLE = False
AudioLDM2Pipeline = None
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
# Load environment variables
load_dotenv()
# Set up Hugging Face authentication
hf_token = os.getenv("HF_TOKEN")
if hf_token:
try:
login(token=hf_token)
print("✅ Hugging Face authentication successful")
except Exception as e:
print(f"⚠️ Hugging Face authentication failed: {e}")
print(" Continuing without authentication...")
else:
print("ℹ️ No Hugging Face token found. Some models may have rate limits.")
# Model configuration
MODEL_ID = "cvssp/audioldm2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# Global model cache
model_cache = {
"pipeline": None,
"loaded": False
}
def load_model():
"""
Load the AudioLDM2 model with caching to avoid reloading on every request
"""
if not AUDIO_LDM_AVAILABLE:
raise ImportError("diffusers library not available. Please install: pip install diffusers transformers accelerate")
if model_cache["loaded"] and model_cache["pipeline"] is not None:
print("Using cached model")
return model_cache["pipeline"]
try:
print(f"Loading AudioLDM2 model: {MODEL_ID}")
print(f"Device: {DEVICE}, Dtype: {DTYPE}")
# Try AudioLDM2Pipeline first, fallback to DiffusionPipeline
if AudioLDM2Pipeline is not None:
pipeline = AudioLDM2Pipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
)
else:
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
)
pipeline = pipeline.to(DEVICE)
# Enable memory efficient attention if available
if hasattr(pipeline, "enable_attention_slicing"):
pipeline.enable_attention_slicing()
if hasattr(pipeline, "enable_vae_slicing"):
pipeline.enable_vae_slicing()
# Cache the model
model_cache["pipeline"] = pipeline
model_cache["loaded"] = True
print("Model loaded successfully!")
return pipeline
except Exception as e:
print(f"Error loading model: {e}")
import traceback
traceback.print_exc()
model_cache["loaded"] = False
raise
def generate_audio_with_model(prompt, duration, seed):
"""
Generate audio using the AudioLDM2 model
"""
try:
# Load model (will use cache if already loaded)
pipeline = load_model()
# Prepare seed
generator = None
if seed is not None:
try:
seed_int = int(seed)
generator = torch.Generator(device=DEVICE).manual_seed(seed_int)
except (ValueError, TypeError, OverflowError):
generator = None
# Generate audio
print(f"Generating audio: prompt='{prompt}', duration={duration}s, seed={seed}")
# Stable Audio expects duration in seconds
# Note: The model may have limits on duration, so we clamp it
audio_duration = float(max(1.0, min(30.0, duration)))
# Generate audio using the AudioLDM2 pipeline
# AudioLDM2 API - uses different parameters than Stable Audio
output = None
try:
# AudioLDM2 standard API
output = pipeline(
prompt=prompt,
num_inference_steps=50, # Balance between quality and speed
audio_length_in_s=audio_duration,
generator=generator,
)
except TypeError as e1:
try:
# Try alternative parameter name (some models use 'duration' instead of 'audio_length_in_s')
output = pipeline(
prompt=prompt,
num_inference_steps=50,
duration=audio_duration,
guidance_scale=3.5, # Add guidance for better quality
generator=generator,
)
except TypeError as e2:
try:
# Try without duration parameter
output = pipeline(
prompt=prompt,
num_inference_steps=50,
generator=generator,
)
print(f"Warning: Duration parameter not supported, using model default")
except Exception as e3:
raise RuntimeError(f"Failed to generate audio with any parameter combination: {e1}, {e2}, {e3}")
if output is None:
raise RuntimeError("Pipeline returned None")
# Extract audio array and sample rate
# Handle different output formats from diffusers
audio = None
sample_rate = 44100 # Default
# Try different output attribute names
if hasattr(output, 'audios'):
audio_data = output.audios
if isinstance(audio_data, (list, tuple)) and len(audio_data) > 0:
audio = audio_data[0]
else:
audio = audio_data
elif hasattr(output, 'audio'):
audio_data = output.audio
if isinstance(audio_data, (list, tuple)) and len(audio_data) > 0:
audio = audio_data[0]
else:
audio = audio_data
elif isinstance(output, dict):
audio = output.get('audios', output.get('audio', None))
if isinstance(audio, (list, tuple)) and len(audio) > 0:
audio = audio[0]
elif isinstance(output, (list, tuple)) and len(output) > 0:
audio = output[0]
elif isinstance(output, np.ndarray):
audio = output
elif isinstance(output, torch.Tensor):
audio = output
# Get sample rate
if hasattr(output, 'sample_rate'):
sample_rate = output.sample_rate
elif isinstance(output, dict):
sample_rate = output.get('sample_rate', 44100)
if audio is None:
raise ValueError("Could not extract audio from pipeline output")
# Handle different audio shapes
if len(audio.shape) > 1:
# If multi-channel, convert to mono by averaging
if audio.shape[0] > audio.shape[1]:
audio = audio.mean(axis=0)
else:
audio = audio.mean(axis=1)
# Ensure audio is numpy array and float32
if isinstance(audio, torch.Tensor):
audio = audio.cpu().numpy()
audio = audio.astype(np.float32)
# Normalize to prevent clipping
max_val = np.abs(audio).max()
if max_val > 0:
audio = audio / max_val * 0.95
print(f"Audio generated: shape={audio.shape}, dtype={audio.dtype}, sample_rate={sample_rate}")
return sample_rate, audio
except Exception as e:
print(f"Error in model generation: {e}")
raise
def generate_audio_fallback(prompt, duration, seed):
"""
Fallback audio generation using simple synthesis
"""
# Input validation and sanitization
if prompt is None:
prompt = "gentle melody"
if not isinstance(prompt, str):
prompt = str(prompt)
if duration is None or not isinstance(duration, (int, float)) or duration <= 0:
duration = 10.0
duration = min(max(duration, 1.0), 30.0)
sample_rate = 44100
duration_samples = int(duration * sample_rate)
# Set seed for reproducibility
if seed is not None:
try:
seed_int = int(seed)
np.random.seed(seed_int)
except (ValueError, TypeError, OverflowError):
pass
# Extract features from prompt to influence audio
prompt_lower = prompt.lower()
base_freq = 220 # A3 note
if 'high' in prompt_lower or 'bright' in prompt_lower:
base_freq *= 2
elif 'low' in prompt_lower or 'deep' in prompt_lower:
base_freq /= 2
if 'fast' in prompt_lower or 'quick' in prompt_lower:
vibrato_freq = 5
vibrato_depth = 0.1
else:
vibrato_freq = 0
vibrato_depth = 0
# Generate time array
t = np.linspace(0, duration, duration_samples, endpoint=False)
# Create base waveform
if 'noise' in prompt_lower or 'wind' in prompt_lower or 'rain' in prompt_lower:
audio = np.random.normal(0, 0.3, duration_samples)
elif 'pulse' in prompt_lower or 'beep' in prompt_lower:
audio = 0.3 * np.sign(np.sin(2 * np.pi * base_freq * t))
else:
if vibrato_freq > 0:
phase_modulation = vibrato_depth * np.sin(2 * np.pi * vibrato_freq * t)
audio = 0.3 * np.sin(2 * np.pi * base_freq * t + phase_modulation)
else:
audio = 0.3 * np.sin(2 * np.pi * base_freq * t)
# Add harmonics
if 'rich' in prompt_lower or 'full' in prompt_lower or 'warm' in prompt_lower:
harmonic = 0.2 * np.sin(2 * np.pi * (base_freq * 2) * t)
audio += harmonic
# Add natural variation
if 'natural' in prompt_lower or 'organic' in prompt_lower:
variation = np.random.normal(0, 0.05, duration_samples)
audio += variation
# Normalize
audio = np.clip(audio, -0.95, 0.95)
audio = audio.astype(np.float32)
return sample_rate, audio
def create_audio_generation_interface():
"""
Create a Gradio interface for Stable Audio generation
"""
def generate_audio(prompt, duration, seed):
"""
Generate audio based on text prompt using AudioLDM2 model
"""
try:
# Input validation
if prompt is None or prompt.strip() == "":
prompt = "gentle melody"
if not isinstance(prompt, str):
prompt = str(prompt)
if duration is None or not isinstance(duration, (int, float)):
duration = 10.0
duration = float(max(1.0, min(30.0, duration)))
print(f"Generating audio for prompt: '{prompt}', duration: {duration}s, seed: {seed}")
# Try to use the model first
try:
sample_rate, audio = generate_audio_with_model(prompt, duration, seed)
status_msg = f"✅ Audio generated successfully using AudioLDM2! ({len(audio)/sample_rate:.1f}s)"
except Exception as model_error:
print(f"Model generation failed: {model_error}")
print("Falling back to simple synthesis...")
# Fallback to simple synthesis
sample_rate, audio = generate_audio_fallback(prompt, duration, seed)
status_msg = f"⚠️ Model unavailable, using fallback synthesis. Error: {str(model_error)[:100]}"
# Verify audio was generated correctly
if audio is None or len(audio) == 0:
raise ValueError("Generated audio is empty")
print(f"Audio generated: shape={audio.shape}, dtype={audio.dtype}, sample_rate={sample_rate}")
return (sample_rate, audio), status_msg
except Exception as e:
print(f"Error generating audio: {e}")
import traceback
traceback.print_exc()
# Ultimate fallback
try:
safe_duration = float(max(1.0, min(30.0, duration if isinstance(duration, (int, float)) else 10.0)))
sample_rate = 44100
duration_samples = int(safe_duration * sample_rate)
t = np.linspace(0, safe_duration, duration_samples, endpoint=False)
audio = 0.3 * np.sin(2 * np.pi * 440 * t)
audio = audio.astype(np.float32)
return (sample_rate, audio), f"❌ Error: {str(e)[:100]}. Using emergency fallback."
except Exception as fallback_error:
print(f"Fallback also failed: {fallback_error}")
# Absolute minimum fallback
sample_rate = 44100
duration_samples = 441000 # 10 seconds
t = np.linspace(0, 10.0, duration_samples, endpoint=False)
audio = 0.3 * np.sin(2 * np.pi * 440 * t)
audio = audio.astype(np.float32)
return (sample_rate, audio), "❌ Critical error occurred. Using emergency fallback."
# Create the Gradio interface
device_info = "GPU" if DEVICE == "cuda" else "CPU"
with gr.Blocks(title="AudioLDM2 Audio Generation", theme=gr.themes.Soft()) as interface:
gr.Markdown(f"""
# 🎵 AudioLDM2 Audio Generation
Generate high-quality audio from text prompts using AudioLDM2 technology.
**Device:** {device_info} | **Model:** {MODEL_ID}
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="Describe the audio you want to generate...",
lines=3,
value="A gentle piano melody playing in a cozy room"
)
duration_input = gr.Slider(
label="Duration (seconds)",
minimum=1,
maximum=30,
value=10,
step=1
)
seed_input = gr.Number(
label="Random Seed (optional)",
value=None,
precision=0,
minimum=0,
maximum=999999
)
generate_btn = gr.Button("🎵 Generate Audio", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Generated Audio")
status_output = gr.Textbox(label="Status", interactive=False)
# Connect the generate button to the function
generate_btn.click(
fn=generate_audio,
inputs=[prompt_input, duration_input, seed_input],
outputs=[audio_output, status_output],
show_progress=True
)
# Add some example prompts
examples = gr.Examples(
examples=[
["A calming ocean wave sound with seagulls", 15, 42],
["Upbeat electronic dance music", 20, 123],
["Classical violin concerto", 25, 999],
["Rain falling on a tin roof", 10, 777]
],
inputs=[prompt_input, duration_input, seed_input],
outputs=[audio_output, status_output],
fn=generate_audio,
cache_examples=False
)
return interface
# Application is ready for health monitoring
# Launch the interface
if __name__ == "__main__":
print(f"Starting AudioLDM2 Audio Generation application...")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
interface = create_audio_generation_interface()
# Health check available via Gradio's built-in endpoints
print("Application ready at: http://localhost:7860/")
print("Health status: System is operational")
interface.launch(server_name="0.0.0.0", server_port=7860)