heartmula / app.py
ABLingss's picture
ai-timeout
0d2b17a
Raw
History Blame Contribute Delete
11.4 kB
from __future__ import annotations
import os
import sys
import tempfile
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Tuple
try:
import spaces
except ImportError:
class _SpacesShim:
@staticmethod
def GPU(*_args, **_kwargs):
def decorator(fn):
return fn
return decorator
spaces = _SpacesShim()
import gradio as gr
import torch
from huggingface_hub import hf_hub_download, snapshot_download
APP_ROOT = Path(__file__).resolve().parent
HEARTLIB_SRC = APP_ROOT / "heartlib" / "src"
HEARTLIB_PACKAGE_INIT = HEARTLIB_SRC / "heartlib" / "__init__.py"
if HEARTLIB_PACKAGE_INIT.is_file():
sys.path.insert(0, str(HEARTLIB_SRC))
from heartlib import HeartMuLaGenPipeline
@dataclass(frozen=True)
class ModelConfig:
version: str
generator_repo: str
mula_repo: str
codec_repo: str
mula_dirname: str
codec_dirname: str
MODEL_CONFIG = ModelConfig(
version=os.getenv("HEARTMULA_VERSION", "3B"),
generator_repo=os.getenv("HEARTMULA_GENERATOR_REPO", "HeartMuLa/HeartMuLaGen"),
mula_repo=os.getenv(
"HEARTMULA_MULA_REPO", "HeartMuLa/HeartMuLa-oss-3B-happy-new-year"
),
codec_repo=os.getenv(
"HEARTMULA_CODEC_REPO", "HeartMuLa/HeartCodec-oss-20260123"
),
mula_dirname=os.getenv("HEARTMULA_MULA_DIRNAME", "HeartMuLa-oss-3B"),
codec_dirname=os.getenv("HEARTMULA_CODEC_DIRNAME", "HeartCodec-oss"),
)
GPU_DURATION_SECONDS = int(os.getenv("HEARTMULA_GPU_DURATION_SECONDS", "300"))
COMPILE_DURATION_SECONDS = int(os.getenv("HEARTMULA_COMPILE_DURATION_SECONDS", "600"))
MAX_DURATION_SECONDS = int(os.getenv("HEARTMULA_MAX_DURATION_SECONDS", "180"))
DEFAULT_DURATION_SECONDS = min(
int(os.getenv("HEARTMULA_DEFAULT_DURATION_SECONDS", "60")),
MAX_DURATION_SECONDS,
)
ENABLE_FLASH_ATTN = os.getenv("HEARTMULA_ENABLE_FLASH_ATTN", "1") != "0"
ENABLE_AOTI = os.getenv("HEARTMULA_ENABLE_AOTI", "1") != "0"
AOTI_MAX_BATCH = int(os.getenv("HEARTMULA_AOTI_MAX_BATCH", "2"))
AOTI_MAX_SEQ_LEN = int(os.getenv("HEARTMULA_AOTI_MAX_SEQ_LEN", "4096"))
KEEP_MULA_LOADED = os.getenv("HEARTMULA_KEEP_MULA_LOADED", "1") != "0"
KEEP_CODEC_LOADED = os.getenv("HEARTMULA_KEEP_CODEC_LOADED", "0") != "0"
MODEL_LOCK = threading.Lock()
PIPELINE_LOCK = threading.Lock()
PIPELINE_CACHE: Dict[Tuple[str, str], HeartMuLaGenPipeline] = {}
_RUNTIME_PREPARED = False
def _default_cache_root() -> Path:
env_home = os.getenv("HF_HOME")
if env_home:
return Path(env_home)
data_home = Path("/data/.huggingface")
if data_home.parent.exists():
return data_home
return Path("/tmp/huggingface")
def _model_root() -> Path:
return Path(
os.getenv(
"HEARTMULA_MODEL_DIR",
str(_default_cache_root() / "heartmula_models"),
)
)
def _read_text(path: Path, fallback: str) -> str:
if path.is_file():
return path.read_text(encoding="utf-8").strip()
return fallback
def _cached_model_exists(model_dir: Path) -> bool:
required_paths = [
model_dir / "tokenizer.json",
model_dir / "gen_config.json",
model_dir / MODEL_CONFIG.mula_dirname,
model_dir / MODEL_CONFIG.codec_dirname,
]
return all(path.exists() for path in required_paths)
def ensure_model_artifacts(progress: gr.Progress | None = None) -> Path:
model_dir = _model_root()
model_dir.mkdir(parents=True, exist_ok=True)
if _cached_model_exists(model_dir):
if progress is not None:
progress(0.05, desc="Using cached model artifacts")
return model_dir
with MODEL_LOCK:
if _cached_model_exists(model_dir):
if progress is not None:
progress(0.05, desc="Using cached model artifacts")
return model_dir
if progress is not None:
progress(0.05, desc="Downloading tokenizer and generation config")
for filename in ("tokenizer.json", "gen_config.json"):
hf_hub_download(
repo_id=MODEL_CONFIG.generator_repo,
filename=filename,
local_dir=str(model_dir),
)
if progress is not None:
progress(0.25, desc="Downloading HeartMuLa checkpoint")
snapshot_download(
repo_id=MODEL_CONFIG.mula_repo,
local_dir=str(model_dir / MODEL_CONFIG.mula_dirname),
)
if progress is not None:
progress(0.6, desc="Downloading HeartCodec checkpoint")
snapshot_download(
repo_id=MODEL_CONFIG.codec_repo,
local_dir=str(model_dir / MODEL_CONFIG.codec_dirname),
)
if progress is not None:
progress(0.95, desc="Model artifacts ready")
return model_dir
def _runtime_key() -> Tuple[str, str]:
runtime = "cuda" if torch.cuda.is_available() else "cpu"
return runtime, str(_model_root())
def get_pipeline(model_dir: Path) -> HeartMuLaGenPipeline:
"""Create pipeline and store acceleration config. Does NOT trigger compilation."""
runtime = "cuda" if torch.cuda.is_available() else "cpu"
cache_key = (runtime, str(model_dir))
with PIPELINE_LOCK:
if cache_key in PIPELINE_CACHE:
return PIPELINE_CACHE[cache_key]
if runtime == "cuda":
device = {
"mula": torch.device("cuda"),
"codec": torch.device("cuda"),
}
dtype = {
"mula": torch.bfloat16,
"codec": torch.float32,
}
lazy_load = {
"mula": not KEEP_MULA_LOADED,
"codec": not KEEP_CODEC_LOADED,
}
else:
device = torch.device("cpu")
dtype = torch.float32
lazy_load = False
pipeline = HeartMuLaGenPipeline.from_pretrained(
str(model_dir),
device=device,
dtype=dtype,
version=MODEL_CONFIG.version,
lazy_load=lazy_load,
)
pipeline.configure_runtime_acceleration(
enable_flash_attn=runtime == "cuda" and ENABLE_FLASH_ATTN,
enable_aoti=runtime == "cuda" and ENABLE_AOTI,
max_batch_size=AOTI_MAX_BATCH,
max_compile_seq_len=AOTI_MAX_SEQ_LEN,
)
PIPELINE_CACHE[cache_key] = pipeline
return pipeline
@spaces.GPU(duration=COMPILE_DURATION_SECONDS)
def _compile_runtime(model_dir: str):
"""AoTI compilation + FA3 injection on a real GPU.
Uses a separate, long time budget (default 1500 s) so that first-time
compilation does not compete with inference for GPU seconds. On
subsequent calls the ``spaces.aoti_compile`` filesystem cache makes
this effectively a no-op.
"""
pipeline = get_pipeline(Path(model_dir))
pipeline.prepare_runtime()
@spaces.GPU(duration=GPU_DURATION_SECONDS)
def _run_generation(
model_dir: str,
lyrics: str,
tags: str,
max_duration_seconds: int,
temperature: float,
topk: int,
cfg_scale: float,
progress=gr.Progress(track_tqdm=True),
):
pipeline = get_pipeline(Path(model_dir))
max_audio_length_ms = max_duration_seconds * 1000
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
output_path = fp.name
progress(0.05, desc="Generating audio")
with torch.no_grad():
pipeline(
{
"lyrics": lyrics,
"tags": tags,
},
max_audio_length_ms=max_audio_length_ms,
save_path=output_path,
topk=topk,
temperature=temperature,
cfg_scale=cfg_scale,
)
return output_path
def generate_music(
lyrics: str,
tags: str,
max_duration_seconds: int,
temperature: float,
topk: int,
cfg_scale: float,
progress=gr.Progress(track_tqdm=True),
):
if not lyrics.strip():
raise gr.Error("Please enter lyrics before generating.")
if not tags.strip():
raise gr.Error("Please enter at least one style tag.")
model_dir = ensure_model_artifacts(progress)
global _RUNTIME_PREPARED
if not _RUNTIME_PREPARED:
progress(0.02, desc="Compiling runtime (first request, please wait)")
_compile_runtime(str(model_dir))
_RUNTIME_PREPARED = True
return _run_generation(
str(model_dir),
lyrics,
tags,
max_duration_seconds,
temperature,
topk,
cfg_scale,
)
DEFAULT_LYRICS = _read_text(
APP_ROOT / "heartlib" / "assets" / "lyrics.txt",
"""[Verse]
The city wakes before the sun
We keep moving one by one
[Chorus]
Hold the light and sing it through
Every road comes back to you""",
)
DEFAULT_TAGS = _read_text(
APP_ROOT / "heartlib" / "assets" / "tags.txt",
"female,indie pop,piano,emotional,night,silky,memories",
)
with gr.Blocks(title="HeartMuLa ZeroGPU Demo") as demo:
gr.Markdown(
"""
# HeartMuLa ZeroGPU Demo
Generate music from lyrics and style tags with **HeartMuLa** on Hugging Face Spaces.
First use may take longer because model files need to be cached.
"""
)
with gr.Row():
with gr.Column(scale=1):
lyrics_input = gr.Textbox(
label="Lyrics",
lines=18,
value=DEFAULT_LYRICS,
placeholder="Use structured sections such as [Verse], [Chorus], [Bridge].",
)
tags_input = gr.Textbox(
label="Tags",
value=DEFAULT_TAGS,
placeholder="female,indie pop,piano,emotional,night,silky,memories",
info="Comma-separated tags without spaces for best compatibility.",
)
with gr.Accordion("Generation Settings", open=False):
max_duration_input = gr.Slider(
minimum=30,
maximum=MAX_DURATION_SECONDS,
value=DEFAULT_DURATION_SECONDS,
step=10,
label="Max Duration (seconds)",
)
temperature_input = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="Temperature",
)
topk_input = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-K",
)
cfg_scale_input = gr.Slider(
minimum=1.0,
maximum=3.0,
value=1.5,
step=0.1,
label="CFG Scale",
)
generate_button = gr.Button("Generate Music", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(label="Generated Audio", type="filepath")
generate_button.click(
fn=generate_music,
inputs=[
lyrics_input,
tags_input,
max_duration_input,
temperature_input,
topk_input,
cfg_scale_input,
],
outputs=audio_output,
)
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch()