Spaces:
Running on Zero
Running on Zero
multimodalart HF Staff
Fix persist_layer_norm override: set args.no_persist_layer_norm=True (Megatron uses the negated form)
da5ce3b verified | """Gradio ZeroGPU wrapper around Khala's backend_worker generation path. | |
| This is a best-effort port of an NGC-targeted, Megatron-Core based system | |
| to the ZeroGPU runtime. Many things can break — see the Space README for the | |
| list of caveats. If the import or first-call init fails, check the Space | |
| build/runtime logs. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| import time | |
| import traceback | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| # --------------------------------------------------------------------------- | |
| # Paths and sys.path setup | |
| # --------------------------------------------------------------------------- | |
| SPACE_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| BACKEND_DIR = os.path.join(SPACE_ROOT, "backend") | |
| MEGATRON_ROOT = os.path.join(SPACE_ROOT, "models", "Megatron") | |
| DECODER_ROOT = os.path.join(SPACE_ROOT, "models", "Decoder") | |
| CHECKPOINTS_DIR = os.path.join(SPACE_ROOT, "checkpoints") | |
| OUTPUT_DIR = os.path.join(BACKEND_DIR, "generated_audio") | |
| for path in (SPACE_ROOT, BACKEND_DIR, MEGATRON_ROOT, DECODER_ROOT): | |
| if path not in sys.path: | |
| sys.path.insert(0, path) | |
| os.makedirs(CHECKPOINTS_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Ensure pybind11 + pre-compile the Megatron dataset helpers | |
| # --------------------------------------------------------------------------- | |
| # Belt-and-suspenders: even though requirements.txt lists pybind11, the | |
| # Megatron Makefile shells out to `python3 -m pybind11 --includes` and we've | |
| # seen that fail with `No module named pybind11` on the Space — likely a path | |
| # mismatch between the pip target and /usr/local/bin/python3. So we install | |
| # explicitly with the same interpreter app.py runs in, then pre-compile the | |
| # helpers ourselves. If Megatron later re-triggers `initialize` we want the | |
| # .so already built so its compile-on-startup never runs. | |
| def _ensure_dataset_helpers_compiled() -> None: | |
| helpers_dir = os.path.join(MEGATRON_ROOT, "megatron", "core", "datasets") | |
| so_glob = [f for f in os.listdir(helpers_dir) if f.startswith("helpers_cpp") and f.endswith(".so")] | |
| if so_glob: | |
| print(f"[app] dataset helpers already compiled: {so_glob[0]}") | |
| return | |
| try: | |
| import pybind11 # noqa: F401 | |
| except ImportError: | |
| print("[app] pybind11 missing — installing via current interpreter.") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "pybind11"]) | |
| # Run make using the same interpreter so the Makefile's `python3 -m pybind11 ...` resolves. | |
| env = os.environ.copy() | |
| env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "") | |
| print(f"[app] Compiling Megatron dataset helpers in {helpers_dir} ...") | |
| try: | |
| subprocess.check_call(["make"], cwd=helpers_dir, env=env) | |
| print("[app] dataset helpers compile OK.") | |
| except subprocess.CalledProcessError as exc: | |
| print(f"[app] make failed ({exc}); Megatron will try its own compile at init time.") | |
| _ensure_dataset_helpers_compiled() | |
| # --------------------------------------------------------------------------- | |
| # Model checkpoint download (CPU-only, runs at module import on cold start) | |
| # --------------------------------------------------------------------------- | |
| KHALA_REPO = "liujiafeng/Khala-MusicGeneration-v1.0" | |
| def download_checkpoints_if_needed() -> None: | |
| backbone_marker = os.path.join( | |
| CHECKPOINTS_DIR, "backbone", "latest_checkpointed_iteration.txt" | |
| ) | |
| if os.path.isfile(backbone_marker): | |
| print("[app] Khala checkpoints already present, skipping download.") | |
| return | |
| print(f"[app] Downloading {KHALA_REPO} → {CHECKPOINTS_DIR} (this takes a while on cold start)…") | |
| t0 = time.perf_counter() | |
| snapshot_download( | |
| repo_id=KHALA_REPO, | |
| local_dir=CHECKPOINTS_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| print(f"[app] Download done in {time.perf_counter() - t0:.1f}s.") | |
| download_checkpoints_if_needed() | |
| # --------------------------------------------------------------------------- | |
| # Distributed env vars (Megatron expects these even on single GPU) | |
| # --------------------------------------------------------------------------- | |
| os.environ.setdefault("MASTER_ADDR", "127.0.0.1") | |
| os.environ.setdefault("MASTER_PORT", "29500") | |
| os.environ.setdefault("WORLD_SIZE", "1") | |
| os.environ.setdefault("RANK", "0") | |
| os.environ.setdefault("LOCAL_RANK", "0") | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") | |
| # --------------------------------------------------------------------------- | |
| # Lazy bootstrap of Megatron + Khala worker | |
| # --------------------------------------------------------------------------- | |
| _initialized = False | |
| def _bootstrap_worker() -> None: | |
| """One-time Megatron init + model load. Runs inside a @spaces.GPU context. | |
| Mirrors the bootstrap sequence in Khala's backend_worker.main() — sets the | |
| required CLI args before initialize_megatron, applies the embedding patch, | |
| initialises Megatron, then preloads tokenizer / backbone / superres / | |
| decoder. | |
| """ | |
| global _initialized | |
| if _initialized: | |
| return | |
| # Set the args Megatron's parser expects, before initialize_megatron reads sys.argv. | |
| first_backbone_path = os.path.join(CHECKPOINTS_DIR, "backbone") | |
| first_backbone_vocab = 130304 | |
| sys.argv = [ | |
| sys.argv[0], | |
| "--load", first_backbone_path, | |
| "--vocab-size", str(first_backbone_vocab), | |
| "--use-checkpoint-args", | |
| ] | |
| # Imports deferred to here because they pull in heavy CUDA-bound modules. | |
| import backend_worker as bw | |
| from megatron.training.initialize import initialize_megatron | |
| bw.patch_language_model_embedding() | |
| initialize_megatron( | |
| extra_args_provider=bw.add_worker_args, | |
| args_defaults={ | |
| "no_load_rng": True, | |
| "no_load_optim": True, | |
| "micro_batch_size": 1, | |
| "exit_on_missing_checkpoint": True, | |
| }, | |
| ) | |
| # Disable Transformer-Engine and Apex code paths that the checkpoint args | |
| # request but that aren't installable in the ZeroGPU environment. The | |
| # checkpoint config is loaded with --use-checkpoint-args, so we have to | |
| # override AFTER initialize_megatron and BEFORE any model is built. | |
| from megatron.training import get_args | |
| args = get_args() | |
| args.transformer_impl = "local" # was "transformer_engine" | |
| args.apply_rope_fusion = False | |
| args.bias_swiglu_fusion = False | |
| args.bias_dropout_fusion = False | |
| args.bias_gelu_fusion = False | |
| args.masked_softmax_fusion = False | |
| args.gradient_accumulation_fusion = False | |
| args.cross_entropy_loss_fusion = False | |
| args.use_flash_attn = False | |
| args.attention_softmax_in_fp32 = True | |
| args.no_persist_layer_norm = True # Megatron uses negated flag — sets config.persist_layer_norm=False | |
| print(f"[app] Overrode args for TE-less env: transformer_impl={args.transformer_impl}, " | |
| f"all fusions disabled, no_persist_layer_norm={args.no_persist_layer_norm}.") | |
| bw.preload_runtime() | |
| _initialized = True | |
| # --------------------------------------------------------------------------- | |
| # Generation entry point | |
| # --------------------------------------------------------------------------- | |
| LANGUAGES = ["English", "Chinese", "Japanese", "Korean", "Cantonese", "Instrumental"] | |
| GENRES = [ | |
| "Pop", "Rock", "R&B", "Hip-Hop", "Electronic", "Jazz", "Classical", | |
| "Folk", "Country", "Metal", "Latin", "Reggae", "Blues", "Funk", | |
| "Soul", "Indie", "Alternative", "Dance", "Acoustic", | |
| ] | |
| def generate( | |
| description: str, | |
| lyrics: str, | |
| language: str, | |
| genre: str, | |
| duration_min: int, | |
| tags: str, | |
| top_k_bb: int, | |
| top_k_sr: int, | |
| temperature: float, | |
| seed: int, | |
| ): | |
| try: | |
| _bootstrap_worker() | |
| from backend_worker import GenerateRequest, run_generation | |
| req = GenerateRequest( | |
| genre=genre, | |
| language=language, | |
| tags=tags, | |
| description=description, | |
| duration=int(duration_min), | |
| lyrics=lyrics, | |
| top_k_bb=int(top_k_bb), | |
| top_k_sr=int(top_k_sr), | |
| temperature=float(temperature), | |
| superres_text_mode="same_as_backbone", | |
| seed_override=int(seed), | |
| ) | |
| result = run_generation(req) | |
| if result.get("status") != "ok": | |
| raise gr.Error(result.get("error", "Unknown error from worker")) | |
| wav_path = os.path.join(OUTPUT_DIR, result["wav_filename"]) | |
| mp3_path = os.path.join(OUTPUT_DIR, result["mp3_filename"]) | |
| audio_path = mp3_path if os.path.isfile(mp3_path) else wav_path | |
| return audio_path, json.dumps(result, indent=2, default=str) | |
| except gr.Error: | |
| raise | |
| except Exception as exc: | |
| tb = traceback.format_exc() | |
| print(f"[app] generate() failed:\n{tb}") | |
| raise gr.Error(f"{type(exc).__name__}: {exc}") | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="Khala — High-Fidelity Song Generation") as demo: | |
| gr.Markdown( | |
| """# Khala — High-Fidelity Song Generation | |
| Best-effort ZeroGPU wrapper around the official | |
| [Khala](https://github.com/Khala-Music-AI/Khala) backend. Cold start downloads | |
| ~several GB of weights; first generation also bootstraps Megatron, so the | |
| **first run will take many minutes**. Subsequent runs reuse the loaded models. | |
| > ⚠️ The upstream maintainers have an open quality bug as of 2026-05-07. | |
| > Output may sound off until they patch it. See the linked repo for status.""" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| description = gr.Textbox( | |
| label="Description", | |
| placeholder="A dreamy synthwave track with female vocals", | |
| lines=2, | |
| ) | |
| lyrics = gr.Textbox( | |
| label="Lyrics (one line per line; leave empty if Instrumental)", | |
| placeholder="Verse 1\nLine one\nLine two\n\nChorus\n…", | |
| lines=8, | |
| ) | |
| with gr.Row(): | |
| language = gr.Dropdown(LANGUAGES, value="English", label="Language") | |
| genre = gr.Dropdown(GENRES, value="Pop", label="Genre") | |
| with gr.Row(): | |
| duration_min = gr.Slider(1, 4, value=2, step=1, label="Duration (min)") | |
| tags = gr.Textbox(label="Tags (optional)", placeholder="acoustic, melancholic") | |
| with gr.Accordion("Sampling parameters", open=False): | |
| top_k_bb = gr.Slider(1, 200, value=50, step=1, label="Backbone top-k") | |
| top_k_sr = gr.Slider(1, 50, value=10, step=1, label="Super-res top-k") | |
| temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature") | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 = use process seed)") | |
| submit = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| audio_out = gr.Audio(label="Generated track", type="filepath") | |
| meta_out = gr.Textbox(label="Worker result", lines=8, interactive=False) | |
| submit.click( | |
| fn=generate, | |
| inputs=[description, lyrics, language, genre, duration_min, tags, | |
| top_k_bb, top_k_sr, temperature, seed], | |
| outputs=[audio_out, meta_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |