File size: 16,132 Bytes
505eff0
 
d0fa7b7
 
 
fa03fad
 
d0fa7b7
a82c7f0
d0fa7b7
a82c7f0
 
d0fa7b7
 
 
 
a82c7f0
 
d0fa7b7
a82c7f0
 
d0fa7b7
 
 
 
fa03fad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0fa7b7
a82c7f0
d0fa7b7
 
 
 
 
 
 
 
 
 
 
a82c7f0
d0fa7b7
a82c7f0
d0fa7b7
a82c7f0
d0fa7b7
 
 
a82c7f0
d0fa7b7
a82c7f0
d0fa7b7
a82c7f0
 
 
 
d0fa7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e911600
a82c7f0
d0fa7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a82c7f0
 
d0fa7b7
 
a82c7f0
d0fa7b7
 
 
 
 
 
 
 
f3c8dbf
d0fa7b7
 
 
f3c8dbf
a82c7f0
d0fa7b7
 
 
 
a82c7f0
d0fa7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e911600
88a32e8
 
 
 
 
 
d0fa7b7
 
88a32e8
e911600
 
 
d0fa7b7
e911600
88a32e8
 
 
84525fb
88a32e8
e911600
 
 
 
 
 
d0fa7b7
e911600
d0fa7b7
e911600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0fa7b7
 
e911600
 
 
d0fa7b7
e911600
 
 
 
d0fa7b7
e911600
 
 
 
d0fa7b7
e911600
d0fa7b7
e911600
d0fa7b7
505eff0
 
 
 
 
 
 
 
a82c7f0
505eff0
 
d0fa7b7
 
88a32e8
 
 
 
 
d0fa7b7
88a32e8
f86e88f
e911600
d0fa7b7
 
 
a82c7f0
d0fa7b7
 
 
 
 
 
 
 
 
 
 
 
e911600
d0fa7b7
505eff0
 
e911600
d0fa7b7
 
 
 
88a32e8
 
 
 
 
d0fa7b7
 
88a32e8
d0fa7b7
88a32e8
 
 
 
 
 
 
d0fa7b7
88a32e8
d0fa7b7
505eff0
 
d0fa7b7
a82c7f0
d0fa7b7
a82c7f0
 
 
d0fa7b7
505eff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88a32e8
 
d0fa7b7
505eff0
 
 
 
 
 
 
 
d0fa7b7
505eff0
 
 
d0fa7b7
 
e911600
 
505eff0
d0fa7b7
505eff0
 
 
 
 
 
d0fa7b7
 
 
 
505eff0
 
 
 
a82c7f0
 
505eff0
 
a82c7f0
d0fa7b7
 
 
 
a82c7f0
505eff0
a82c7f0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
import gradio as gr
import numpy as np
import torch
import os
import warnings
from dotenv import load_dotenv
from huggingface_hub import login

# Try to import AudioLDM2 pipeline
try:
    from diffusers import AudioLDM2Pipeline
    AUDIO_LDM_AVAILABLE = True
except ImportError:
    try:
        # Alternative import path
        from diffusers import DiffusionPipeline
        AUDIO_LDM_AVAILABLE = True
        AudioLDM2Pipeline = None  # Will use DiffusionPipeline instead
    except ImportError:
        AUDIO_LDM_AVAILABLE = False
        AudioLDM2Pipeline = None

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)

# Load environment variables
load_dotenv()

# Set up Hugging Face authentication
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    try:
        login(token=hf_token)
        print("✅ Hugging Face authentication successful")
    except Exception as e:
        print(f"⚠️  Hugging Face authentication failed: {e}")
        print("   Continuing without authentication...")
else:
    print("ℹ️  No Hugging Face token found. Some models may have rate limits.")

# Model configuration
MODEL_ID = "cvssp/audioldm2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

# Global model cache
model_cache = {
    "pipeline": None,
    "loaded": False
}

