from __future__ import annotations import os import sys import tempfile import threading from dataclasses import dataclass from pathlib import Path from typing import Dict, Tuple try: import spaces except ImportError: class _SpacesShim: @staticmethod def GPU(*_args, **_kwargs): def decorator(fn): return fn return decorator spaces = _SpacesShim() import gradio as gr import torch from huggingface_hub import hf_hub_download, snapshot_download APP_ROOT = Path(__file__).resolve().parent HEARTLIB_SRC = APP_ROOT / "heartlib" / "src" HEARTLIB_PACKAGE_INIT = HEARTLIB_SRC / "heartlib" / "__init__.py" if HEARTLIB_PACKAGE_INIT.is_file(): sys.path.insert(0, str(HEARTLIB_SRC)) from heartlib import HeartMuLaGenPipeline @dataclass(frozen=True) class ModelConfig: version: str generator_repo: str mula_repo: str codec_repo: str mula_dirname: str codec_dirname: str MODEL_CONFIG = ModelConfig( version=os.getenv("HEARTMULA_VERSION", "3B"), generator_repo=os.getenv("HEARTMULA_GENERATOR_REPO", "HeartMuLa/HeartMuLaGen"), mula_repo=os.getenv( "HEARTMULA_MULA_REPO", "HeartMuLa/HeartMuLa-oss-3B-happy-new-year" ), codec_repo=os.getenv( "HEARTMULA_CODEC_REPO", "HeartMuLa/HeartCodec-oss-20260123" ), mula_dirname=os.getenv("HEARTMULA_MULA_DIRNAME", "HeartMuLa-oss-3B"), codec_dirname=os.getenv("HEARTMULA_CODEC_DIRNAME", "HeartCodec-oss"), ) GPU_DURATION_SECONDS = int(os.getenv("HEARTMULA_GPU_DURATION_SECONDS", "300")) COMPILE_DURATION_SECONDS = int(os.getenv("HEARTMULA_COMPILE_DURATION_SECONDS", "600")) MAX_DURATION_SECONDS = int(os.getenv("HEARTMULA_MAX_DURATION_SECONDS", "180")) DEFAULT_DURATION_SECONDS = min( int(os.getenv("HEARTMULA_DEFAULT_DURATION_SECONDS", "60")), MAX_DURATION_SECONDS, ) ENABLE_FLASH_ATTN = os.getenv("HEARTMULA_ENABLE_FLASH_ATTN", "1") != "0" ENABLE_AOTI = os.getenv("HEARTMULA_ENABLE_AOTI", "1") != "0" AOTI_MAX_BATCH = int(os.getenv("HEARTMULA_AOTI_MAX_BATCH", "2")) AOTI_MAX_SEQ_LEN = int(os.getenv("HEARTMULA_AOTI_MAX_SEQ_LEN", "4096")) KEEP_MULA_LOADED = os.getenv("HEARTMULA_KEEP_MULA_LOADED", "1") != "0" KEEP_CODEC_LOADED = os.getenv("HEARTMULA_KEEP_CODEC_LOADED", "0") != "0" MODEL_LOCK = threading.Lock() PIPELINE_LOCK = threading.Lock() PIPELINE_CACHE: Dict[Tuple[str, str], HeartMuLaGenPipeline] = {} _RUNTIME_PREPARED = False def _default_cache_root() -> Path: env_home = os.getenv("HF_HOME") if env_home: return Path(env_home) data_home = Path("/data/.huggingface") if data_home.parent.exists(): return data_home return Path("/tmp/huggingface") def _model_root() -> Path: return Path( os.getenv( "HEARTMULA_MODEL_DIR", str(_default_cache_root() / "heartmula_models"), ) ) def _read_text(path: Path, fallback: str) -> str: if path.is_file(): return path.read_text(encoding="utf-8").strip() return fallback def _cached_model_exists(model_dir: Path) -> bool: required_paths = [ model_dir / "tokenizer.json", model_dir / "gen_config.json", model_dir / MODEL_CONFIG.mula_dirname, model_dir / MODEL_CONFIG.codec_dirname, ] return all(path.exists() for path in required_paths) def ensure_model_artifacts(progress: gr.Progress | None = None) -> Path: model_dir = _model_root() model_dir.mkdir(parents=True, exist_ok=True) if _cached_model_exists(model_dir): if progress is not None: progress(0.05, desc="Using cached model artifacts") return model_dir with MODEL_LOCK: if _cached_model_exists(model_dir): if progress is not None: progress(0.05, desc="Using cached model artifacts") return model_dir if progress is not None: progress(0.05, desc="Downloading tokenizer and generation config") for filename in ("tokenizer.json", "gen_config.json"): hf_hub_download( repo_id=MODEL_CONFIG.generator_repo, filename=filename, local_dir=str(model_dir), ) if progress is not None: progress(0.25, desc="Downloading HeartMuLa checkpoint") snapshot_download( repo_id=MODEL_CONFIG.mula_repo, local_dir=str(model_dir / MODEL_CONFIG.mula_dirname), ) if progress is not None: progress(0.6, desc="Downloading HeartCodec checkpoint") snapshot_download( repo_id=MODEL_CONFIG.codec_repo, local_dir=str(model_dir / MODEL_CONFIG.codec_dirname), ) if progress is not None: progress(0.95, desc="Model artifacts ready") return model_dir def _runtime_key() -> Tuple[str, str]: runtime = "cuda" if torch.cuda.is_available() else "cpu" return runtime, str(_model_root()) def get_pipeline(model_dir: Path) -> HeartMuLaGenPipeline: """Create pipeline and store acceleration config. Does NOT trigger compilation.""" runtime = "cuda" if torch.cuda.is_available() else "cpu" cache_key = (runtime, str(model_dir)) with PIPELINE_LOCK: if cache_key in PIPELINE_CACHE: return PIPELINE_CACHE[cache_key] if runtime == "cuda": device = { "mula": torch.device("cuda"), "codec": torch.device("cuda"), } dtype = { "mula": torch.bfloat16, "codec": torch.float32, } lazy_load = { "mula": not KEEP_MULA_LOADED, "codec": not KEEP_CODEC_LOADED, } else: device = torch.device("cpu") dtype = torch.float32 lazy_load = False pipeline = HeartMuLaGenPipeline.from_pretrained( str(model_dir), device=device, dtype=dtype, version=MODEL_CONFIG.version, lazy_load=lazy_load, ) pipeline.configure_runtime_acceleration( enable_flash_attn=runtime == "cuda" and ENABLE_FLASH_ATTN, enable_aoti=runtime == "cuda" and ENABLE_AOTI, max_batch_size=AOTI_MAX_BATCH, max_compile_seq_len=AOTI_MAX_SEQ_LEN, ) PIPELINE_CACHE[cache_key] = pipeline return pipeline @spaces.GPU(duration=COMPILE_DURATION_SECONDS) def _compile_runtime(model_dir: str): """AoTI compilation + FA3 injection on a real GPU. Uses a separate, long time budget (default 1500 s) so that first-time compilation does not compete with inference for GPU seconds. On subsequent calls the ``spaces.aoti_compile`` filesystem cache makes this effectively a no-op. """ pipeline = get_pipeline(Path(model_dir)) pipeline.prepare_runtime() @spaces.GPU(duration=GPU_DURATION_SECONDS) def _run_generation( model_dir: str, lyrics: str, tags: str, max_duration_seconds: int, temperature: float, topk: int, cfg_scale: float, progress=gr.Progress(track_tqdm=True), ): pipeline = get_pipeline(Path(model_dir)) max_audio_length_ms = max_duration_seconds * 1000 with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: output_path = fp.name progress(0.05, desc="Generating audio") with torch.no_grad(): pipeline( { "lyrics": lyrics, "tags": tags, }, max_audio_length_ms=max_audio_length_ms, save_path=output_path, topk=topk, temperature=temperature, cfg_scale=cfg_scale, ) return output_path def generate_music( lyrics: str, tags: str, max_duration_seconds: int, temperature: float, topk: int, cfg_scale: float, progress=gr.Progress(track_tqdm=True), ): if not lyrics.strip(): raise gr.Error("Please enter lyrics before generating.") if not tags.strip(): raise gr.Error("Please enter at least one style tag.") model_dir = ensure_model_artifacts(progress) global _RUNTIME_PREPARED if not _RUNTIME_PREPARED: progress(0.02, desc="Compiling runtime (first request, please wait)") _compile_runtime(str(model_dir)) _RUNTIME_PREPARED = True return _run_generation( str(model_dir), lyrics, tags, max_duration_seconds, temperature, topk, cfg_scale, ) DEFAULT_LYRICS = _read_text( APP_ROOT / "heartlib" / "assets" / "lyrics.txt", """[Verse] The city wakes before the sun We keep moving one by one [Chorus] Hold the light and sing it through Every road comes back to you""", ) DEFAULT_TAGS = _read_text( APP_ROOT / "heartlib" / "assets" / "tags.txt", "female,indie pop,piano,emotional,night,silky,memories", ) with gr.Blocks(title="HeartMuLa ZeroGPU Demo") as demo: gr.Markdown( """ # HeartMuLa ZeroGPU Demo Generate music from lyrics and style tags with **HeartMuLa** on Hugging Face Spaces. First use may take longer because model files need to be cached. """ ) with gr.Row(): with gr.Column(scale=1): lyrics_input = gr.Textbox( label="Lyrics", lines=18, value=DEFAULT_LYRICS, placeholder="Use structured sections such as [Verse], [Chorus], [Bridge].", ) tags_input = gr.Textbox( label="Tags", value=DEFAULT_TAGS, placeholder="female,indie pop,piano,emotional,night,silky,memories", info="Comma-separated tags without spaces for best compatibility.", ) with gr.Accordion("Generation Settings", open=False): max_duration_input = gr.Slider( minimum=30, maximum=MAX_DURATION_SECONDS, value=DEFAULT_DURATION_SECONDS, step=10, label="Max Duration (seconds)", ) temperature_input = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature", ) topk_input = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top-K", ) cfg_scale_input = gr.Slider( minimum=1.0, maximum=3.0, value=1.5, step=0.1, label="CFG Scale", ) generate_button = gr.Button("Generate Music", variant="primary") with gr.Column(scale=1): audio_output = gr.Audio(label="Generated Audio", type="filepath") generate_button.click( fn=generate_music, inputs=[ lyrics_input, tags_input, max_duration_input, temperature_input, topk_input, cfg_scale_input, ], outputs=audio_output, ) demo.queue(default_concurrency_limit=1) if __name__ == "__main__": demo.launch()