testmula / app.py
ABLingss's picture
plz
5857256
# ==============================================================================
# Original app.py (for diff reference):
# ==============================================================================
# import gradio as gr
#
# def greet(name):
# return "Hello " + name + "!!"
#
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")
# demo.launch()
# ==============================================================================
# HF Spaces compatibility: Add src to Python path for heartlib import
import sys
import os
import uuid
import subprocess
import tempfile
import wave
from pathlib import Path
from typing import Dict, Any, Iterator, Optional, Tuple, TYPE_CHECKING
import numpy as np
# HF Spaces: import spaces before any CUDA-related packages (e.g., torch).
try:
import spaces
except Exception:
spaces = None
IS_SPACE = spaces is not None and os.environ.get("SPACE_ID")
_gpu_duration_env = os.environ.get("GPU_MAX_DURATION")
if _gpu_duration_env is None:
aoti_env = os.environ.get("ENABLE_AOTI")
enable_aoti_default = aoti_env is None or aoti_env.strip().lower() in ("1", "true", "yes", "y", "on")
GPU_MAX_DURATION = 600 if IS_SPACE and enable_aoti_default else 100
else:
GPU_MAX_DURATION = int(_gpu_duration_env)
def _env_bool(name: str) -> Optional[bool]:
val = os.environ.get(name)
if val is None:
return None
return val.strip().lower() in ("1", "true", "yes", "y", "on")
_default_keep_model_loaded_env = _env_bool("KEEP_MODEL_LOADED_DEFAULT")
if _default_keep_model_loaded_env is None:
DEFAULT_KEEP_MODEL_LOADED = not IS_SPACE
else:
DEFAULT_KEEP_MODEL_LOADED = _default_keep_model_loaded_env
def _gpu_guard(fn):
if spaces is None:
return fn
if GPU_MAX_DURATION > 0:
try:
return spaces.GPU(fn, duration=GPU_MAX_DURATION)
except TypeError:
return spaces.GPU(fn)
return spaces.GPU(fn)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
# HF Spaces: disable SSR to avoid Node proxy/port issues.
os.environ.setdefault("GRADIO_SSR_MODE", "0")
# Mitigate CUDA memory fragmentation on small GPUs.
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
if TYPE_CHECKING:
from heartlib import HeartMuLaGenPipeline
import gradio as gr
import re
try:
from google import genai # google-genai package
from google.genai import types
except Exception:
genai = None
types = None
from openai import OpenAI
from transformers import BitsAndBytesConfig
# Global pipelines (managed lazily)
# Default model path (HF Spaces + local override via env)
MODEL_PATH = os.environ.get("MODEL_PATH", "./ckpt")
# LLM API Presets
LLM_PRESETS = {
"gemini": {
"name": "Google Gemini",
"api_type": "gemini",
"default_model": "gemini-2.0-flash-lite",
"env_key": "GEMINI_API_KEY",
"base_url": None,
},
"openai": {
"name": "OpenAI",
"api_type": "openai",
"default_model": "gpt-4o-mini",
"env_key": "OPENAI_API_KEY",
"base_url": None,
},
"deepseek": {
"name": "DeepSeek",
"api_type": "openai",
"default_model": "deepseek-chat",
"env_key": "DEEPSEEK_API_KEY",
"base_url": "https://api.deepseek.com",
},
"custom": {
"name": "Custom OpenAI-Compatible",
"api_type": "openai",
"default_model": "custom-model",
"env_key": None,
"base_url": None,
}
}
# Default example from assets
EXAMPLE_LYRICS = """[Intro]
[Verse]
The sun creeps in across the floor
I hear the traffic outside the door
The coffee pot begins to hiss
It is another morning just like this
[Prechorus]
The world keeps spinning round and round
Feet are planted on the ground
I find my rhythm in the sound
[Chorus]
Every day the light returns
Every day the fire burns
We keep on walking down this street
Moving to the same steady beat
It is the ordinary magic that we meet
[Verse]
The hours tick deeply into noon
Chasing shadows,chasing the moon
Work is done and the lights go low
Watching the city start to glow
[Bridge]
It is not always easy,not always bright
Sometimes we wrestle with the night
But we make it to the morning light
[Chorus]
Every day the light returns
Every day the fire burns
We keep on walking down this street
Moving to the same steady beat
[Outro]
Just another day
Every single day"""
EXAMPLE_TAGS = "piano,happy"
# Tag categories for selection
TAG_DATA = {
"Gender": [
"Male", "Female"
],
"Genre": [
"Pop", "Folk", "Ballad", "Electronic", "Rock", "Acoustic", "R&B",
"Indie", "Dance", "Indie Pop", "J-Pop", "Hip-Hop", "Country",
"Latin", "Alternative", "Christian", "Cantopop", "Gospel", "Soul",
"Mandopop"
],
"Instrument": [
"Drums", "Piano", "Guitar", "Strings", "Synthesizer", "Bass",
"Acoustic Guitar", "Keyboard", "Electronic Drums", "Vocals",
"Drum Machine", "Electric Guitar", "Percussion", "Beat",
"Orchestra", "Saxophone", "Accordion", "Voice", "String", "Vocal"
],
"Mood": [
"Melancholy", "Romantic", "Energetic", "Hopeful", "Dreamy",
"Relaxed", "Sad", "Calm", "Cheerful", "Reflective", "Emotional",
"Joyful", "Sentimental", "Uplifting", "Warm", "Peaceful", "Upbeat",
"Gentle", "Nostalgic", "Epic"
],
"Scene": [
"Driving", "Road Trip", "Cafe", "Relaxing", "Wedding", "Meditation",
"Workout", "Walking", "Alone", "Travel", "Reflection", "Rainy Day",
"Night", "Church", "Coffee Shop", "Gym", "Gaming", "Study",
"Dating", "Date"
],
"Singer Timbre": [
"Soft", "Clear", "Warm", "Gentle", "Smooth", "Sweet", "Emotional",
"Mellow", "Powerful", "Youthful", "Bright", "Rough", "Raspy",
"Melodic", "Deep", "Soulful", "Strong", "Energetic", "Breathy",
"Passionate"
],
"Topic": [
"Love", "Relationship", "Hope", "Longing", "Loss", "Heartbreak",
"Memory", "Reflection", "Life", "Faith", "Regret", "Freedom",
"Breakup", "Nature", "Loneliness", "Dreams", "Nostalgia", "Romance",
"Friendship", "Youth"
]
}
DATA_DIR = Path(os.environ.get("HEARTMULA_DATA_DIR", os.path.join(tempfile.gettempdir(), "heartmula_stream")))
DATA_DIR.mkdir(parents=True, exist_ok=True)
print(f"DATA_DIR = {DATA_DIR}")
# Clear ZeroGPU offload cache to avoid disk-full errors.
offload_path = "/data-nvme/zerogpu-offload"
try:
if os.path.exists(offload_path):
for name in os.listdir(offload_path):
os.system(f"rm -rf {offload_path}/{name}")
except Exception as e:
print(f"WARN: failed to clear ZeroGPU offload cache: {e}")
GRADIO_QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "24"))
GRADIO_DEFAULT_CONCURRENCY = int(os.environ.get("GRADIO_DEFAULT_CONCURRENCY", "1"))
GPU_CONCURRENCY_LIMIT = int(os.environ.get("GRADIO_GPU_CONCURRENCY", "1"))
DEFAULT_DURATION_SEC = int(os.environ.get("DEFAULT_DURATION_SEC", "60" if IS_SPACE else "180"))
DEFAULT_QUANT_MODE = os.environ.get("DEFAULT_QUANT_MODE", "4bit" if IS_SPACE else "none")
DEFAULT_OFFLOAD_MODE = os.environ.get("DEFAULT_OFFLOAD_MODE", "aggressive" if IS_SPACE else "auto")
DEFAULT_GENERATION_MODE = os.environ.get(
"DEFAULT_GENERATION_MODE",
"Accelerated" if IS_SPACE else "Original (No Acceleration)",
)
DEFAULT_SPEED_SUBMODE = os.environ.get("DEFAULT_SPEED_SUBMODE", "Standard")
DEFAULT_PRESET = os.environ.get("DEFAULT_PRESET", "ZeroGPU AOTI FP8" if IS_SPACE else "Balanced")
MODEL_DETOKENIZE_INTERVAL_SEC = float(os.environ.get("MODEL_DETOKENIZE_INTERVAL_SEC", "29.76"))
AUDIO_SAMPLE_RATE = int(os.environ.get("AUDIO_SAMPLE_RATE", "48000"))
FRAME_MS = 80.0
BLOCK_FRAMES = max(1, int(round((MODEL_DETOKENIZE_INTERVAL_SEC * 1000.0) / FRAME_MS)))
BLOCK_SAMPLES = max(1, int(round(MODEL_DETOKENIZE_INTERVAL_SEC * AUDIO_SAMPLE_RATE)))
MIN_BUFFER_BLOCKS = int(os.environ.get("MIN_BUFFER_BLOCKS", "1"))
PREFETCH_BLOCKS = int(os.environ.get("PREFETCH_BLOCKS", "2"))
MAX_QUEUE_BLOCKS = int(os.environ.get("MAX_QUEUE_BLOCKS", "4"))
ALLOW_STREAMING_ZERO_GPU = _env_bool("ALLOW_STREAMING_ZERO_GPU")
STREAMING_ALLOWED = (not IS_SPACE) or bool(ALLOW_STREAMING_ZERO_GPU)
PRESET_CONFIGS = {
"ZeroGPU Safe": {
"duration": 60,
"quant_mode": "4bit",
"offload_mode": "aggressive",
"keep_model_loaded": False,
"temperature": 1.0,
"topk": 50,
"cfg_scale": 1.5,
"generation_mode": "Accelerated",
"speed_submode": "Standard",
},
"ZeroGPU AOTI FP8": {
"duration": 60,
"quant_mode": "fp8",
"offload_mode": "aggressive",
"keep_model_loaded": False,
"temperature": 1.0,
"topk": 50,
"cfg_scale": 1.5,
"generation_mode": "Accelerated",
"speed_submode": "Standard",
},
"Balanced": {
"duration": 120,
"quant_mode": "4bit" if IS_SPACE else "none",
"offload_mode": "auto",
"keep_model_loaded": DEFAULT_KEEP_MODEL_LOADED,
"temperature": 1.0,
"topk": 50,
"cfg_scale": 1.5,
"generation_mode": "Accelerated",
"speed_submode": "Standard",
},
"Quality": {
"duration": 180,
"quant_mode": "none",
"offload_mode": "auto",
"keep_model_loaded": DEFAULT_KEEP_MODEL_LOADED,
"temperature": 0.9,
"topk": 50,
"cfg_scale": 2.0,
"generation_mode": "Original (No Acceleration)",
"speed_submode": "Standard",
},
}
if DEFAULT_PRESET not in PRESET_CONFIGS:
DEFAULT_PRESET = "Balanced"
_default_preset_config = PRESET_CONFIGS[DEFAULT_PRESET]
if "DEFAULT_DURATION_SEC" not in os.environ:
DEFAULT_DURATION_SEC = _default_preset_config["duration"]
if "DEFAULT_QUANT_MODE" not in os.environ:
DEFAULT_QUANT_MODE = _default_preset_config["quant_mode"]
if "DEFAULT_OFFLOAD_MODE" not in os.environ:
DEFAULT_OFFLOAD_MODE = _default_preset_config["offload_mode"]
if "DEFAULT_GENERATION_MODE" not in os.environ:
DEFAULT_GENERATION_MODE = _default_preset_config["generation_mode"]
if "DEFAULT_SPEED_SUBMODE" not in os.environ:
DEFAULT_SPEED_SUBMODE = _default_preset_config["speed_submode"]
class ModelManager:
def __init__(self, model_path: str, use_deepspeed_override: Optional[bool] = None):
import torch
from heartlib import HeartMuLaGenPipeline
self.model_path = model_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self._gen_pipes: Dict[Tuple[str, str, str], "HeartMuLaGenPipeline"] = {}
if use_deepspeed_override is None:
self.use_deepspeed = os.getenv("USE_DEEPSPEED_INFERENCE", "0").lower() in ("1", "true", "yes")
else:
self.use_deepspeed = use_deepspeed_override
self.ds_inference_config = self._make_ds_inference_config()
self._HeartMuLaGenPipeline = HeartMuLaGenPipeline
def _make_ds_inference_config(self) -> Dict[str, Any]:
if not self.use_deepspeed:
return {}
mp_size = int(os.getenv("DEEPSPEED_TP_SIZE", os.getenv("WORLD_SIZE", "1")))
replace_method = os.getenv("DEEPSPEED_REPLACE_METHOD", "auto")
kernel_inject = os.getenv("DEEPSPEED_KERNEL_INJECT", "1").lower() in ("1", "true", "yes")
return {
"mp_size": mp_size,
"dtype": self.dtype,
"replace_method": replace_method,
"replace_with_kernel_inject": kernel_inject,
}
def _make_bnb_config(self, quant_mode: str) -> Optional[BitsAndBytesConfig]:
import torch
if quant_mode == "none":
return None
if self.device.type != "cuda":
raise gr.Error("Quantization requires CUDA.")
if quant_mode == "fp8":
return None
if quant_mode == "4bit":
quant_type = "nf4"
try:
major, _ = torch.cuda.get_device_capability()
if major >= 10:
quant_type = "fp4"
except Exception:
pass
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=quant_type,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
if quant_mode == "8bit":
return BitsAndBytesConfig(load_in_8bit=True)
raise gr.Error(f"Unknown quant mode: {quant_mode}")
def get_gen_pipeline(self, version: str, codec_version: str, quant_mode: str) -> "HeartMuLaGenPipeline":
key = (version, codec_version, quant_mode)
if key not in self._gen_pipes:
bnb_config = self._make_bnb_config(quant_mode)
self._gen_pipes[key] = self._HeartMuLaGenPipeline.from_pretrained(
self.model_path,
device=self.device,
dtype=self.dtype,
version=version,
codec_version=codec_version,
bnb_config=bnb_config,
torchao_quantize=quant_mode == "fp8",
lazy_load=True,
use_deepspeed=self.use_deepspeed,
ds_inference_config=self.ds_inference_config,
)
return self._gen_pipes[key]
model_managers: Dict[str, ModelManager] = {}
def get_model_manager(use_acceleration: bool) -> ModelManager:
key = "accelerated" if use_acceleration else "original"
if key not in model_managers:
os.makedirs(MODEL_PATH, exist_ok=True)
download_models_if_needed(MODEL_PATH)
use_deepspeed_override = None if use_acceleration else False
model_managers[key] = ModelManager(MODEL_PATH, use_deepspeed_override=use_deepspeed_override)
return model_managers[key]
def update_tag_string(*args):
"""
Collects selected tags from all categories and joins them.
args: list of lists (selections from each CheckboxGroup)
"""
all_tags = []
for selection in args:
if selection:
if isinstance(selection, list):
all_tags.extend(selection)
else:
all_tags.append(selection)
# Remove duplicates while preserving order
seen = set()
unique_tags = []
for t in all_tags:
if t not in seen:
unique_tags.append(t)
seen.add(t)
return ",".join(unique_tags)
def process_lyrics_correct(content):
"""
Correct lyrics processing logic aligned with training data.
1. Removes timestamps [xx:xx].
2. Split lines and strip whitespace from each line.
3. Remove leading/trailing empty lines.
4. Collapse multiple newlines (3 or more) into 2.
"""
# 0. Convert to lowercase
content = content.lower()
# 1. Remove timestamps [00:12] or [00:12.34]
content = re.sub(r'\[[^\]]*\d{1,2}:\d{2}[^\]]*\]', '', content)
# 2. Split lines and strip whitespace from each line
lines = [line.strip() for line in content.split('\n')]
# 3. Remove leading empty lines
while lines and lines[0] == '':
lines.pop(0)
# 4. Remove trailing empty lines
while lines and lines[-1] == '':
lines.pop()
# 5. Join back to string
content = '\n'.join(lines)
# 6. Collapse multiple newlines (3 or more) into 2
content = re.sub(r'\n{3,}', '\n\n', content)
return content
def save_audio_to_wav(sample_rate: int, audio_np: np.ndarray, output_dir: Path) -> str:
output_dir.mkdir(parents=True, exist_ok=True)
unique_id = str(uuid.uuid4())
wav_path = output_dir / f"{unique_id}.wav"
audio_int16 = (audio_np * 32767).astype(np.int16)
with wave.open(str(wav_path), "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
return str(wav_path)
def convert_wav_to_mp3(wav_path: str, output_dir: Path) -> str:
output_dir.mkdir(parents=True, exist_ok=True)
unique_id = str(uuid.uuid4())
mp3_path = output_dir / f"{unique_id}.mp3"
cmd = [
"ffmpeg",
"-y",
"-i",
wav_path,
"-codec:a",
"libmp3lame",
"-qscale:a",
"2",
str(mp3_path),
]
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return str(mp3_path)
def check_models_exist(ckpt_dir):
"""Check if all required models exist"""
required_files = [
os.path.join(ckpt_dir, "gen_config.json"),
os.path.join(ckpt_dir, "tokenizer.json"),
os.path.join(ckpt_dir, "HeartCodec-oss"),
os.path.join(ckpt_dir, "HeartMuLa-oss-3B"),
]
for path in required_files:
if not os.path.exists(path):
return False
return True
def download_models_if_needed(ckpt_dir):
"""Download models from ModelScope if not present"""
if check_models_exist(ckpt_dir):
print("=" * 50)
print(f"✓ Checkpoints found in {ckpt_dir}")
print("✓ Skipping download")
print("=" * 50)
return
print("=" * 50)
print("⬇ Starting model download from ModelScope")
print("=" * 50)
print("")
from modelscope import snapshot_download
# Download HeartMuLaGen (config and tokenizer)
print("━" * 50)
print("📦 [1/3] Downloading HeartMuLaGen config and tokenizer...")
print("━" * 50)
snapshot_download('HeartMuLa/HeartMuLaGen', local_dir=ckpt_dir)
print("✓ HeartMuLaGen download completed")
print("")
# Download HeartMuLa-oss-3B
print("━" * 50)
print("📦 [2/3] Downloading HeartMuLa-oss-3B model...")
print("━" * 50)
snapshot_download('HeartMuLa/HeartMuLa-oss-3B',
local_dir=os.path.join(ckpt_dir, 'HeartMuLa-oss-3B'))
print("✓ HeartMuLa-oss-3B download completed")
print("")
# Download HeartCodec-oss
print("━" * 50)
print("📦 [3/3] Downloading HeartCodec-oss model...")
print("━" * 50)
snapshot_download('HeartMuLa/HeartCodec-oss',
local_dir=os.path.join(ckpt_dir, 'HeartCodec-oss'))
print("✓ HeartCodec-oss download completed")
print("")
print("=" * 50)
print("✓ All models downloaded successfully!")
print("=" * 50)
print("")
def load_pipeline(model_path, version, codec_version, quant_mode, use_acceleration: bool):
"""Load HeartMuLa pipeline (lazy)"""
manager = get_model_manager(use_acceleration)
print(f"Using model from {model_path} on {manager.device}...")
return manager.get_gen_pipeline(version, codec_version, quant_mode)
def _validate_generation_inputs(lyrics: str, tags: str) -> None:
if not lyrics.strip():
raise gr.Error("Please enter lyrics")
if not tags.strip():
raise gr.Error("Please enter tags")
def generate(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
use_acceleration,
):
"""Generate music"""
import torch
_validate_generation_inputs(lyrics, tags)
max_audio_length_ms = int(duration_sec * 1000)
try:
if backend == "exllama_v2":
raise gr.Error("ExLlamaV2 backend is not implemented yet.")
pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode, use_acceleration)
output_path = os.path.join(DATA_DIR, f"gen_{uuid.uuid4().hex}.wav")
with torch.inference_mode():
pipe(
{
"lyrics": lyrics,
"tags": tags,
},
max_audio_length_ms=max_audio_length_ms,
save_path=output_path,
topk=topk,
temperature=temperature,
cfg_scale=cfg_scale,
keep_model_loaded=keep_model_loaded,
offload_mode=offload_mode,
)
try:
file_size = os.path.getsize(output_path)
print(f"Generated file: {output_path} ({file_size} bytes)")
except Exception as e:
print(f"WARN: failed to stat output file {output_path}: {e}")
if output_format == "mp3":
mp3_path = convert_wav_to_mp3(output_path, DATA_DIR)
return mp3_path
return output_path
except Exception as e:
raise gr.Error(f"Generation error: {str(e)}")
finally:
if not keep_model_loaded and torch.cuda.is_available():
torch.cuda.empty_cache()
@_gpu_guard
def generate_original(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
):
return generate(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
False,
)
@_gpu_guard
def generate_accelerated(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
):
return generate(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
True,
)
def _normalize_stream_chunk(chunk: np.ndarray) -> np.ndarray:
chunk = np.nan_to_num(chunk, nan=0.0, posinf=0.0, neginf=0.0)
return np.clip(chunk, -1.0, 1.0)
def generate_music_streaming(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
keep_model_loaded,
offload_mode,
backend,
use_acceleration,
) -> Iterator[Tuple[int, np.ndarray]]:
_validate_generation_inputs(lyrics, tags)
if backend == "exllama_v2":
raise gr.Error("ExLlamaV2 backend is not implemented yet.")
pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode, use_acceleration)
max_audio_length_ms = int(duration_sec * 1000)
for chunk in pipe.stream(
{"lyrics": lyrics, "tags": tags},
max_audio_length_ms=max_audio_length_ms,
temperature=temperature,
topk=topk,
cfg_scale=cfg_scale,
chunk_frames=BLOCK_FRAMES,
keep_model_loaded=keep_model_loaded,
offload_mode=offload_mode,
):
if chunk.dim() == 2:
chunk = chunk.squeeze(0)
chunk_np = chunk.cpu().numpy()
chunk_np = _normalize_stream_chunk(chunk_np)
print(f"stream chunk: samples={chunk_np.shape[0]} sr={AUDIO_SAMPLE_RATE}")
yield AUDIO_SAMPLE_RATE, chunk_np
def stream_generate(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
keep_model_loaded,
offload_mode,
backend,
use_acceleration,
):
try:
_validate_generation_inputs(lyrics, tags)
start_threshold = max(MIN_BUFFER_BLOCKS, PREFETCH_BLOCKS)
queue = []
started = False
print("block stream start:", f"block_sec={MODEL_DETOKENIZE_INTERVAL_SEC}", f"duration_sec={duration_sec}")
for sr, chunk_np in generate_music_streaming(
lyrics=lyrics,
tags=tags,
cfg_scale=cfg_scale,
duration_sec=duration_sec,
temperature=temperature,
topk=topk,
version=version,
codec_version=codec_version,
quant_mode=quant_mode,
keep_model_loaded=keep_model_loaded,
offload_mode=offload_mode,
backend=backend,
use_acceleration=use_acceleration,
):
chunk_np = chunk_np.astype("float32", copy=False)
queue.append(chunk_np)
if not started and len(queue) < start_threshold and len(queue) < MAX_QUEUE_BLOCKS:
continue
if not started:
started = True
print(f"block stream start playback: buffered_blocks={len(queue)}")
while queue:
block = queue.pop(0)
print(f"block stream yield: samples={block.shape[0]}")
yield sr, block
except Exception as e:
raise gr.Error(f"Streaming error: {str(e)}")
@_gpu_guard
def stream_generate_accelerated(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
keep_model_loaded,
offload_mode,
backend,
):
if IS_SPACE and not STREAMING_ALLOWED:
raise gr.Error("Streaming is disabled on ZeroGPU. Use Standard mode or set ALLOW_STREAMING_ZERO_GPU=1.")
return stream_generate(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
keep_model_loaded,
offload_mode,
backend,
True,
)
def generate_lyrics(theme, tags, language, api_choice, api_key_input, custom_base_url, custom_model, progress=gr.Progress()):
"""Generate lyrics using selected LLM API"""
if not theme.strip():
raise gr.Error("Please enter a theme")
progress(0.1, desc="Preparing request...")
# Get preset configuration
if api_choice not in LLM_PRESETS:
raise gr.Error(f"Unknown API choice: {api_choice}")
preset = LLM_PRESETS[api_choice]
# Determine API key
api_key = api_key_input.strip() if api_key_input and api_key_input.strip() else None
if not api_key and preset["env_key"]:
api_key = os.environ.get(preset["env_key"])
if not api_key:
raise gr.Error(f"No API key provided. Please enter your API key in the field above.")
# Determine base URL and model
base_url = custom_base_url.strip() if custom_base_url and custom_base_url.strip() else preset["base_url"]
model_name = custom_model.strip() if custom_model and custom_model.strip() else preset["default_model"]
# Language mapping
language_names = {
"en": "English",
"zh": "Chinese",
"jp": "Japanese",
"kr": "Korean",
"sp": "Spanish"
}
lang_name = language_names.get(language, "English")
# Tags processing
tags_text = tags.strip() if tags.strip() else "pop, emotional"
# Create prompt
prompt = f"""You are a professional songwriter. Generate song lyrics based on the following requirements:
**Theme**: {theme}
**Music Style/Tags**: {tags_text}
**Language**: {lang_name}
**Format Requirements** (CRITICAL):
1. Use lowercase for all lyrics text (except structure tags which are in brackets)
2. Include proper song structure tags: [Intro], [Verse], [Prechorus], [Chorus], [Bridge], [Outro]
3. Each structure tag should be on its own line
4. Separate different sections with a blank line (one empty line between sections)
5. NO timestamps like [00:12] - only structure tags allowed
6. Keep lyrics concise and suitable for a 3-4 minute song
**Structure Guidelines**:
- [Intro]: Optional, 1-2 lines if included
- [Verse]: Story-telling part, 4-6 lines, can repeat with different lyrics
- [Prechorus]: Optional, 2-4 lines, builds tension before chorus
- [Chorus]: Main hook, catchy and repetitive, 4-6 lines
- [Bridge]: Optional, provides contrast, 4-6 lines
- [Outro]: Closing, 1-2 lines
**Example Format**:
```
[Intro]
[Verse]
the sun creeps in across the floor
i hear the traffic outside the door
the coffee pot begins to hiss
it is another morning just like this
[Chorus]
every day the light returns
every day the fire burns
we keep on walking down this street
moving to the same steady beat
```
Now generate lyrics in {lang_name} based on the theme "{theme}" with style "{tags_text}".
Output ONLY the lyrics with structure tags, no explanations.
"""
try:
if preset["api_type"] == "gemini":
if genai is None or types is None:
raise gr.Error("Gemini SDK not available. Install `google-genai` or switch provider.")
# Gemini API
progress(0.3, desc=f"Connecting to {preset['name']}...")
# Set proxy if needed
try:
proxy_host = os.environ.get("PROXY_HOST", "127.0.0.1")
proxy_port = os.environ.get("PROXY_PORT", "7890")
os.environ['http_proxy'] = f'http://{proxy_host}:{proxy_port}'
os.environ['https_proxy'] = f'http://{proxy_host}:{proxy_port}'
except Exception:
pass # Proxy is optional
client = genai.Client(api_key=api_key)
progress(0.5, desc=f"Generating lyrics with {preset['name']}...")
response = client.models.generate_content(
model=model_name,
contents=[
types.Content(
role='user',
parts=[types.Part(text=prompt)]
)
],
config=types.GenerateContentConfig(
temperature=0.8,
max_output_tokens=2000
)
)
generated_lyrics = response.text.strip()
elif preset["api_type"] == "openai":
# OpenAI-compatible API (OpenAI, DeepSeek, Custom)
progress(0.3, desc=f"Connecting to {preset['name']}...")
# Create client with optional base_url
if base_url:
client = OpenAI(api_key=api_key, base_url=base_url)
else:
client = OpenAI(api_key=api_key)
progress(0.5, desc=f"Generating lyrics with {preset['name']}...")
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": "You are a professional songwriter who creates well-structured lyrics."},
{"role": "user", "content": prompt}
],
temperature=0.8,
max_tokens=2000
)
generated_lyrics = response.choices[0].message.content.strip()
else:
raise gr.Error(f"Unknown API type: {preset['api_type']}")
progress(0.9, desc="Processing response...")
# Clean up the response (remove markdown code blocks if present)
if generated_lyrics.startswith("```"):
lines = generated_lyrics.split("\n")
generated_lyrics = "\n".join(lines[1:-1]) if len(lines) > 2 else generated_lyrics
# Apply our lyrics processing function to ensure format consistency
generated_lyrics = process_lyrics_correct(generated_lyrics)
progress(1.0, desc="Done!")
return generated_lyrics
except Exception as e:
raise gr.Error(f"Lyrics generation error: {str(e)}")
def create_ui():
"""Create Gradio UI"""
speed_submode_value = DEFAULT_SPEED_SUBMODE if STREAMING_ALLOWED else "Standard"
show_stream_default = (
STREAMING_ALLOWED
and DEFAULT_GENERATION_MODE != "Original (No Acceleration)"
and speed_submode_value == "Streaming"
)
with gr.Blocks(title="HeartMuLa Music Generation") as demo:
gr.Markdown("# HeartMuLa Music Generation")
gr.Markdown("Generate music from lyrics and style tags")
gr.Markdown("Tip: start with the **ZeroGPU Safe** preset for reliable generation on small GPUs.")
with gr.Tabs():
# Tab 1: Music Generation
with gr.Tab("Music Generation"):
with gr.Row():
with gr.Column():
lyrics = gr.Textbox(
label="Lyrics",
lines=15,
value=EXAMPLE_LYRICS,
placeholder="Enter lyrics here..."
)
# Add format button
format_btn = gr.Button("Format Lyrics", size="sm")
# Tag Selection
gr.Markdown("### Tags")
tags = gr.Textbox(
label="Selected Tags (comma-separated)",
value=EXAMPLE_TAGS,
placeholder="e.g., piano,happy,pop",
lines=2
)
# Tag categories in accordion
tag_checkboxes = []
with gr.Accordion("Tag Categories (Click to Expand)", open=False):
with gr.Row():
with gr.Column():
t1 = gr.CheckboxGroup(choices=TAG_DATA["Gender"], label="Gender")
tag_checkboxes.append(t1)
t2 = gr.CheckboxGroup(choices=TAG_DATA["Genre"], label="Genre")
tag_checkboxes.append(t2)
with gr.Column():
t3 = gr.CheckboxGroup(choices=TAG_DATA["Instrument"], label="Instrument")
tag_checkboxes.append(t3)
t4 = gr.CheckboxGroup(choices=TAG_DATA["Mood"], label="Mood")
tag_checkboxes.append(t4)
with gr.Column():
t5 = gr.CheckboxGroup(choices=TAG_DATA["Scene"], label="Scene")
tag_checkboxes.append(t5)
t6 = gr.CheckboxGroup(choices=TAG_DATA["Singer Timbre"], label="Singer Timbre")
tag_checkboxes.append(t6)
with gr.Column():
t7 = gr.CheckboxGroup(choices=TAG_DATA["Topic"], label="Topic")
tag_checkboxes.append(t7)
gr.Markdown("### Quick Presets")
with gr.Row():
preset_selector = gr.Dropdown(
choices=list(PRESET_CONFIGS.keys()),
value=DEFAULT_PRESET,
label="Preset"
)
apply_preset_btn = gr.Button("Apply Preset", size="sm")
# Generation parameters
with gr.Row():
cfg_scale = gr.Slider(0.0, 3.0, value=1.5, step=0.1, label="CFG Scale")
duration = gr.Slider(10, 300, value=DEFAULT_DURATION_SEC, step=10, label="Duration (sec)")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
topk = gr.Slider(1, 100, value=50, step=1, label="Top-K")
with gr.Accordion("Advanced Settings", open=False):
backend = gr.Dropdown(
choices=[("HF Pipeline", "hf"), ("ExLlamaV2 (Not Implemented)", "exllama_v2")],
value="hf",
label="Backend"
)
version = gr.Dropdown(
choices=["3B", "7B", "RL-oss-3B-20260123"],
value="3B",
label="Model Version"
)
codec_version = gr.Dropdown(
choices=["oss", "oss-20260123"],
value="oss",
label="Codec Version"
)
quant_mode = gr.Dropdown(
choices=[
("None", "none"),
("4-bit (NF4/FP4)", "4bit"),
("8-bit", "8bit"),
("FP8 (TorchAO)", "fp8"),
],
value=DEFAULT_QUANT_MODE,
label="Quantization (ZeroGPU recommended: 4-bit)"
)
keep_model_loaded = gr.Checkbox(
value=DEFAULT_KEEP_MODEL_LOADED,
label="Keep Model Loaded"
)
offload_mode = gr.Dropdown(
choices=["auto", "aggressive"],
value=DEFAULT_OFFLOAD_MODE,
label="Offload Mode"
)
output_format = gr.Radio(
choices=[("WAV", "wav"), ("MP3", "mp3")],
value="wav",
label="Output Format",
visible=not show_stream_default,
)
gr.Markdown(
f"Streaming is block-based: {MODEL_DETOKENIZE_INTERVAL_SEC:.2f}s per block "
f"({BLOCK_FRAMES} frames, {BLOCK_SAMPLES} samples)."
)
gr.Markdown("### 🚀 Generation")
generation_mode = gr.Radio(
choices=["Original (No Acceleration)", "Accelerated"],
value=DEFAULT_GENERATION_MODE,
label="Generation Mode",
)
speed_choices = [("Standard", "Standard")]
if STREAMING_ALLOWED:
speed_choices.append(("Block Streaming (Preview)", "Streaming"))
speed_submode = gr.Radio(
choices=speed_choices,
value=speed_submode_value,
label="Accelerated Options",
visible=DEFAULT_GENERATION_MODE != "Original (No Acceleration)",
)
btn_original = gr.Button(
"🎼 Generate Music (Original)",
variant="primary",
size="lg",
visible=DEFAULT_GENERATION_MODE == "Original (No Acceleration)",
)
btn_accel = gr.Button(
"🎼 Generate Music (Accelerated)",
variant="primary",
size="lg",
visible=DEFAULT_GENERATION_MODE != "Original (No Acceleration)"
and not show_stream_default,
)
btn_stream = gr.Button(
"🎼 Generate Music (Block Streaming)",
variant="primary",
size="lg",
visible=show_stream_default,
)
cancel_stream_btn = gr.Button(
"Cancel Streaming",
variant="secondary",
size="lg",
visible=show_stream_default,
)
cancel_state = gr.State()
with gr.Column():
# Notice section
with gr.Accordion("Usage Notice", open=True):
gr.Markdown("""
### ZeroGPU Tips
- Default preset is **ZeroGPU AOTI FP8** for fastest H200 path
- Prefer **4-bit quantization** if FP8 or AOTI is unavailable
- Keep durations short (<= 60s) for faster first results
- **Streaming is block-based**, not token streaming: each block is ~29.76s
- Use Streaming for preview; Standard mode outputs a full downloadable file
- Try **FP8 + AOTI** (preset: ZeroGPU AOTI FP8) for the fastest H200 path
- ZeroGPU does not support generator streaming in forked GPU; streaming is for local deploy only
### Lyrics Format Requirements
**Automatic Processing:**
1. All text will be converted to **lowercase**
2. Timestamps (e.g., [00:12]) will be **automatically removed**
3. Leading/trailing whitespace on each line will be **stripped**
4. Leading/trailing empty lines will be **removed**
5. Multiple consecutive empty lines (3+) will be **collapsed to 2**
**Recommended Format:**
- Use standard song structure tags: `[Intro]`, `[Verse]`, `[Chorus]`, `[Bridge]`, `[Outro]`, etc.
- Separate sections with **blank lines**
- Case doesn't matter (will be auto-converted)
**Example:**
```
[Intro]
[Verse]
The sun creeps in across the floor
I hear the traffic outside the door
[Chorus]
Every day the light returns
Every day the fire burns
```
---
### Tags Format
- Use **commas** to separate multiple tags: `piano,happy,pop`
- Tags influence the style and mood of the generated music
- Select from categories below or type directly
""")
output_audio_file = gr.Audio(
label="Generated Music (Full)",
type="filepath",
interactive=False,
visible=not show_stream_default,
)
stream_audio = gr.Audio(
label="Block Streaming Audio (Preview)",
streaming=True,
autoplay=True,
type="numpy",
format="wav",
interactive=False,
visible=show_stream_default,
)
# Event handlers for tag selection
for cb in tag_checkboxes:
cb.change(fn=update_tag_string, inputs=tag_checkboxes, outputs=tags)
# Button callbacks
format_btn.click(
fn=process_lyrics_correct,
inputs=[lyrics],
outputs=[lyrics]
)
def update_visibility(gen_mode, spd_mode):
if gen_mode == "Original (No Acceleration)":
return (
gr.update(visible=False), # speed_submode
gr.update(visible=True), # btn_original
gr.update(visible=False), # btn_accel
gr.update(visible=False), # btn_stream
gr.update(visible=False), # cancel_stream_btn
gr.update(visible=True), # output_format
gr.update(visible=True), # output_audio_file
gr.update(visible=False), # stream_audio
)
show_stream = STREAMING_ALLOWED and spd_mode == "Streaming"
return (
gr.update(visible=True), # speed_submode
gr.update(visible=False), # btn_original
gr.update(visible=not show_stream), # btn_accel
gr.update(visible=show_stream), # btn_stream
gr.update(visible=show_stream), # cancel_stream_btn
gr.update(visible=not show_stream), # output_format
gr.update(visible=not show_stream), # output_audio_file
gr.update(visible=show_stream), # stream_audio
)
def apply_preset(preset_name):
preset = PRESET_CONFIGS.get(preset_name, PRESET_CONFIGS[DEFAULT_PRESET])
return (
gr.update(value=preset["cfg_scale"]),
gr.update(value=preset["duration"]),
gr.update(value=preset["temperature"]),
gr.update(value=preset["topk"]),
gr.update(value=preset["quant_mode"]),
gr.update(value=preset["keep_model_loaded"]),
gr.update(value=preset["offload_mode"]),
gr.update(value=preset["generation_mode"]),
gr.update(value=preset["speed_submode"]),
)
generation_mode.change(
fn=update_visibility,
inputs=[generation_mode, speed_submode],
outputs=[
speed_submode,
btn_original,
btn_accel,
btn_stream,
cancel_stream_btn,
output_format,
output_audio_file,
stream_audio,
],
)
speed_submode.change(
fn=update_visibility,
inputs=[generation_mode, speed_submode],
outputs=[
speed_submode,
btn_original,
btn_accel,
btn_stream,
cancel_stream_btn,
output_format,
output_audio_file,
stream_audio,
],
)
preset_event = apply_preset_btn.click(
fn=apply_preset,
inputs=[preset_selector],
outputs=[
cfg_scale,
duration,
temperature,
topk,
quant_mode,
keep_model_loaded,
offload_mode,
generation_mode,
speed_submode,
],
)
preset_event.then(
fn=update_visibility,
inputs=[generation_mode, speed_submode],
outputs=[
speed_submode,
btn_original,
btn_accel,
btn_stream,
cancel_stream_btn,
output_format,
output_audio_file,
stream_audio,
],
)
btn_original.click(
fn=generate_original,
inputs=[
lyrics,
tags,
cfg_scale,
duration,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
],
outputs=[output_audio_file],
concurrency_id="gpu_queue",
concurrency_limit=GPU_CONCURRENCY_LIMIT,
)
btn_accel.click(
fn=generate_accelerated,
inputs=[
lyrics,
tags,
cfg_scale,
duration,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
],
outputs=[output_audio_file],
concurrency_id="gpu_queue",
concurrency_limit=GPU_CONCURRENCY_LIMIT,
)
stream_event = btn_stream.click(
fn=stream_generate_accelerated,
inputs=[
lyrics,
tags,
cfg_scale,
duration,
temperature,
topk,
version,
codec_version,
quant_mode,
keep_model_loaded,
offload_mode,
backend,
],
outputs=[stream_audio],
concurrency_id="gpu_queue",
concurrency_limit=GPU_CONCURRENCY_LIMIT,
)
cancel_stream_btn.click(
fn=lambda: None,
inputs=None,
outputs=[cancel_state],
cancels=[stream_event],
)
# Tab 2: Lyrics Generation
with gr.Tab("Lyrics Generation"):
with gr.Row():
with gr.Column():
gr.Markdown("### Generate Lyrics with AI")
api_selector = gr.Radio(
choices=[
("Google Gemini", "gemini"),
("OpenAI", "openai"),
("DeepSeek", "deepseek"),
("Custom (OpenAI-compatible)", "custom")
],
value="gemini",
label="Select LLM Provider"
)
with gr.Accordion("API Configuration", open=True):
api_key_input = gr.Textbox(
label="API Key (Required)",
type="password",
placeholder="Enter your API key or set environment variable",
info="Will use environment variable if not provided here"
)
custom_base_url = gr.Textbox(
label="Custom Base URL (Optional)",
placeholder="e.g., https://api.your-provider.com/v1",
info="Leave empty to use default. For custom providers only.",
visible=False
)
custom_model = gr.Textbox(
label="Model Name (Optional)",
placeholder="e.g., gpt-4o, deepseek-chat",
info="Leave empty to use recommended default",
visible=False
)
def update_custom_fields(choice):
"""Show/hide custom fields based on API choice"""
if choice == "custom":
return gr.update(visible=True), gr.update(visible=True)
elif choice == "deepseek":
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=False), gr.update(visible=False)
api_selector.change(
fn=update_custom_fields,
inputs=[api_selector],
outputs=[custom_base_url, custom_model]
)
theme_input = gr.Textbox(
label="Theme",
placeholder="e.g., Love lost in the city, Dreams and hope, Rainy day memories...",
lines=2
)
tags_gen = gr.Textbox(
label="Music Style/Tags",
placeholder="e.g., piano, melancholy, pop",
value="pop,emotional"
)
language_select = gr.Radio(
choices=[
("English", "en"),
("中文 (Chinese)", "zh"),
("日本語 (Japanese)", "jp"),
("한국어 (Korean)", "kr"),
("Español (Spanish)", "sp")
],
value="en",
label="Language"
)
generate_lyrics_btn = gr.Button(
"Generate Lyrics",
variant="primary",
size="lg"
)
with gr.Column():
with gr.Accordion("How to Use", open=True):
gr.Markdown("""
### How to Generate Lyrics
**Theme**: Describe your song's story or emotion
- Examples: "Lost love in Tokyo", "Overcoming obstacles", "Summer road trip"
**Music Style/Tags**: Define mood and genre
- Examples: "piano,melancholy,ballad", "upbeat,electronic,dance"
**Tips**
- Generated lyrics follow standard song structure ([Intro], [Verse], [Chorus], etc.)
- Edit lyrics before using for music generation
- Be specific with themes for better results
""")
generated_lyrics_output = gr.Textbox(
label="Generated Lyrics",
lines=20,
placeholder="Generated lyrics will appear here...",
interactive=False
)
copy_to_music_gen = gr.Button(
"Copy to Music Generation Tab",
size="sm"
)
# Lyrics generation button callback
generate_lyrics_btn.click(
fn=generate_lyrics,
inputs=[theme_input, tags_gen, language_select, api_selector, api_key_input, custom_base_url, custom_model],
outputs=[generated_lyrics_output]
)
# Copy lyrics to music generation tab
def copy_lyrics(lyrics_text):
return lyrics_text
copy_to_music_gen.click(
fn=copy_lyrics,
inputs=[generated_lyrics_output],
outputs=[lyrics]
)
return demo
demo = create_ui()
demo.queue(max_size=GRADIO_QUEUE_MAX_SIZE, default_concurrency_limit=GRADIO_DEFAULT_CONCURRENCY)
if __name__ == "__main__":
port = int(os.environ.get("PORT", "7860"))
demo.launch(
server_name="0.0.0.0",
server_port=port,
allowed_paths=[str(DATA_DIR), "/tmp"],
ssr_mode=False,
)