| from __future__ import annotations |
|
|
| import os |
| import sys |
| import tempfile |
| import threading |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| try: |
| import spaces |
| except ImportError: |
| class _SpacesShim: |
| @staticmethod |
| def GPU(*_args, **_kwargs): |
| def decorator(fn): |
| return fn |
|
|
| return decorator |
|
|
| spaces = _SpacesShim() |
|
|
| import gradio as gr |
| import torch |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
| APP_ROOT = Path(__file__).resolve().parent |
| HEARTLIB_SRC = APP_ROOT / "heartlib" / "src" |
| HEARTLIB_PACKAGE_INIT = HEARTLIB_SRC / "heartlib" / "__init__.py" |
| if HEARTLIB_PACKAGE_INIT.is_file(): |
| sys.path.insert(0, str(HEARTLIB_SRC)) |
|
|
| from heartlib import HeartMuLaGenPipeline |
|
|
|
|
| @dataclass(frozen=True) |
| class ModelConfig: |
| version: str |
| generator_repo: str |
| mula_repo: str |
| codec_repo: str |
| mula_dirname: str |
| codec_dirname: str |
|
|
|
|
| MODEL_CONFIG = ModelConfig( |
| version=os.getenv("HEARTMULA_VERSION", "3B"), |
| generator_repo=os.getenv("HEARTMULA_GENERATOR_REPO", "HeartMuLa/HeartMuLaGen"), |
| mula_repo=os.getenv( |
| "HEARTMULA_MULA_REPO", "HeartMuLa/HeartMuLa-oss-3B-happy-new-year" |
| ), |
| codec_repo=os.getenv( |
| "HEARTMULA_CODEC_REPO", "HeartMuLa/HeartCodec-oss-20260123" |
| ), |
| mula_dirname=os.getenv("HEARTMULA_MULA_DIRNAME", "HeartMuLa-oss-3B"), |
| codec_dirname=os.getenv("HEARTMULA_CODEC_DIRNAME", "HeartCodec-oss"), |
| ) |
|
|
| GPU_DURATION_SECONDS = int(os.getenv("HEARTMULA_GPU_DURATION_SECONDS", "300")) |
| COMPILE_DURATION_SECONDS = int(os.getenv("HEARTMULA_COMPILE_DURATION_SECONDS", "600")) |
| MAX_DURATION_SECONDS = int(os.getenv("HEARTMULA_MAX_DURATION_SECONDS", "180")) |
| DEFAULT_DURATION_SECONDS = min( |
| int(os.getenv("HEARTMULA_DEFAULT_DURATION_SECONDS", "60")), |
| MAX_DURATION_SECONDS, |
| ) |
| ENABLE_FLASH_ATTN = os.getenv("HEARTMULA_ENABLE_FLASH_ATTN", "1") != "0" |
| ENABLE_AOTI = os.getenv("HEARTMULA_ENABLE_AOTI", "1") != "0" |
| AOTI_MAX_BATCH = int(os.getenv("HEARTMULA_AOTI_MAX_BATCH", "2")) |
| AOTI_MAX_SEQ_LEN = int(os.getenv("HEARTMULA_AOTI_MAX_SEQ_LEN", "4096")) |
| KEEP_MULA_LOADED = os.getenv("HEARTMULA_KEEP_MULA_LOADED", "1") != "0" |
| KEEP_CODEC_LOADED = os.getenv("HEARTMULA_KEEP_CODEC_LOADED", "0") != "0" |
|
|
| MODEL_LOCK = threading.Lock() |
| PIPELINE_LOCK = threading.Lock() |
| PIPELINE_CACHE: Dict[Tuple[str, str], HeartMuLaGenPipeline] = {} |
| _RUNTIME_PREPARED = False |
|
|
|
|
| def _default_cache_root() -> Path: |
| env_home = os.getenv("HF_HOME") |
| if env_home: |
| return Path(env_home) |
|
|
| data_home = Path("/data/.huggingface") |
| if data_home.parent.exists(): |
| return data_home |
|
|
| return Path("/tmp/huggingface") |
|
|
|
|
| def _model_root() -> Path: |
| return Path( |
| os.getenv( |
| "HEARTMULA_MODEL_DIR", |
| str(_default_cache_root() / "heartmula_models"), |
| ) |
| ) |
|
|
|
|
| def _read_text(path: Path, fallback: str) -> str: |
| if path.is_file(): |
| return path.read_text(encoding="utf-8").strip() |
| return fallback |
|
|
|
|
| def _cached_model_exists(model_dir: Path) -> bool: |
| required_paths = [ |
| model_dir / "tokenizer.json", |
| model_dir / "gen_config.json", |
| model_dir / MODEL_CONFIG.mula_dirname, |
| model_dir / MODEL_CONFIG.codec_dirname, |
| ] |
| return all(path.exists() for path in required_paths) |
|
|
|
|
| def ensure_model_artifacts(progress: gr.Progress | None = None) -> Path: |
| model_dir = _model_root() |
| model_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if _cached_model_exists(model_dir): |
| if progress is not None: |
| progress(0.05, desc="Using cached model artifacts") |
| return model_dir |
|
|
| with MODEL_LOCK: |
| if _cached_model_exists(model_dir): |
| if progress is not None: |
| progress(0.05, desc="Using cached model artifacts") |
| return model_dir |
|
|
| if progress is not None: |
| progress(0.05, desc="Downloading tokenizer and generation config") |
| for filename in ("tokenizer.json", "gen_config.json"): |
| hf_hub_download( |
| repo_id=MODEL_CONFIG.generator_repo, |
| filename=filename, |
| local_dir=str(model_dir), |
| ) |
|
|
| if progress is not None: |
| progress(0.25, desc="Downloading HeartMuLa checkpoint") |
| snapshot_download( |
| repo_id=MODEL_CONFIG.mula_repo, |
| local_dir=str(model_dir / MODEL_CONFIG.mula_dirname), |
| ) |
|
|
| if progress is not None: |
| progress(0.6, desc="Downloading HeartCodec checkpoint") |
| snapshot_download( |
| repo_id=MODEL_CONFIG.codec_repo, |
| local_dir=str(model_dir / MODEL_CONFIG.codec_dirname), |
| ) |
|
|
| if progress is not None: |
| progress(0.95, desc="Model artifacts ready") |
| return model_dir |
|
|
|
|
| def _runtime_key() -> Tuple[str, str]: |
| runtime = "cuda" if torch.cuda.is_available() else "cpu" |
| return runtime, str(_model_root()) |
|
|
|
|
| def get_pipeline(model_dir: Path) -> HeartMuLaGenPipeline: |
| """Create pipeline and store acceleration config. Does NOT trigger compilation.""" |
| runtime = "cuda" if torch.cuda.is_available() else "cpu" |
| cache_key = (runtime, str(model_dir)) |
|
|
| with PIPELINE_LOCK: |
| if cache_key in PIPELINE_CACHE: |
| return PIPELINE_CACHE[cache_key] |
|
|
| if runtime == "cuda": |
| device = { |
| "mula": torch.device("cuda"), |
| "codec": torch.device("cuda"), |
| } |
| dtype = { |
| "mula": torch.bfloat16, |
| "codec": torch.float32, |
| } |
| lazy_load = { |
| "mula": not KEEP_MULA_LOADED, |
| "codec": not KEEP_CODEC_LOADED, |
| } |
| else: |
| device = torch.device("cpu") |
| dtype = torch.float32 |
| lazy_load = False |
|
|
| pipeline = HeartMuLaGenPipeline.from_pretrained( |
| str(model_dir), |
| device=device, |
| dtype=dtype, |
| version=MODEL_CONFIG.version, |
| lazy_load=lazy_load, |
| ) |
| pipeline.configure_runtime_acceleration( |
| enable_flash_attn=runtime == "cuda" and ENABLE_FLASH_ATTN, |
| enable_aoti=runtime == "cuda" and ENABLE_AOTI, |
| max_batch_size=AOTI_MAX_BATCH, |
| max_compile_seq_len=AOTI_MAX_SEQ_LEN, |
| ) |
| PIPELINE_CACHE[cache_key] = pipeline |
| return pipeline |
|
|
|
|
| @spaces.GPU(duration=COMPILE_DURATION_SECONDS) |
| def _compile_runtime(model_dir: str): |
| """AoTI compilation + FA3 injection on a real GPU. |
| |
| Uses a separate, long time budget (default 1500 s) so that first-time |
| compilation does not compete with inference for GPU seconds. On |
| subsequent calls the ``spaces.aoti_compile`` filesystem cache makes |
| this effectively a no-op. |
| """ |
| pipeline = get_pipeline(Path(model_dir)) |
| pipeline.prepare_runtime() |
|
|
|
|
| @spaces.GPU(duration=GPU_DURATION_SECONDS) |
| def _run_generation( |
| model_dir: str, |
| lyrics: str, |
| tags: str, |
| max_duration_seconds: int, |
| temperature: float, |
| topk: int, |
| cfg_scale: float, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| pipeline = get_pipeline(Path(model_dir)) |
| max_audio_length_ms = max_duration_seconds * 1000 |
|
|
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: |
| output_path = fp.name |
|
|
| progress(0.05, desc="Generating audio") |
| with torch.no_grad(): |
| pipeline( |
| { |
| "lyrics": lyrics, |
| "tags": tags, |
| }, |
| max_audio_length_ms=max_audio_length_ms, |
| save_path=output_path, |
| topk=topk, |
| temperature=temperature, |
| cfg_scale=cfg_scale, |
| ) |
|
|
| return output_path |
|
|
|
|
| def generate_music( |
| lyrics: str, |
| tags: str, |
| max_duration_seconds: int, |
| temperature: float, |
| topk: int, |
| cfg_scale: float, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| if not lyrics.strip(): |
| raise gr.Error("Please enter lyrics before generating.") |
| if not tags.strip(): |
| raise gr.Error("Please enter at least one style tag.") |
|
|
| model_dir = ensure_model_artifacts(progress) |
|
|
| global _RUNTIME_PREPARED |
| if not _RUNTIME_PREPARED: |
| progress(0.02, desc="Compiling runtime (first request, please wait)") |
| _compile_runtime(str(model_dir)) |
| _RUNTIME_PREPARED = True |
|
|
| return _run_generation( |
| str(model_dir), |
| lyrics, |
| tags, |
| max_duration_seconds, |
| temperature, |
| topk, |
| cfg_scale, |
| ) |
|
|
|
|
| DEFAULT_LYRICS = _read_text( |
| APP_ROOT / "heartlib" / "assets" / "lyrics.txt", |
| """[Verse] |
| The city wakes before the sun |
| We keep moving one by one |
| |
| [Chorus] |
| Hold the light and sing it through |
| Every road comes back to you""", |
| ) |
|
|
| DEFAULT_TAGS = _read_text( |
| APP_ROOT / "heartlib" / "assets" / "tags.txt", |
| "female,indie pop,piano,emotional,night,silky,memories", |
| ) |
|
|
|
|
| with gr.Blocks(title="HeartMuLa ZeroGPU Demo") as demo: |
| gr.Markdown( |
| """ |
| # HeartMuLa ZeroGPU Demo |
| |
| Generate music from lyrics and style tags with **HeartMuLa** on Hugging Face Spaces. |
| |
| First use may take longer because model files need to be cached. |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| lyrics_input = gr.Textbox( |
| label="Lyrics", |
| lines=18, |
| value=DEFAULT_LYRICS, |
| placeholder="Use structured sections such as [Verse], [Chorus], [Bridge].", |
| ) |
| tags_input = gr.Textbox( |
| label="Tags", |
| value=DEFAULT_TAGS, |
| placeholder="female,indie pop,piano,emotional,night,silky,memories", |
| info="Comma-separated tags without spaces for best compatibility.", |
| ) |
| with gr.Accordion("Generation Settings", open=False): |
| max_duration_input = gr.Slider( |
| minimum=30, |
| maximum=MAX_DURATION_SECONDS, |
| value=DEFAULT_DURATION_SECONDS, |
| step=10, |
| label="Max Duration (seconds)", |
| ) |
| temperature_input = gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| value=1.0, |
| step=0.1, |
| label="Temperature", |
| ) |
| topk_input = gr.Slider( |
| minimum=1, |
| maximum=100, |
| value=50, |
| step=1, |
| label="Top-K", |
| ) |
| cfg_scale_input = gr.Slider( |
| minimum=1.0, |
| maximum=3.0, |
| value=1.5, |
| step=0.1, |
| label="CFG Scale", |
| ) |
|
|
| generate_button = gr.Button("Generate Music", variant="primary") |
|
|
| with gr.Column(scale=1): |
| audio_output = gr.Audio(label="Generated Audio", type="filepath") |
|
|
| generate_button.click( |
| fn=generate_music, |
| inputs=[ |
| lyrics_input, |
| tags_input, |
| max_duration_input, |
| temperature_input, |
| topk_input, |
| cfg_scale_input, |
| ], |
| outputs=audio_output, |
| ) |
|
|
| demo.queue(default_concurrency_limit=1) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|