def load_model():
    """
    Load the AudioLDM2 model with caching to avoid reloading on every request
    """
    if not AUDIO_LDM_AVAILABLE:
        raise ImportError("diffusers library not available. Please install: pip install diffusers transformers accelerate")

    if model_cache["loaded"] and model_cache["pipeline"] is not None:
        print("Using cached model")
        return model_cache["pipeline"]

    try:
        print(f"Loading AudioLDM2 model: {MODEL_ID}")
        print(f"Device: {DEVICE}, Dtype: {DTYPE}")

        # Try AudioLDM2Pipeline first, fallback to DiffusionPipeline
        if AudioLDM2Pipeline is not None:
            pipeline = AudioLDM2Pipeline.from_pretrained(
                MODEL_ID,
                torch_dtype=DTYPE,
            )
        else:
            from diffusers import DiffusionPipeline
            pipeline = DiffusionPipeline.from_pretrained(
                MODEL_ID,
                torch_dtype=DTYPE,
            )
        
        pipeline = pipeline.to(DEVICE)
        
        # Enable memory efficient attention if available
        if hasattr(pipeline, "enable_attention_slicing"):
            pipeline.enable_attention_slicing()
        if hasattr(pipeline, "enable_vae_slicing"):
            pipeline.enable_vae_slicing()
        
        # Cache the model
        model_cache["pipeline"] = pipeline
        model_cache["loaded"] = True
        
        print("Model loaded successfully!")
        return pipeline
        
    except Exception as e:
        print(f"Error loading model: {e}")
        import traceback
        traceback.print_exc()
        model_cache["loaded"] = False
        raise

def generate_audio_with_model(prompt, duration, seed):
    """
    Generate audio using the AudioLDM2 model
    """
    try:
        # Load model (will use cache if already loaded)
        pipeline = load_model()
        
        # Prepare seed
        generator = None
        if seed is not None:
            try:
                seed_int = int(seed)
                generator = torch.Generator(device=DEVICE).manual_seed(seed_int)
            except (ValueError, TypeError, OverflowError):
                generator = None
        
        # Generate audio
        print(f"Generating audio: prompt='{prompt}', duration={duration}s, seed={seed}")
        
        # Stable Audio expects duration in seconds
        # Note: The model may have limits on duration, so we clamp it
        audio_duration = float(max(1.0, min(30.0, duration)))
        
        # Generate audio using the AudioLDM2 pipeline
        # AudioLDM2 API - uses different parameters than Stable Audio
        output = None
        try:
            # AudioLDM2 standard API
            output = pipeline(
                prompt=prompt,
                num_inference_steps=50,  # Balance between quality and speed
                audio_length_in_s=audio_duration,
                generator=generator,
            )
        except TypeError as e1:
            try:
                # Try alternative parameter name (some models use 'duration' instead of 'audio_length_in_s')
                output = pipeline(
                    prompt=prompt,
                    num_inference_steps=50,
                    duration=audio_duration,
                    guidance_scale=3.5,  # Add guidance for better quality
                    generator=generator,
                )
            except TypeError as e2:
                try:
                    # Try without duration parameter
                    output = pipeline(
                        prompt=prompt,
                        num_inference_steps=50,
                        generator=generator,
                    )
                    print(f"Warning: Duration parameter not supported, using model default")
                except Exception as e3:
                    raise RuntimeError(f"Failed to generate audio with any parameter combination: {e1}, {e2}, {e3}")
        
        if output is None:
            raise RuntimeError("Pipeline returned None")
        
        # Extract audio array and sample rate
        # Handle different output formats from diffusers
        audio = None
        sample_rate = 44100  # Default
        
        # Try different output attribute names
        if hasattr(output, 'audios'):
            audio_data = output.audios
            if isinstance(audio_data, (list, tuple)) and len(audio_data) > 0:
                audio = audio_data[0]
            else:
                audio = audio_data
        elif hasattr(output, 'audio'):
            audio_data = output.audio
            if isinstance(audio_data, (list, tuple)) and len(audio_data) > 0:
                audio = audio_data[0]
            else:
                audio = audio_data
        elif isinstance(output, dict):
            audio = output.get('audios', output.get('audio', None))
            if isinstance(audio, (list, tuple)) and len(audio) > 0:
                audio = audio[0]
        elif isinstance(output, (list, tuple)) and len(output) > 0:
            audio = output[0]
        elif isinstance(output, np.ndarray):
            audio = output
        elif isinstance(output, torch.Tensor):
            audio = output
        
        # Get sample rate
        if hasattr(output, 'sample_rate'):
            sample_rate = output.sample_rate
        elif isinstance(output, dict):
            sample_rate = output.get('sample_rate', 44100)
        
        if audio is None:
            raise ValueError("Could not extract audio from pipeline output")
        
        # Handle different audio shapes
        if len(audio.shape) > 1:
            # If multi-channel, convert to mono by averaging
            if audio.shape[0] > audio.shape[1]:
                audio = audio.mean(axis=0)
            else:
                audio = audio.mean(axis=1)
        
        # Ensure audio is numpy array and float32
        if isinstance(audio, torch.Tensor):
            audio = audio.cpu().numpy()
        audio = audio.astype(np.float32)
        
        # Normalize to prevent clipping
        max_val = np.abs(audio).max()
        if max_val > 0:
            audio = audio / max_val * 0.95
        
        print(f"Audio generated: shape={audio.shape}, dtype={audio.dtype}, sample_rate={sample_rate}")
        
        return sample_rate, audio
        
    except Exception as e:
        print(f"Error in model generation: {e}")
        raise

