"""Gradio ZeroGPU wrapper around Khala's backend_worker generation path. This is a best-effort port of an NGC-targeted, Megatron-Core based system to the ZeroGPU runtime. Many things can break — see the Space README for the list of caveats. If the import or first-call init fails, check the Space build/runtime logs. """ from __future__ import annotations import json import os import subprocess import sys import time import traceback import gradio as gr import spaces from huggingface_hub import snapshot_download # --------------------------------------------------------------------------- # Paths and sys.path setup # --------------------------------------------------------------------------- SPACE_ROOT = os.path.dirname(os.path.abspath(__file__)) BACKEND_DIR = os.path.join(SPACE_ROOT, "backend") MEGATRON_ROOT = os.path.join(SPACE_ROOT, "models", "Megatron") DECODER_ROOT = os.path.join(SPACE_ROOT, "models", "Decoder") CHECKPOINTS_DIR = os.path.join(SPACE_ROOT, "checkpoints") OUTPUT_DIR = os.path.join(BACKEND_DIR, "generated_audio") for path in (SPACE_ROOT, BACKEND_DIR, MEGATRON_ROOT, DECODER_ROOT): if path not in sys.path: sys.path.insert(0, path) os.makedirs(CHECKPOINTS_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # --------------------------------------------------------------------------- # Ensure pybind11 + pre-compile the Megatron dataset helpers # --------------------------------------------------------------------------- # Belt-and-suspenders: even though requirements.txt lists pybind11, the # Megatron Makefile shells out to `python3 -m pybind11 --includes` and we've # seen that fail with `No module named pybind11` on the Space — likely a path # mismatch between the pip target and /usr/local/bin/python3. So we install # explicitly with the same interpreter app.py runs in, then pre-compile the # helpers ourselves. If Megatron later re-triggers `initialize` we want the # .so already built so its compile-on-startup never runs. def _ensure_dataset_helpers_compiled() -> None: helpers_dir = os.path.join(MEGATRON_ROOT, "megatron", "core", "datasets") so_glob = [f for f in os.listdir(helpers_dir) if f.startswith("helpers_cpp") and f.endswith(".so")] if so_glob: print(f"[app] dataset helpers already compiled: {so_glob[0]}") return try: import pybind11 # noqa: F401 except ImportError: print("[app] pybind11 missing — installing via current interpreter.") subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "pybind11"]) # Run make using the same interpreter so the Makefile's `python3 -m pybind11 ...` resolves. env = os.environ.copy() env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") print(f"[app] Compiling Megatron dataset helpers in {helpers_dir} ...") try: subprocess.check_call(["make"], cwd=helpers_dir, env=env) print("[app] dataset helpers compile OK.") except subprocess.CalledProcessError as exc: print(f"[app] make failed ({exc}); Megatron will try its own compile at init time.") _ensure_dataset_helpers_compiled() # --------------------------------------------------------------------------- # Model checkpoint download (CPU-only, runs at module import on cold start) # --------------------------------------------------------------------------- KHALA_REPO = "liujiafeng/Khala-MusicGeneration-v1.0" def download_checkpoints_if_needed() -> None: backbone_marker = os.path.join( CHECKPOINTS_DIR, "backbone", "latest_checkpointed_iteration.txt" ) if os.path.isfile(backbone_marker): print("[app] Khala checkpoints already present, skipping download.") return print(f"[app] Downloading {KHALA_REPO} → {CHECKPOINTS_DIR} (this takes a while on cold start)…") t0 = time.perf_counter() snapshot_download( repo_id=KHALA_REPO, local_dir=CHECKPOINTS_DIR, local_dir_use_symlinks=False, ) print(f"[app] Download done in {time.perf_counter() - t0:.1f}s.") download_checkpoints_if_needed() # --------------------------------------------------------------------------- # Distributed env vars (Megatron expects these even on single GPU) # --------------------------------------------------------------------------- os.environ.setdefault("MASTER_ADDR", "127.0.0.1") os.environ.setdefault("MASTER_PORT", "29500") os.environ.setdefault("WORLD_SIZE", "1") os.environ.setdefault("RANK", "0") os.environ.setdefault("LOCAL_RANK", "0") os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") # --------------------------------------------------------------------------- # Lazy bootstrap of Megatron + Khala worker # --------------------------------------------------------------------------- _initialized = False def _bootstrap_worker() -> None: """One-time Megatron init + model load. Runs inside a @spaces.GPU context. Mirrors the bootstrap sequence in Khala's backend_worker.main() — sets the required CLI args before initialize_megatron, applies the embedding patch, initialises Megatron, then preloads tokenizer / backbone / superres / decoder. """ global _initialized if _initialized: return # Set the args Megatron's parser expects, before initialize_megatron reads sys.argv. first_backbone_path = os.path.join(CHECKPOINTS_DIR, "backbone") first_backbone_vocab = 130304 sys.argv = [ sys.argv[0], "--load", first_backbone_path, "--vocab-size", str(first_backbone_vocab), "--use-checkpoint-args", ] # Imports deferred to here because they pull in heavy CUDA-bound modules. import backend_worker as bw from megatron.training.initialize import initialize_megatron bw.patch_language_model_embedding() initialize_megatron( extra_args_provider=bw.add_worker_args, args_defaults={ "no_load_rng": True, "no_load_optim": True, "micro_batch_size": 1, "exit_on_missing_checkpoint": True, }, ) # Disable Transformer-Engine and Apex code paths that the checkpoint args # request but that aren't installable in the ZeroGPU environment. The # checkpoint config is loaded with --use-checkpoint-args, so we have to # override AFTER initialize_megatron and BEFORE any model is built. from megatron.training import get_args args = get_args() args.transformer_impl = "local" # was "transformer_engine" args.apply_rope_fusion = False args.bias_swiglu_fusion = False args.bias_dropout_fusion = False args.bias_gelu_fusion = False args.masked_softmax_fusion = False args.gradient_accumulation_fusion = False args.cross_entropy_loss_fusion = False args.use_flash_attn = False args.attention_softmax_in_fp32 = True args.no_persist_layer_norm = True # Megatron uses negated flag — sets config.persist_layer_norm=False print(f"[app] Overrode args for TE-less env: transformer_impl={args.transformer_impl}, " f"all fusions disabled, no_persist_layer_norm={args.no_persist_layer_norm}.") bw.preload_runtime() _initialized = True # --------------------------------------------------------------------------- # Generation entry point # --------------------------------------------------------------------------- LANGUAGES = ["English", "Chinese", "Japanese", "Korean", "Cantonese", "Instrumental"] GENRES = [ "Pop", "Rock", "R&B", "Hip-Hop", "Electronic", "Jazz", "Classical", "Folk", "Country", "Metal", "Latin", "Reggae", "Blues", "Funk", "Soul", "Indie", "Alternative", "Dance", "Acoustic", ] @spaces.GPU(duration=360) def generate( description: str, lyrics: str, language: str, genre: str, duration_min: int, tags: str, top_k_bb: int, top_k_sr: int, temperature: float, seed: int, ): try: _bootstrap_worker() from backend_worker import GenerateRequest, run_generation req = GenerateRequest( genre=genre, language=language, tags=tags, description=description, duration=int(duration_min), lyrics=lyrics, top_k_bb=int(top_k_bb), top_k_sr=int(top_k_sr), temperature=float(temperature), superres_text_mode="same_as_backbone", seed_override=int(seed), ) result = run_generation(req) if result.get("status") != "ok": raise gr.Error(result.get("error", "Unknown error from worker")) wav_path = os.path.join(OUTPUT_DIR, result["wav_filename"]) mp3_path = os.path.join(OUTPUT_DIR, result["mp3_filename"]) audio_path = mp3_path if os.path.isfile(mp3_path) else wav_path return audio_path, json.dumps(result, indent=2, default=str) except gr.Error: raise except Exception as exc: tb = traceback.format_exc() print(f"[app] generate() failed:\n{tb}") raise gr.Error(f"{type(exc).__name__}: {exc}") # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- with gr.Blocks(title="Khala — High-Fidelity Song Generation") as demo: gr.Markdown( """# Khala — High-Fidelity Song Generation Best-effort ZeroGPU wrapper around the official [Khala](https://github.com/Khala-Music-AI/Khala) backend. Cold start downloads ~several GB of weights; first generation also bootstraps Megatron, so the **first run will take many minutes**. Subsequent runs reuse the loaded models. > ⚠️ The upstream maintainers have an open quality bug as of 2026-05-07. > Output may sound off until they patch it. See the linked repo for status.""" ) with gr.Row(): with gr.Column(): description = gr.Textbox( label="Description", placeholder="A dreamy synthwave track with female vocals", lines=2, ) lyrics = gr.Textbox( label="Lyrics (one line per line; leave empty if Instrumental)", placeholder="Verse 1\nLine one\nLine two\n\nChorus\n…", lines=8, ) with gr.Row(): language = gr.Dropdown(LANGUAGES, value="English", label="Language") genre = gr.Dropdown(GENRES, value="Pop", label="Genre") with gr.Row(): duration_min = gr.Slider(1, 4, value=2, step=1, label="Duration (min)") tags = gr.Textbox(label="Tags (optional)", placeholder="acoustic, melancholic") with gr.Accordion("Sampling parameters", open=False): top_k_bb = gr.Slider(1, 200, value=50, step=1, label="Backbone top-k") top_k_sr = gr.Slider(1, 50, value=10, step=1, label="Super-res top-k") temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature") seed = gr.Number(value=0, precision=0, label="Seed (0 = use process seed)") submit = gr.Button("Generate", variant="primary") with gr.Column(): audio_out = gr.Audio(label="Generated track", type="filepath") meta_out = gr.Textbox(label="Worker result", lines=8, interactive=False) submit.click( fn=generate, inputs=[description, lyrics, language, genre, duration_min, tags, top_k_bb, top_k_sr, temperature, seed], outputs=[audio_out, meta_out], ) if __name__ == "__main__": demo.queue().launch()