khala / app.py
multimodalart's picture
multimodalart HF Staff
Fix persist_layer_norm override: set args.no_persist_layer_norm=True (Megatron uses the negated form)
da5ce3b verified
"""Gradio ZeroGPU wrapper around Khala's backend_worker generation path.
This is a best-effort port of an NGC-targeted, Megatron-Core based system
to the ZeroGPU runtime. Many things can break — see the Space README for the
list of caveats. If the import or first-call init fails, check the Space
build/runtime logs.
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
import time
import traceback
import gradio as gr
import spaces
from huggingface_hub import snapshot_download
# ---------------------------------------------------------------------------
# Paths and sys.path setup
# ---------------------------------------------------------------------------
SPACE_ROOT = os.path.dirname(os.path.abspath(__file__))
BACKEND_DIR = os.path.join(SPACE_ROOT, "backend")
MEGATRON_ROOT = os.path.join(SPACE_ROOT, "models", "Megatron")
DECODER_ROOT = os.path.join(SPACE_ROOT, "models", "Decoder")
CHECKPOINTS_DIR = os.path.join(SPACE_ROOT, "checkpoints")
OUTPUT_DIR = os.path.join(BACKEND_DIR, "generated_audio")
for path in (SPACE_ROOT, BACKEND_DIR, MEGATRON_ROOT, DECODER_ROOT):
if path not in sys.path:
sys.path.insert(0, path)
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ---------------------------------------------------------------------------
# Ensure pybind11 + pre-compile the Megatron dataset helpers
# ---------------------------------------------------------------------------
# Belt-and-suspenders: even though requirements.txt lists pybind11, the
# Megatron Makefile shells out to `python3 -m pybind11 --includes` and we've
# seen that fail with `No module named pybind11` on the Space — likely a path
# mismatch between the pip target and /usr/local/bin/python3. So we install
# explicitly with the same interpreter app.py runs in, then pre-compile the
# helpers ourselves. If Megatron later re-triggers `initialize` we want the
# .so already built so its compile-on-startup never runs.
def _ensure_dataset_helpers_compiled() -> None:
helpers_dir = os.path.join(MEGATRON_ROOT, "megatron", "core", "datasets")
so_glob = [f for f in os.listdir(helpers_dir) if f.startswith("helpers_cpp") and f.endswith(".so")]
if so_glob:
print(f"[app] dataset helpers already compiled: {so_glob[0]}")
return
try:
import pybind11 # noqa: F401
except ImportError:
print("[app] pybind11 missing — installing via current interpreter.")
subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", "pybind11"])
# Run make using the same interpreter so the Makefile's `python3 -m pybind11 ...` resolves.
env = os.environ.copy()
env["PATH"] = os.path.dirname(sys.executable) + os.pathsep + env.get("PATH", "")
print(f"[app] Compiling Megatron dataset helpers in {helpers_dir} ...")
try:
subprocess.check_call(["make"], cwd=helpers_dir, env=env)
print("[app] dataset helpers compile OK.")
except subprocess.CalledProcessError as exc:
print(f"[app] make failed ({exc}); Megatron will try its own compile at init time.")
_ensure_dataset_helpers_compiled()
# ---------------------------------------------------------------------------
# Model checkpoint download (CPU-only, runs at module import on cold start)
# ---------------------------------------------------------------------------
KHALA_REPO = "liujiafeng/Khala-MusicGeneration-v1.0"
def download_checkpoints_if_needed() -> None:
backbone_marker = os.path.join(
CHECKPOINTS_DIR, "backbone", "latest_checkpointed_iteration.txt"
)
if os.path.isfile(backbone_marker):
print("[app] Khala checkpoints already present, skipping download.")
return
print(f"[app] Downloading {KHALA_REPO}{CHECKPOINTS_DIR} (this takes a while on cold start)…")
t0 = time.perf_counter()
snapshot_download(
repo_id=KHALA_REPO,
local_dir=CHECKPOINTS_DIR,
local_dir_use_symlinks=False,
)
print(f"[app] Download done in {time.perf_counter() - t0:.1f}s.")
download_checkpoints_if_needed()
# ---------------------------------------------------------------------------
# Distributed env vars (Megatron expects these even on single GPU)
# ---------------------------------------------------------------------------
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("LOCAL_RANK", "0")
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
# ---------------------------------------------------------------------------
# Lazy bootstrap of Megatron + Khala worker
# ---------------------------------------------------------------------------
_initialized = False
def _bootstrap_worker() -> None:
"""One-time Megatron init + model load. Runs inside a @spaces.GPU context.
Mirrors the bootstrap sequence in Khala's backend_worker.main() — sets the
required CLI args before initialize_megatron, applies the embedding patch,
initialises Megatron, then preloads tokenizer / backbone / superres /
decoder.
"""
global _initialized
if _initialized:
return
# Set the args Megatron's parser expects, before initialize_megatron reads sys.argv.
first_backbone_path = os.path.join(CHECKPOINTS_DIR, "backbone")
first_backbone_vocab = 130304
sys.argv = [
sys.argv[0],
"--load", first_backbone_path,
"--vocab-size", str(first_backbone_vocab),
"--use-checkpoint-args",
]
# Imports deferred to here because they pull in heavy CUDA-bound modules.
import backend_worker as bw
from megatron.training.initialize import initialize_megatron
bw.patch_language_model_embedding()
initialize_megatron(
extra_args_provider=bw.add_worker_args,
args_defaults={
"no_load_rng": True,
"no_load_optim": True,
"micro_batch_size": 1,
"exit_on_missing_checkpoint": True,
},
)
# Disable Transformer-Engine and Apex code paths that the checkpoint args
# request but that aren't installable in the ZeroGPU environment. The
# checkpoint config is loaded with --use-checkpoint-args, so we have to
# override AFTER initialize_megatron and BEFORE any model is built.
from megatron.training import get_args
args = get_args()
args.transformer_impl = "local" # was "transformer_engine"
args.apply_rope_fusion = False
args.bias_swiglu_fusion = False
args.bias_dropout_fusion = False
args.bias_gelu_fusion = False
args.masked_softmax_fusion = False
args.gradient_accumulation_fusion = False
args.cross_entropy_loss_fusion = False
args.use_flash_attn = False
args.attention_softmax_in_fp32 = True
args.no_persist_layer_norm = True # Megatron uses negated flag — sets config.persist_layer_norm=False
print(f"[app] Overrode args for TE-less env: transformer_impl={args.transformer_impl}, "
f"all fusions disabled, no_persist_layer_norm={args.no_persist_layer_norm}.")
bw.preload_runtime()
_initialized = True
# ---------------------------------------------------------------------------
# Generation entry point
# ---------------------------------------------------------------------------
LANGUAGES = ["English", "Chinese", "Japanese", "Korean", "Cantonese", "Instrumental"]
GENRES = [
"Pop", "Rock", "R&B", "Hip-Hop", "Electronic", "Jazz", "Classical",
"Folk", "Country", "Metal", "Latin", "Reggae", "Blues", "Funk",
"Soul", "Indie", "Alternative", "Dance", "Acoustic",
]
@spaces.GPU(duration=360)
def generate(
description: str,
lyrics: str,
language: str,
genre: str,
duration_min: int,
tags: str,
top_k_bb: int,
top_k_sr: int,
temperature: float,
seed: int,
):
try:
_bootstrap_worker()
from backend_worker import GenerateRequest, run_generation
req = GenerateRequest(
genre=genre,
language=language,
tags=tags,
description=description,
duration=int(duration_min),
lyrics=lyrics,
top_k_bb=int(top_k_bb),
top_k_sr=int(top_k_sr),
temperature=float(temperature),
superres_text_mode="same_as_backbone",
seed_override=int(seed),
)
result = run_generation(req)
if result.get("status") != "ok":
raise gr.Error(result.get("error", "Unknown error from worker"))
wav_path = os.path.join(OUTPUT_DIR, result["wav_filename"])
mp3_path = os.path.join(OUTPUT_DIR, result["mp3_filename"])
audio_path = mp3_path if os.path.isfile(mp3_path) else wav_path
return audio_path, json.dumps(result, indent=2, default=str)
except gr.Error:
raise
except Exception as exc:
tb = traceback.format_exc()
print(f"[app] generate() failed:\n{tb}")
raise gr.Error(f"{type(exc).__name__}: {exc}")
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Khala — High-Fidelity Song Generation") as demo:
gr.Markdown(
"""# Khala — High-Fidelity Song Generation
Best-effort ZeroGPU wrapper around the official
[Khala](https://github.com/Khala-Music-AI/Khala) backend. Cold start downloads
~several GB of weights; first generation also bootstraps Megatron, so the
**first run will take many minutes**. Subsequent runs reuse the loaded models.
> ⚠️ The upstream maintainers have an open quality bug as of 2026-05-07.
> Output may sound off until they patch it. See the linked repo for status."""
)
with gr.Row():
with gr.Column():
description = gr.Textbox(
label="Description",
placeholder="A dreamy synthwave track with female vocals",
lines=2,
)
lyrics = gr.Textbox(
label="Lyrics (one line per line; leave empty if Instrumental)",
placeholder="Verse 1\nLine one\nLine two\n\nChorus\n…",
lines=8,
)
with gr.Row():
language = gr.Dropdown(LANGUAGES, value="English", label="Language")
genre = gr.Dropdown(GENRES, value="Pop", label="Genre")
with gr.Row():
duration_min = gr.Slider(1, 4, value=2, step=1, label="Duration (min)")
tags = gr.Textbox(label="Tags (optional)", placeholder="acoustic, melancholic")
with gr.Accordion("Sampling parameters", open=False):
top_k_bb = gr.Slider(1, 200, value=50, step=1, label="Backbone top-k")
top_k_sr = gr.Slider(1, 50, value=10, step=1, label="Super-res top-k")
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Temperature")
seed = gr.Number(value=0, precision=0, label="Seed (0 = use process seed)")
submit = gr.Button("Generate", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="Generated track", type="filepath")
meta_out = gr.Textbox(label="Worker result", lines=8, interactive=False)
submit.click(
fn=generate,
inputs=[description, lyrics, language, genre, duration_min, tags,
top_k_bb, top_k_sr, temperature, seed],
outputs=[audio_out, meta_out],
)
if __name__ == "__main__":
demo.queue().launch()