ghostai1 commited on
Commit
b2feece
·
verified ·
1 Parent(s): 3e74d68

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +444 -0
app.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import psutil
5
+ import time
6
+ import sys
7
+ import numpy as np
8
+ import gc
9
+ import gradio as gr
10
+ from pydub import AudioSegment
11
+ from audiocraft.models import MusicGen
12
+ from torch.cuda.amp import autocast
13
+ import warnings
14
+
15
+ # Suppress warnings for cleaner output
16
+ warnings.filterwarnings("ignore")
17
+
18
+ # Set PYTORCH_CUDA_ALLOC_CONF to manage memory fragmentation
19
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
20
+
21
+ # Check critical dependencies
22
+ if np.__version__ != "1.23.5":
23
+ print(f"WARNING: NumPy version {np.__version__} is being used. Tested with numpy==1.23.5.")
24
+ if not torch.__version__.startswith(("2.1.0", "2.3.1")):
25
+ print(f"WARNING: PyTorch version {torch.__version__} may not be compatible. Expected torch==2.1.0 or 2.3.1.")
26
+
27
+ # 1) DEVICE SETUP
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ if device != "cuda":
30
+ print("ERROR: CUDA is required for GPU rendering. CPU rendering is disabled.")
31
+ sys.exit(1)
32
+ print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
33
+
34
+ # 2) LOAD MUSICGEN INTO VRAM
35
+ try:
36
+ print("Loading MusicGen medium model into VRAM...")
37
+ local_model_path = "./models/musicgen-medium"
38
+ if not os.path.exists(local_model_path):
39
+ print(f"ERROR: Local model path {local_model_path} does not exist.")
40
+ print("Please download the MusicGen medium model weights and place them in the correct directory.")
41
+ sys.exit(1)
42
+ musicgen_model = MusicGen.get_pretrained(local_model_path, device=device)
43
+ musicgen_model.set_generation_params(
44
+ duration=15, # Default chunk duration
45
+ two_step_cfg=False # Disable two-step CFG for stability
46
+ )
47
+ except Exception as e:
48
+ print(f"ERROR: Failed to load MusicGen model: {e}")
49
+ print("Ensure model weights are correctly placed and dependencies are installed.")
50
+ sys.exit(1)
51
+
52
+ # 3) RESOURCE MONITORING FUNCTION
53
+ def print_resource_usage(stage: str):
54
+ print(f"--- {stage} ---")
55
+ print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / (1024**3):.2f} GB")
56
+ print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / (1024**3):.2f} GB")
57
+ print(f"CPU Memory Used: {psutil.virtual_memory().percent}%")
58
+ print("---------------")
59
+
60
+ # 4) GENRE PROMPT FUNCTIONS
61
+ def set_red_hot_chili_peppers_prompt():
62
+ return "Funk rock with groovy basslines, syncopated guitar riffs, energetic drums, and a Red Hot Chili Peppers-inspired vibe with dynamic vocal energy and funky breakdowns."
63
+
64
+ def set_nirvana_grunge_prompt():
65
+ return "Grunge with raw distorted guitar riffs, heavy drums, melodic basslines, and a Nirvana-inspired angst-filled sound with quiet-loud dynamics."
66
+
67
+ def set_pearl_jam_grunge_prompt():
68
+ return "Grunge with soulful guitar leads, driving rhythms, deep bass, and a Pearl Jam-inspired emotional intensity with soaring choruses."
69
+
70
+ def set_soundgarden_grunge_prompt():
71
+ return "Grunge with heavy, sludgy guitar riffs, complex drum patterns, and a Soundgarden-inspired dark, psychedelic edge with powerful vocals."
72
+
73
+ def set_foo_fighters_prompt():
74
+ return "Alternative rock with punchy guitar riffs, tight drums, melodic hooks, and a Foo Fighters-inspired anthemic energy with gritty verses."
75
+
76
+ def set_smashing_pumpkins_prompt():
77
+ return "Alternative rock with dreamy guitar textures, heavy distortion, dynamic drums, and a Smashing Pumpkins-inspired blend of melancholy and aggression."
78
+
79
+ def set_radiohead_prompt():
80
+ return "Experimental rock with atmospheric synths, intricate guitar layers, complex rhythms, and a Radiohead-inspired blend of introspective and innovative soundscapes."
81
+
82
+ def set_classic_rock_prompt():
83
+ return "Classic rock with bluesy electric guitars, steady drums, groovy bass, and a Led Zeppelin-inspired raw energy with dynamic solos."
84
+
85
+ def set_alternative_rock_prompt():
86
+ return "Alternative rock with distorted guitar riffs, punchy drums, melodic basslines, and a Pixies-inspired quirky, energetic vibe."
87
+
88
+ def set_post_punk_prompt():
89
+ return "Post-punk with jangly guitars, driving basslines, sharp drums, and a Joy Division-inspired moody, atmospheric sound."
90
+
91
+ def set_indie_rock_prompt():
92
+ return "Indie rock with jangly guitars, heartfelt vocals, steady drums, and an Arctic Monkeys-inspired blend of witty lyrics and catchy riffs."
93
+
94
+ def set_funk_rock_prompt():
95
+ return "Funk rock with slap bass, funky guitar chords, upbeat drums, and a Rage Against the Machine-inspired mix of groove and aggression."
96
+
97
+ def set_detroit_techno_prompt():
98
+ return "Detroit techno with deep pulsing synths, driving basslines, crisp hi-hats, and a Juan Atkins-inspired rhythmic groove."
99
+
100
+ def set_deep_house_prompt():
101
+ return "Deep house with warm analog synth chords, soulful vocal chops, deep basslines, and a Larry Heard-inspired laid-back groove."
102
+
103
+ # 5) AUDIO PROCESSING FUNCTIONS
104
+ def apply_chorus(segment):
105
+ delayed = segment - 6
106
+ delayed = delayed.set_frame_rate(segment.frame_rate)
107
+ return segment.overlay(delayed, position=20)
108
+
109
+ def apply_eq(segment):
110
+ segment = segment.low_pass_filter(8000)
111
+ segment = segment.high_pass_filter(80)
112
+ return segment
113
+
114
+ def apply_limiter(segment, max_db=-6.0):
115
+ if segment.dBFS > max_db:
116
+ segment = segment - (segment.dBFS - max_db)
117
+ return segment
118
+
119
+ def apply_final_gain(segment, target_db=-18.0):
120
+ gain_adjustment = target_db - segment.dBFS
121
+ return segment + gain_adjustment
122
+
123
+ def apply_fade(segment, fade_in_duration=2000, fade_out_duration=2000):
124
+ segment = segment.fade_in(fade_in_duration)
125
+ segment = segment.fade_out(fade_out_duration)
126
+ return segment
127
+
128
+ # 6) GENERATION & I/O FUNCTIONS
129
+ def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, crossfade_duration: int, num_variations: int = 1):
130
+ global musicgen_model
131
+ if not instrumental_prompt.strip():
132
+ return None, "⚠️ Please enter a valid instrumental prompt!"
133
+ try:
134
+ start_time = time.time()
135
+ total_duration = min(max(total_duration, 10), 90)
136
+ chunk_duration = 15
137
+ num_chunks = max(1, total_duration // chunk_duration)
138
+ chunk_duration = total_duration / num_chunks
139
+ overlap_duration = min(1.0, crossfade_duration / 1000.0)
140
+ generation_duration = chunk_duration + overlap_duration
141
+
142
+ output_files = []
143
+ sample_rate = musicgen_model.sample_rate
144
+
145
+ for var in range(num_variations):
146
+ print(f"Generating variation {var+1}/{num_variations}...")
147
+ audio_chunks = []
148
+ seed = 42 + var # Use different seeds for variations
149
+ torch.manual_seed(seed)
150
+ np.random.seed(seed)
151
+
152
+ for i in range(num_chunks):
153
+ chunk_prompt = instrumental_prompt
154
+ print(f"Generating chunk {i+1}/{num_chunks} for variation {var+1} on GPU (prompt: {chunk_prompt})...")
155
+ musicgen_model.set_generation_params(
156
+ duration=generation_duration,
157
+ use_sampling=True,
158
+ top_k=top_k,
159
+ top_p=top_p,
160
+ temperature=temperature,
161
+ cfg_coef=cfg_scale
162
+ )
163
+
164
+ print_resource_usage(f"Before Chunk {i+1} Generation (Variation {var+1})")
165
+
166
+ with torch.no_grad():
167
+ with autocast():
168
+ audio_chunk = musicgen_model.generate([chunk_prompt], progress=True)[0]
169
+
170
+ audio_chunk = audio_chunk.cpu().to(dtype=torch.float32)
171
+ if audio_chunk.dim() == 1:
172
+ audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0)
173
+ elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1:
174
+ audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0)
175
+ elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2:
176
+ audio_chunk = audio_chunk[:1, :]
177
+ audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0)
178
+ elif audio_chunk.dim() > 2:
179
+ audio_chunk = audio_chunk.view(2, -1)
180
+
181
+ if audio_chunk.shape[0] != 2:
182
+ raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}")
183
+
184
+ temp_wav_path = f"temp_chunk_{var}_{i}.wav"
185
+ chunk_path = f"chunk_{var}_{i}.mp3"
186
+ torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24)
187
+ segment = AudioSegment.from_wav(temp_wav_path)
188
+ segment.export(chunk_path, format="mp3", bitrate="320k")
189
+ os.remove(temp_wav_path)
190
+ audio_chunks.append(chunk_path)
191
+
192
+ torch.cuda.empty_cache()
193
+ gc.collect()
194
+ time.sleep(0.5)
195
+ print_resource_usage(f"After Chunk {i+1} Generation (Variation {var+1})")
196
+
197
+ print(f"Combining audio chunks for variation {var+1}...")
198
+ final_segment = AudioSegment.from_mp3(audio_chunks[0])
199
+ for i in range(1, len(audio_chunks)):
200
+ next_segment = AudioSegment.from_mp3(audio_chunks[i])
201
+ next_segment = next_segment + 1
202
+ final_segment = final_segment.append(next_segment, crossfade=crossfade_duration)
203
+
204
+ final_segment = final_segment[:total_duration * 1000]
205
+
206
+ print(f"Post-processing final track for variation {var+1}...")
207
+ final_segment = apply_eq(final_segment)
208
+ final_segment = apply_chorus(final_segment)
209
+ final_segment = apply_limiter(final_segment, max_db=-6.0)
210
+ final_segment = final_segment.normalize(headroom=-9.0)
211
+ final_segment = apply_final_gain(final_segment, target_db=-18.0)
212
+
213
+ mp3_path = f"output_cleaned_variation_{var+1}.mp3"
214
+ final_segment.export(
215
+ mp3_path,
216
+ format="mp3",
217
+ bitrate="320k",
218
+ tags={"title": f"GhostAI Instrumental Variation {var+1}", "artist": "GhostAI"}
219
+ )
220
+ print(f"Saved final audio to {mp3_path}")
221
+ output_files.append(mp3_path)
222
+
223
+ for chunk_path in audio_chunks:
224
+ os.remove(chunk_path)
225
+
226
+ print_resource_usage("After Final Generation")
227
+ print(f"Total Generation Time: {time.time() - start_time:.2f} seconds")
228
+
229
+ # Return the first variation for Gradio display; others are saved to disk
230
+ return output_files[0], f"✅ Done! Generated {num_variations} variations."
231
+ except Exception as e:
232
+ return None, f"❌ Generation failed: {e}"
233
+ finally:
234
+ torch.cuda.empty_cache()
235
+ gc.collect()
236
+
237
+ def clear_inputs():
238
+ return "", 3.0, 250, 0.9, 1.0, 30, 500, 1
239
+
240
+ # 7) CUSTOM CSS
241
+ css = """
242
+ body {
243
+ background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%);
244
+ color: #E0E0E0;
245
+ font-family: 'Orbitron', sans-serif;
246
+ }
247
+ .header-container {
248
+ text-align: center;
249
+ padding: 10px 20px;
250
+ background: rgba(0, 0, 0, 0.9);
251
+ border-bottom: 1px solid #00FF9F;
252
+ }
253
+ #ghost-logo {
254
+ font-size: 40px;
255
+ animation: glitch-ghost 1.5s infinite;
256
+ }
257
+ h1 {
258
+ color: #A100FF;
259
+ font-size: 24px;
260
+ animation: glitch-text 2s infinite;
261
+ }
262
+ p {
263
+ color: #E0E0E0;
264
+ font-size: 12px;
265
+ }
266
+ .input-container, .settings-container, .output-container {
267
+ max-width: 1200px;
268
+ margin: 20px auto;
269
+ padding: 20px;
270
+ background: rgba(28, 37, 38, 0.8);
271
+ border-radius: 10px;
272
+ }
273
+ .textbox {
274
+ background: #1A1A1A;
275
+ border: 1px solid #A100FF;
276
+ color: #E0E0E0;
277
+ }
278
+ .genre-buttons {
279
+ display: flex;
280
+ justify-content: center;
281
+ flex-wrap: wrap;
282
+ gap: 15px;
283
+ }
284
+ .genre-btn, button {
285
+ background: linear-gradient(45deg, #A100FF, #00FF9F);
286
+ border: none;
287
+ color: #0A0A0A;
288
+ padding: 10px 20px;
289
+ border-radius: 5px;
290
+ }
291
+ @keyframes glitch-ghost {
292
+ 0% { transform: translate(0, 0); opacity: 1; }
293
+ 20% { transform: translate(-5px, 2px); opacity: 0.8; }
294
+ 100% { transform: translate(0, 0); opacity: 1; }
295
+ }
296
+ @keyframes glitch-text {
297
+ 0% { transform: translate(0, 0); }
298
+ 20% { transform: translate(-2px, 1px); }
299
+ 100% { transform: translate(0, 0); }
300
+ }
301
+ @font-face {
302
+ font-family: 'Orbitron';
303
+ src: url('https://fonts.gstatic.com/s/orbitron/v29/yMJRMIlzdpvBhQQL_Qq7dy0.woff2') format('woff2');
304
+ }
305
+ """
306
+
307
+ # 8) BUILD WITH BLOCKS
308
+ with gr.Blocks(css=css) as demo:
309
+ gr.Markdown("""
310
+ <div class="header-container">
311
+ <div id="ghost-logo">👻</div>
312
+ <h1>GhostAI Music Generator</h1>
313
+ <p>Summon the Sound of the Unknown</p>
314
+ </div>
315
+ """)
316
+
317
+ with gr.Column(elem_classes="input-container"):
318
+ instrumental_prompt = gr.Textbox(
319
+ label="Instrumental Prompt",
320
+ placeholder="Click a genre button or type your own prompt",
321
+ lines=4,
322
+ elem_classes="textbox"
323
+ )
324
+ with gr.Row(elem_classes="genre-buttons"):
325
+ rhcp_btn = gr.Button("Red Hot Chili Peppers", elem_classes="genre-btn")
326
+ nirvana_btn = gr.Button("Nirvana Grunge", elem_classes="genre-btn")
327
+ pearl_jam_btn = gr.Button("Pearl Jam Grunge", elem_classes="genre-btn")
328
+ soundgarden_btn = gr.Button("Soundgarden Grunge", elem_classes="genre-btn")
329
+ foo_fighters_btn = gr.Button("Foo Fighters", elem_classes="genre-btn")
330
+ smashing_pumpkins_btn = gr.Button("Smashing Pumpkins", elem_classes="genre-btn")
331
+ radiohead_btn = gr.Button("Radiohead", elem_classes="genre-btn")
332
+ classic_rock_btn = gr.Button("Classic Rock", elem_classes="genre-btn")
333
+ alternative_rock_btn = gr.Button("Alternative Rock", elem_classes="genre-btn")
334
+ post_punk_btn = gr.Button("Post-Punk", elem_classes="genre-btn")
335
+ indie_rock_btn = gr.Button("Indie Rock", elem_classes="genre-btn")
336
+ funk_rock_btn = gr.Button("Funk Rock", elem_classes="genre-btn")
337
+ detroit_techno_btn = gr.Button("Detroit Techno", elem_classes="genre-btn")
338
+ deep_house_btn = gr.Button("Deep House", elem_classes="genre-btn")
339
+
340
+ with gr.Column(elem_classes="settings-container"):
341
+ cfg_scale = gr.Slider(
342
+ label="Guidance Scale (CFG)",
343
+ minimum=1.0,
344
+ maximum=10.0,
345
+ value=3.0,
346
+ step=0.1,
347
+ info="Higher values make the instrumental more closely follow the prompt."
348
+ )
349
+ top_k = gr.Slider(
350
+ label="Top-K Sampling",
351
+ minimum=10,
352
+ maximum=500,
353
+ value=250,
354
+ step=10,
355
+ info="Limits sampling to the top k most likely tokens."
356
+ )
357
+ top_p = gr.Slider(
358
+ label="Top-P Sampling",
359
+ minimum=0.0,
360
+ maximum=1.0,
361
+ value=0.9,
362
+ step=0.05,
363
+ info="Keeps tokens with cumulative probability above p."
364
+ )
365
+ temperature = gr.Slider(
366
+ label="Temperature",
367
+ minimum=0.1,
368
+ maximum=2.0,
369
+ value=1.0,
370
+ step=0.1,
371
+ info="Controls randomness. Higher values make output more diverse."
372
+ )
373
+ total_duration = gr.Slider(
374
+ label="Total Duration (seconds)",
375
+ minimum=10,
376
+ maximum=90,
377
+ value=30,
378
+ step=1,
379
+ info="Total duration of the track (10 to 90 seconds)."
380
+ )
381
+ crossfade_duration = gr.Slider(
382
+ label="Crossfade Duration (ms)",
383
+ minimum=100,
384
+ maximum=2000,
385
+ value=500,
386
+ step=100,
387
+ info="Crossfade duration between chunks."
388
+ )
389
+ num_variations = gr.Slider(
390
+ label="Number of Variations",
391
+ minimum=1,
392
+ maximum=4,
393
+ value=1,
394
+ step=1,
395
+ info="Number of different versions to generate with varying random seeds."
396
+ )
397
+ with gr.Row(elem_classes="action-buttons"):
398
+ gen_btn = gr.Button("Generate Music")
399
+ clr_btn = gr.Button("Clear Inputs")
400
+
401
+ with gr.Column(elem_classes="output-container"):
402
+ out_audio = gr.Audio(label="Generated Stereo Instrumental Track", type="filepath")
403
+ status = gr.Textbox(label="Status", interactive=False)
404
+
405
+ rhcp_btn.click(set_red_hot_chili_peppers_prompt, inputs=None, outputs=[instrumental_prompt])
406
+ nirvana_btn.click(set_nirvana_grunge_prompt, inputs=None, outputs=[instrumental_prompt])
407
+ pearl_jam_btn.click(set_pearl_jam_grunge_prompt, inputs=None, outputs=[instrumental_prompt])
408
+ soundgarden_btn.click(set_soundgarden_grunge_prompt, inputs=None, outputs=[instrumental_prompt])
409
+ foo_fighters_btn.click(set_foo_fighters_prompt, inputs=None, outputs=[instrumental_prompt])
410
+ smashing_pumpkins_btn.click(set_smashing_pumpkins_prompt, inputs=None, outputs=[instrumental_prompt])
411
+ radiohead_btn.click(set_radiohead_prompt, inputs=None, outputs=[instrumental_prompt])
412
+ classic_rock_btn.click(set_classic_rock_prompt, inputs=None, outputs=[instrumental_prompt])
413
+ alternative_rock_btn.click(set_alternative_rock_prompt, inputs=None, outputs=[instrumental_prompt])
414
+ post_punk_btn.click(set_post_punk_prompt, inputs=None, outputs=[instrumental_prompt])
415
+ indie_rock_btn.click(set_indie_rock_prompt, inputs=None, outputs=[instrumental_prompt])
416
+ funk_rock_btn.click(set_funk_rock_prompt, inputs=None, outputs=[instrumental_prompt])
417
+ detroit_techno_btn.click(set_detroit_techno_prompt, inputs=None, outputs=[instrumental_prompt])
418
+ deep_house_btn.click(set_deep_house_prompt, inputs=None, outputs=[instrumental_prompt])
419
+ gen_btn.click(
420
+ generate_music,
421
+ inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration, num_variations],
422
+ outputs=[out_audio, status]
423
+ )
424
+ clr_btn.click(
425
+ clear_inputs,
426
+ inputs=None,
427
+ outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration, num_variations]
428
+ )
429
+
430
+ # 9) TURN OFF OPENAPI/DOCS
431
+ app = demo.launch(
432
+ server_name="0.0.0.0",
433
+ server_port=9999,
434
+ share=False,
435
+ inbrowser=False,
436
+ show_error=True
437
+ )
438
+ try:
439
+ fastapi_app = demo._server.app
440
+ fastapi_app.docs_url = None
441
+ fastapi_app.redoc_url = None
442
+ fastapi_app.openapi_url = None
443
+ except Exception:
444
+ pass