def generate_audio_fallback(prompt, duration, seed):
    """
    Fallback audio generation using simple synthesis
    """
    # Input validation and sanitization
    if prompt is None:
        prompt = "gentle melody"
    if not isinstance(prompt, str):
        prompt = str(prompt)
    if duration is None or not isinstance(duration, (int, float)) or duration <= 0:
        duration = 10.0
    duration = min(max(duration, 1.0), 30.0)

    sample_rate = 44100
    duration_samples = int(duration * sample_rate)

    # Set seed for reproducibility
    if seed is not None:
        try:
            seed_int = int(seed)
            np.random.seed(seed_int)
        except (ValueError, TypeError, OverflowError):
            pass

    # Extract features from prompt to influence audio
    prompt_lower = prompt.lower()
    base_freq = 220  # A3 note

    if 'high' in prompt_lower or 'bright' in prompt_lower:
        base_freq *= 2
    elif 'low' in prompt_lower or 'deep' in prompt_lower:
        base_freq /= 2

    if 'fast' in prompt_lower or 'quick' in prompt_lower:
        vibrato_freq = 5
        vibrato_depth = 0.1
    else:
        vibrato_freq = 0
        vibrato_depth = 0

    # Generate time array
    t = np.linspace(0, duration, duration_samples, endpoint=False)

    # Create base waveform
    if 'noise' in prompt_lower or 'wind' in prompt_lower or 'rain' in prompt_lower:
        audio = np.random.normal(0, 0.3, duration_samples)
    elif 'pulse' in prompt_lower or 'beep' in prompt_lower:
        audio = 0.3 * np.sign(np.sin(2 * np.pi * base_freq * t))
    else:
        if vibrato_freq > 0:
            phase_modulation = vibrato_depth * np.sin(2 * np.pi * vibrato_freq * t)
            audio = 0.3 * np.sin(2 * np.pi * base_freq * t + phase_modulation)
        else:
            audio = 0.3 * np.sin(2 * np.pi * base_freq * t)

    # Add harmonics
    if 'rich' in prompt_lower or 'full' in prompt_lower or 'warm' in prompt_lower:
        harmonic = 0.2 * np.sin(2 * np.pi * (base_freq * 2) * t)
        audio += harmonic

    # Add natural variation
    if 'natural' in prompt_lower or 'organic' in prompt_lower:
        variation = np.random.normal(0, 0.05, duration_samples)
        audio += variation

    # Normalize
    audio = np.clip(audio, -0.95, 0.95)
    audio = audio.astype(np.float32)

    return sample_rate, audio

