| import os |
| import torch |
| import torchaudio |
| import psutil |
| import time |
| import sys |
| import numpy as np |
| import gc |
| import gradio as gr |
| from pydub import AudioSegment |
| from audiocraft.models import MusicGen |
| from torch.cuda.amp import autocast |
| import warnings |
|
|
| |
| warnings.filterwarnings("ignore") |
|
|
| |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
|
|
| |
| if np.__version__ != "1.23.5": |
| print(f"WARNING: NumPy version {np.__version__} is being used. Tested with numpy==1.23.5.") |
| if not torch.__version__.startswith(("2.1.0", "2.3.1")): |
| print(f"WARNING: PyTorch version {torch.__version__} may not be compatible. Expected torch==2.1.0 or 2.3.1.") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if device != "cuda": |
| print("ERROR: CUDA is required for GPU rendering. CPU rendering is disabled.") |
| sys.exit(1) |
| print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}") |
|
|
| |
| try: |
| print("Loading MusicGen medium model into VRAM...") |
| local_model_path = "./models/musicgen-medium" |
| if not os.path.exists(local_model_path): |
| print(f"ERROR: Local model path {local_model_path} does not exist.") |
| print("Please download the MusicGen medium model weights and place them in the correct directory.") |
| sys.exit(1) |
| musicgen_model = MusicGen.get_pretrained(local_model_path, device=device) |
| musicgen_model.set_generation_params( |
| duration=15, |
| two_step_cfg=False |
| ) |
| except Exception as e: |
| print(f"ERROR: Failed to load MusicGen model: {e}") |
| print("Ensure model weights are correctly placed and dependencies are installed.") |
| sys.exit(1) |
|
|
| |
| def print_resource_usage(stage: str): |
| print(f"--- {stage} ---") |
| print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / (1024**3):.2f} GB") |
| print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / (1024**3):.2f} GB") |
| print(f"CPU Memory Used: {psutil.virtual_memory().percent}%") |
| print("---------------") |
|
|
| |
| def set_classic_rock_prompt(): |
| return "Classic rock with bluesy electric guitars, steady drums, groovy bass, Hammond organ fills, and a Led Zeppelin-inspired raw energy." |
|
|
| def set_alternative_rock_prompt(): |
| return "Alternative rock with distorted guitar riffs, punchy drums, melodic basslines, atmospheric synths, and a Nirvana-inspired grunge vibe." |
|
|
| def set_detroit_techno_prompt(): |
| return "Detroit techno with deep pulsing synths, driving basslines, crisp hi-hats, and a rhythmic groove inspired by Juan Atkins." |
|
|
| def set_deep_house_prompt(): |
| return "Deep house with warm analog synth chords, soulful vocal chops, deep basslines, and a laid-back groove inspired by Larry Heard." |
|
|
| def set_smooth_jazz_prompt(): |
| return "Smooth jazz with warm saxophone leads, expressive Rhodes piano chords, soft bossa nova drums, and a George Benson-inspired feel." |
|
|
| def set_bebop_jazz_prompt(): |
| return "Bebop jazz with fast-paced saxophone solos, intricate piano runs, walking basslines, and a Charlie Parker-inspired style." |
|
|
| def set_baroque_classical_prompt(): |
| return "Baroque classical with harpsichord, delicate violin, cello, and a Vivaldi-inspired melodic structure." |
|
|
| def set_romantic_classical_prompt(): |
| return "Romantic classical with lush strings, expressive piano, dramatic brass, and a Chopin-inspired melodic flow." |
|
|
| def set_boom_bap_hiphop_prompt(): |
| return "Boom bap hip-hop with gritty sampled drums, deep basslines, jazzy piano loops, and a J Dilla-inspired groove." |
|
|
| def set_trap_hiphop_prompt(): |
| return "Trap hip-hop with hard-hitting 808 bass, snappy snares, rapid hi-hats, and eerie synth melodies." |
|
|
| def set_pop_rock_prompt(): |
| return "Pop rock with catchy electric guitar riffs, uplifting synths, steady drums, and a Coldplay-inspired anthemic feel." |
|
|
| def set_fusion_jazz_prompt(): |
| return "Fusion jazz with electric piano, funky basslines, intricate drum patterns, and a Herbie Hancock-inspired groove." |
|
|
| def set_edm_prompt(): |
| return "EDM with high-energy synth leads, pounding basslines, four-on-the-floor kicks, and a festival-ready drop." |
|
|
| def set_indie_folk_prompt(): |
| return "Indie folk with acoustic guitars, heartfelt vocals, gentle percussion, and a Bon Iver-inspired atmosphere." |
|
|
| |
| def apply_chorus(segment): |
| delayed = segment - 6 |
| delayed = delayed.set_frame_rate(segment.frame_rate) |
| return segment.overlay(delayed, position=20) |
|
|
| def apply_eq(segment): |
| segment = segment.low_pass_filter(8000) |
| segment = segment.high_pass_filter(80) |
| return segment |
|
|
| def apply_limiter(segment, max_db=-3.0): |
| if segment.dBFS > max_db: |
| segment = segment - (segment.dBFS - max_db) |
| return segment |
|
|
| def apply_final_gain(segment, target_db=-12.0): |
| gain_adjustment = target_db - segment.dBFS |
| return segment + gain_adjustment |
|
|
| def apply_fade(segment, fade_in_duration=2000, fade_out_duration=2000): |
| segment = segment.fade_in(fade_in_duration) |
| segment = segment.fade_out(fade_out_duration) |
| return segment |
|
|
| |
| def generate_music(instrumental_prompt: str, cfg_scale: float, top_k: int, top_p: float, temperature: float, total_duration: int, crossfade_duration: int, num_variations: int = 1): |
| global musicgen_model |
| if not instrumental_prompt.strip(): |
| return None, "⚠️ Please enter a valid instrumental prompt!" |
| try: |
| start_time = time.time() |
| total_duration = min(max(total_duration, 10), 90) |
| chunk_duration = 15 |
| num_chunks = max(1, total_duration // chunk_duration) |
| chunk_duration = total_duration / num_chunks |
| overlap_duration = min(1.0, crossfade_duration / 1000.0) |
| generation_duration = chunk_duration + overlap_duration |
|
|
| output_files = [] |
| sample_rate = musicgen_model.sample_rate |
|
|
| for var in range(num_variations): |
| print(f"Generating variation {var+1}/{num_variations}...") |
| audio_chunks = [] |
| seed = 42 + var |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| for i in range(num_chunks): |
| chunk_prompt = instrumental_prompt |
| print(f"Generating chunk {i+1}/{num_chunks} for variation {var+1} on GPU (prompt: {chunk_prompt})...") |
| musicgen_model.set_generation_params( |
| duration=generation_duration, |
| use_sampling=True, |
| top_k=top_k, |
| top_p=top_p, |
| temperature=temperature, |
| cfg_coef=cfg_scale |
| ) |
|
|
| print_resource_usage(f"Before Chunk {i+1} Generation (Variation {var+1})") |
|
|
| with torch.no_grad(): |
| with autocast(): |
| audio_chunk = musicgen_model.generate([chunk_prompt], progress=True)[0] |
|
|
| audio_chunk = audio_chunk.cpu().to(dtype=torch.float32) |
| if audio_chunk.dim() == 1: |
| audio_chunk = torch.stack([audio_chunk, audio_chunk], dim=0) |
| elif audio_chunk.dim() == 2 and audio_chunk.shape[0] == 1: |
| audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
| elif audio_chunk.dim() == 2 and audio_chunk.shape[0] != 2: |
| audio_chunk = audio_chunk[:1, :] |
| audio_chunk = torch.cat([audio_chunk, audio_chunk], dim=0) |
| elif audio_chunk.dim() > 2: |
| audio_chunk = audio_chunk.view(2, -1) |
|
|
| if audio_chunk.shape[0] != 2: |
| raise ValueError(f"Expected stereo audio with shape (2, samples), got shape {audio_chunk.shape}") |
|
|
| temp_wav_path = f"temp_chunk_{var}_{i}.wav" |
| chunk_path = f"chunk_{var}_{i}.mp3" |
| torchaudio.save(temp_wav_path, audio_chunk, sample_rate, bits_per_sample=24) |
| segment = AudioSegment.from_wav(temp_wav_path) |
| segment.export(chunk_path, format="mp3", bitrate="320k") |
| os.remove(temp_wav_path) |
| audio_chunks.append(chunk_path) |
|
|
| torch.cuda.empty_cache() |
| gc.collect() |
| time.sleep(0.5) |
| print_resource_usage(f"After Chunk {i+1} Generation (Variation {var+1})") |
|
|
| print(f"Combining audio chunks for variation {var+1}...") |
| final_segment = AudioSegment.from_mp3(audio_chunks[0]) |
| for i in range(1, len(audio_chunks)): |
| next_segment = AudioSegment.from_mp3(audio_chunks[i]) |
| next_segment = next_segment + 1 |
| final_segment = final_segment.append(next_segment, crossfade=crossfade_duration) |
|
|
| final_segment = final_segment[:total_duration * 1000] |
|
|
| print(f"Post-processing final track for variation {var+1}...") |
| final_segment = apply_eq(final_segment) |
| final_segment = apply_chorus(final_segment) |
| final_segment = apply_limiter(final_segment, max_db=-3.0) |
| final_segment = final_segment.normalize(headroom=-6.0) |
| final_segment = apply_final_gain(final_segment, target_db=-12.0) |
|
|
| mp3_path = f"output_cleaned_variation_{var+1}.mp3" |
| final_segment.export( |
| mp3_path, |
| format="mp3", |
| bitrate="320k", |
| tags={"title": f"GhostAI Instrumental Variation {var+1}", "artist": "GhostAI"} |
| ) |
| print(f"Saved final audio to {mp3_path}") |
| output_files.append(mp3_path) |
|
|
| for chunk_path in audio_chunks: |
| os.remove(chunk_path) |
|
|
| print_resource_usage("After Final Generation") |
| print(f"Total Generation Time: {time.time() - start_time:.2f} seconds") |
|
|
| |
| return output_files[0], f"✅ Done! Generated {num_variations} variations." |
| except Exception as e: |
| return None, f"❌ Generation failed: {e}" |
| finally: |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| def clear_inputs(): |
| return "", 3.0, 250, 0.9, 1.0, 30, 500, 1 |
|
|
| |
| css = """ |
| body { |
| background: linear-gradient(135deg, #0A0A0A 0%, #1C2526 100%); |
| color: #E0E0E0; |
| font-family: 'Orbitron', sans-serif; |
| } |
| .header-container { |
| text-align: center; |
| padding: 15px 20px; |
| background: rgba(0, 0, 0, 0.9); |
| border-bottom: 1px solid #00FF9F; |
| } |
| #ghost-logo { |
| font-size: 60px; |
| animation: glitch-ghost 1.5s infinite; |
| } |
| h1 { |
| color: #A100FF; |
| font-size: 28px; |
| animation: glitch-text 2s infinite; |
| } |
| .input-container, .settings-container, .output-container { |
| max-width: 1000px; |
| margin: 20px auto; |
| padding: 20px; |
| background: rgba(28, 37, 38, 0.8); |
| border-radius: 10px; |
| } |
| .textbox { |
| background: #1A1A1A; |
| border: 1px solid #A100FF; |
| color: #E0E0E0; |
| } |
| .genre-buttons { |
| display: flex; |
| justify-content: center; |
| gap: 15px; |
| } |
| .genre-btn, button { |
| background: linear-gradient(45deg, #A100FF, #00FF9F); |
| border: none; |
| color: #0A0A0A; |
| padding: 10px 20px; |
| border-radius: 5px; |
| } |
| @keyframes glitch-ghost { |
| 0% { transform: translate(0, 0); opacity: 1; } |
| 20% { transform: translate(-5px, 2px); opacity: 0.8; } |
| 100% { transform: translate(0, 0); opacity: 1; } |
| } |
| @keyframes glitch-text { |
| 0% { transform: translate(0, 0); } |
| 20% { transform: translate(-2px, 1px); } |
| 100% { transform: translate(0, 0); } |
| } |
| @font-face { |
| font-family: 'Orbitron'; |
| src: url('https://fonts.gstatic.com/s/orbitron/v29/yMJRMIlzdpvBhQQL_Qq7dy0.woff2') format('woff2'); |
| } |
| """ |
|
|
| |
| with gr.Blocks(css=css) as demo: |
| gr.Markdown(""" |
| <div class="header-container"> |
| <div id="ghost-logo">👻</div> |
| <h1>GhostAI Music Generator</h1> |
| <p>Summon the Sound of the Unknown</p> |
| </div> |
| """) |
| |
| with gr.Column(elem_classes="input-container"): |
| instrumental_prompt = gr.Textbox( |
| label="Instrumental Prompt", |
| placeholder="Click a genre button or type your own prompt", |
| lines=4, |
| elem_classes="textbox" |
| ) |
| with gr.Row(elem_classes="genre-buttons"): |
| classic_rock_btn = gr.Button("Classic Rock", elem_classes="genre-btn") |
| alternative_rock_btn = gr.Button("Alternative Rock", elem_classes="genre-btn") |
| detroit_techno_btn = gr.Button("Detroit Techno", elem_classes="genre-btn") |
| deep_house_btn = gr.Button("Deep House", elem_classes="genre-btn") |
| smooth_jazz_btn = gr.Button("Smooth Jazz", elem_classes="genre-btn") |
| bebop_jazz_btn = gr.Button("Bebop Jazz", elem_classes="genre-btn") |
| baroque_classical_btn = gr.Button("Baroque Classical", elem_classes="genre-btn") |
| romantic_classical_btn = gr.Button("Romantic Classical", elem_classes="genre-btn") |
| boom_bap_hiphop_btn = gr.Button("Boom Bap Hip-Hop", elem_classes="genre-btn") |
| trap_hiphop_btn = gr.Button("Trap Hip-Hop", elem_classes="genre-btn") |
| pop_rock_btn = gr.Button("Pop Rock", elem_classes="genre-btn") |
| fusion_jazz_btn = gr.Button("Fusion Jazz", elem_classes="genre-btn") |
| edm_btn = gr.Button("EDM", elem_classes="genre-btn") |
| indie_folk_btn = gr.Button("Indie Folk", elem_classes="genre-btn") |
| |
| with gr.Column(elem_classes="settings-container"): |
| cfg_scale = gr.Slider( |
| label="Guidance Scale (CFG)", |
| minimum=1.0, |
| maximum=10.0, |
| value=3.0, |
| step=0.1, |
| info="Higher values make the instrumental more closely follow the prompt." |
| ) |
| top_k = gr.Slider( |
| label="Top-K Sampling", |
| minimum=10, |
| maximum=500, |
| value=250, |
| step=10, |
| info="Limits sampling to the top k most likely tokens." |
| ) |
| top_p = gr.Slider( |
| label="Top-P Sampling", |
| minimum=0.0, |
| maximum=1.0, |
| value=0.9, |
| step=0.05, |
| info="Keeps tokens with cumulative probability above p." |
| ) |
| temperature = gr.Slider( |
| label="Temperature", |
| minimum=0.1, |
| maximum=2.0, |
| value=1.0, |
| step=0.1, |
| info="Controls randomness. Higher values make output more diverse." |
| ) |
| total_duration = gr.Slider( |
| label="Total Duration (seconds)", |
| minimum=10, |
| maximum=90, |
| value=30, |
| step=1, |
| info="Total duration of the track (10 to 90 seconds)." |
| ) |
| crossfade_duration = gr.Slider( |
| label="Crossfade Duration (ms)", |
| minimum=100, |
| maximum=2000, |
| value=500, |
| step=100, |
| info="Crossfade duration between chunks." |
| ) |
| num_variations = gr.Slider( |
| label="Number of Variations", |
| minimum=1, |
| maximum=4, |
| value=1, |
| step=1, |
| info="Number of different versions to generate with varying random seeds." |
| ) |
| with gr.Row(elem_classes="action-buttons"): |
| gen_btn = gr.Button("Generate Music") |
| clr_btn = gr.Button("Clear Inputs") |
| |
| with gr.Column(elem_classes="output-container"): |
| out_audio = gr.Audio(label="Generated Stereo Instrumental Track", type="filepath") |
| status = gr.Textbox(label="Status", interactive=False) |
| |
| classic_rock_btn.click(set_classic_rock_prompt, inputs=None, outputs=[instrumental_prompt]) |
| alternative_rock_btn.click(set_alternative_rock_prompt, inputs=None, outputs=[instrumental_prompt]) |
| detroit_techno_btn.click(set_detroit_techno_prompt, inputs=None, outputs=[instrumental_prompt]) |
| deep_house_btn.click(set_deep_house_prompt, inputs=None, outputs=[instrumental_prompt]) |
| smooth_jazz_btn.click(set_smooth_jazz_prompt, inputs=None, outputs=[instrumental_prompt]) |
| bebop_jazz_btn.click(set_bebop_jazz_prompt, inputs=None, outputs=[instrumental_prompt]) |
| baroque_classical_btn.click(set_baroque_classical_prompt, inputs=None, outputs=[instrumental_prompt]) |
| romantic_classical_btn.click(set_romantic_classical_prompt, inputs=None, outputs=[instrumental_prompt]) |
| boom_bap_hiphop_btn.click(set_boom_bap_hiphop_prompt, inputs=None, outputs=[instrumental_prompt]) |
| trap_hiphop_btn.click(set_trap_hiphop_prompt, inputs=None, outputs=[instrumental_prompt]) |
| pop_rock_btn.click(set_pop_rock_prompt, inputs=None, outputs=[instrumental_prompt]) |
| fusion_jazz_btn.click(set_fusion_jazz_prompt, inputs=None, outputs=[instrumental_prompt]) |
| edm_btn.click(set_edm_prompt, inputs=None, outputs=[instrumental_prompt]) |
| indie_folk_btn.click(set_indie_folk_prompt, inputs=None, outputs=[instrumental_prompt]) |
| gen_btn.click( |
| generate_music, |
| inputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration, num_variations], |
| outputs=[out_audio, status] |
| ) |
| clr_btn.click( |
| clear_inputs, |
| inputs=None, |
| outputs=[instrumental_prompt, cfg_scale, top_k, top_p, temperature, total_duration, crossfade_duration, num_variations] |
| ) |
|
|
| |
| app = demo.launch( |
| server_name="0.0.0.0", |
| server_port=9999, |
| share=False, |
| inbrowser=False, |
| show_error=True |
| ) |
| try: |
| fastapi_app = demo._server.app |
| fastapi_app.docs_url = None |
| fastapi_app.redoc_url = None |
| fastapi_app.openapi_url = None |
| except Exception: |
| pass |