def create_audio_generation_interface():
    """
    Create a Gradio interface for Stable Audio generation
    """

    def generate_audio(prompt, duration, seed):
        """
        Generate audio based on text prompt using AudioLDM2 model
        """
        try:
            # Input validation
            if prompt is None or prompt.strip() == "":
                prompt = "gentle melody"
            if not isinstance(prompt, str):
                prompt = str(prompt)
            if duration is None or not isinstance(duration, (int, float)):
                duration = 10.0
            duration = float(max(1.0, min(30.0, duration)))

            print(f"Generating audio for prompt: '{prompt}', duration: {duration}s, seed: {seed}")

            # Try to use the model first
            try:
                sample_rate, audio = generate_audio_with_model(prompt, duration, seed)
                status_msg = f"✅ Audio generated successfully using AudioLDM2! ({len(audio)/sample_rate:.1f}s)"
            except Exception as model_error:
                print(f"Model generation failed: {model_error}")
                print("Falling back to simple synthesis...")
                # Fallback to simple synthesis
                sample_rate, audio = generate_audio_fallback(prompt, duration, seed)
                status_msg = f"⚠️ Model unavailable, using fallback synthesis. Error: {str(model_error)[:100]}"

            # Verify audio was generated correctly
            if audio is None or len(audio) == 0:
                raise ValueError("Generated audio is empty")

            print(f"Audio generated: shape={audio.shape}, dtype={audio.dtype}, sample_rate={sample_rate}")

            return (sample_rate, audio), status_msg

        except Exception as e:
            print(f"Error generating audio: {e}")
            import traceback
            traceback.print_exc()
            
            # Ultimate fallback
            try:
                safe_duration = float(max(1.0, min(30.0, duration if isinstance(duration, (int, float)) else 10.0)))
                sample_rate = 44100
                duration_samples = int(safe_duration * sample_rate)
                t = np.linspace(0, safe_duration, duration_samples, endpoint=False)
                audio = 0.3 * np.sin(2 * np.pi * 440 * t)
                audio = audio.astype(np.float32)

                return (sample_rate, audio), f"❌ Error: {str(e)[:100]}. Using emergency fallback."
            except Exception as fallback_error:
                print(f"Fallback also failed: {fallback_error}")
                # Absolute minimum fallback
                sample_rate = 44100
                duration_samples = 441000  # 10 seconds
                t = np.linspace(0, 10.0, duration_samples, endpoint=False)
                audio = 0.3 * np.sin(2 * np.pi * 440 * t)
                audio = audio.astype(np.float32)

                return (sample_rate, audio), "❌ Critical error occurred. Using emergency fallback."

    # Create the Gradio interface
    device_info = "GPU" if DEVICE == "cuda" else "CPU"
    with gr.Blocks(title="AudioLDM2 Audio Generation", theme=gr.themes.Soft()) as interface:
        gr.Markdown(f"""
        # 🎵 AudioLDM2 Audio Generation
        Generate high-quality audio from text prompts using AudioLDM2 technology.

        **Device:** {device_info} | **Model:** {MODEL_ID}
        """)

        with gr.Row():
            with gr.Column():
                prompt_input = gr.Textbox(
                    label="Text Prompt",
                    placeholder="Describe the audio you want to generate...",
                    lines=3,
                    value="A gentle piano melody playing in a cozy room"
                )

                duration_input = gr.Slider(
                    label="Duration (seconds)",
                    minimum=1,
                    maximum=30,
                    value=10,
                    step=1
                )

                seed_input = gr.Number(
                    label="Random Seed (optional)",
                    value=None,
                    precision=0,
                    minimum=0,
                    maximum=999999
                )

                generate_btn = gr.Button("🎵 Generate Audio", variant="primary")

            with gr.Column():
                audio_output = gr.Audio(label="Generated Audio")
                status_output = gr.Textbox(label="Status", interactive=False)

        # Connect the generate button to the function
        generate_btn.click(
            fn=generate_audio,
            inputs=[prompt_input, duration_input, seed_input],
            outputs=[audio_output, status_output],
            show_progress=True
        )

        # Add some example prompts
        examples = gr.Examples(
            examples=[
                ["A calming ocean wave sound with seagulls", 15, 42],
                ["Upbeat electronic dance music", 20, 123],
                ["Classical violin concerto", 25, 999],
                ["Rain falling on a tin roof", 10, 777]
            ],
            inputs=[prompt_input, duration_input, seed_input],
            outputs=[audio_output, status_output],
            fn=generate_audio,
            cache_examples=False
        )

    return interface

# Application is ready for health monitoring

# Launch the interface
if __name__ == "__main__":
    print(f"Starting AudioLDM2 Audio Generation application...")
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")

    interface = create_audio_generation_interface()

    # Health check available via Gradio's built-in endpoints
    print("Application ready at: http://localhost:7860/")
    print("Health status: System is operational")

    interface.launch(server_name="0.0.0.0", server_port=7860)