testmula / examples /web_demo.py
ABLingss's picture
second init
ed8503f
import argparse
import asyncio
import json
import os
import re
import subprocess
import tempfile
import uuid
import wave
from pathlib import Path
from typing import Any, Dict, Iterator, Optional, Tuple
import numpy as np
import gradio as gr
import torch
from fastapi import FastAPI, WebSocket
from google import genai
from google.genai import types
from openai import OpenAI
from starlette.concurrency import iterate_in_threadpool
from starlette.staticfiles import StaticFiles
from transformers import BitsAndBytesConfig
from heartlib import HeartMuLaGenPipeline, HeartTranscriptorPipeline
# Global model path
MODEL_PATH = None
# LLM API Presets
LLM_PRESETS = {
"gemini": {
"name": "Google Gemini",
"api_type": "gemini",
"default_model": "gemini-2.0-flash-lite",
"env_key": "GEMINI_API_KEY",
"base_url": None,
},
"openai": {
"name": "OpenAI",
"api_type": "openai",
"default_model": "gpt-4o-mini",
"env_key": "OPENAI_API_KEY",
"base_url": None,
},
"deepseek": {
"name": "DeepSeek",
"api_type": "openai",
"default_model": "deepseek-chat",
"env_key": "DEEPSEEK_API_KEY",
"base_url": "https://api.deepseek.com",
},
"custom": {
"name": "Custom OpenAI-Compatible",
"api_type": "openai",
"default_model": "custom-model",
"env_key": None,
"base_url": None,
}
}
# Default example from assets
EXAMPLE_LYRICS = """[Intro]
[Verse]
The sun creeps in across the floor
I hear the traffic outside the door
The coffee pot begins to hiss
It is another morning just like this
[Prechorus]
The world keeps spinning round and round
Feet are planted on the ground
I find my rhythm in the sound
[Chorus]
Every day the light returns
Every day the fire burns
We keep on walking down this street
Moving to the same steady beat
It is the ordinary magic that we meet
[Verse]
The hours tick deeply into noon
Chasing shadows,chasing the moon
Work is done and the lights go low
Watching the city start to glow
[Bridge]
It is not always easy,not always bright
Sometimes we wrestle with the night
But we make it to the morning light
[Chorus]
Every day the light returns
Every day the fire burns
We keep on walking down this street
Moving to the same steady beat
[Outro]
Just another day
Every single day"""
EXAMPLE_TAGS = "piano,happy"
# Tag categories for selection
TAG_DATA = {
"Gender": [
"Male", "Female"
],
"Genre": [
"Pop", "Folk", "Ballad", "Electronic", "Rock", "Acoustic", "R&B",
"Indie", "Dance", "Indie Pop", "J-Pop", "Hip-Hop", "Country",
"Latin", "Alternative", "Christian", "Cantopop", "Gospel", "Soul",
"Mandopop"
],
"Instrument": [
"Drums", "Piano", "Guitar", "Strings", "Synthesizer", "Bass",
"Acoustic Guitar", "Keyboard", "Electronic Drums", "Vocals",
"Drum Machine", "Electric Guitar", "Percussion", "Beat",
"Orchestra", "Saxophone", "Accordion", "Voice", "String", "Vocal"
],
"Mood": [
"Melancholy", "Romantic", "Energetic", "Hopeful", "Dreamy",
"Relaxed", "Sad", "Calm", "Cheerful", "Reflective", "Emotional",
"Joyful", "Sentimental", "Uplifting", "Warm", "Peaceful", "Upbeat",
"Gentle", "Nostalgic", "Epic"
],
"Scene": [
"Driving", "Road Trip", "Cafe", "Relaxing", "Wedding", "Meditation",
"Workout", "Walking", "Alone", "Travel", "Reflection", "Rainy Day",
"Night", "Church", "Coffee Shop", "Gym", "Gaming", "Study",
"Dating", "Date"
],
"Singer Timbre": [
"Soft", "Clear", "Warm", "Gentle", "Smooth", "Sweet", "Emotional",
"Mellow", "Powerful", "Youthful", "Bright", "Rough", "Raspy",
"Melodic", "Deep", "Soulful", "Strong", "Energetic", "Breathy",
"Passionate"
],
"Topic": [
"Love", "Relationship", "Hope", "Longing", "Loss", "Heartbreak",
"Memory", "Reflection", "Life", "Faith", "Regret", "Freedom",
"Breakup", "Nature", "Loneliness", "Dreams", "Nostalgia", "Romance",
"Friendship", "Youth"
]
}
DATA_DIR = Path(os.environ.get("HEARTMULA_DATA_DIR", os.path.join(tempfile.gettempdir(), "heartmula_stream")))
DATA_DIR.mkdir(parents=True, exist_ok=True)
STREAM_MAX_CONCURRENCY = int(os.environ.get("STREAM_MAX_CONCURRENCY", "1"))
STREAM_MAX_QUEUE = int(os.environ.get("STREAM_MAX_QUEUE", "20"))
STREAM_MAX_PER_USER = int(os.environ.get("STREAM_MAX_PER_USER", "1"))
STREAM_SESSION_TTL_SEC = int(os.environ.get("STREAM_SESSION_TTL_SEC", "900"))
class StreamingQueue:
def __init__(self, max_concurrency: int = 1, max_queue: int = 20, max_per_user: int = 1):
self._max_concurrency = max(1, max_concurrency)
self._max_queue = max(1, max_queue)
self._max_per_user = max(1, max_per_user)
self._queue = []
self._running = set()
self._per_user = {}
self._enqueued_at = {}
self._lock = asyncio.Lock()
async def join(self, request_id: str, user_id: str) -> Tuple[bool, str]:
async with self._lock:
if len(self._queue) >= self._max_queue:
return False, "Queue is full. Please try again later."
if self._per_user.get(user_id, 0) >= self._max_per_user:
return False, "You already have a pending request."
if request_id not in self._queue and request_id not in self._running:
self._queue.append(request_id)
self._per_user[user_id] = self._per_user.get(user_id, 0) + 1
self._enqueued_at[request_id] = asyncio.get_event_loop().time()
return True, "OK"
async def acquire(self, request_id: str) -> bool:
async with self._lock:
if self._queue and self._queue[0] == request_id and len(self._running) < self._max_concurrency:
self._queue.pop(0)
self._running.add(request_id)
return True
return False
async def release(self, request_id: str, user_id: Optional[str] = None):
async with self._lock:
self._running.discard(request_id)
if user_id:
self._per_user[user_id] = max(0, self._per_user.get(user_id, 1) - 1)
self._enqueued_at.pop(request_id, None)
async def get_wait_info(self, request_id: str) -> Tuple[int, int]:
async with self._lock:
ahead = self._queue.index(request_id) if request_id in self._queue else 0
est_wait = ahead * 60
return ahead, est_wait
async def cancel(self, request_id: str, user_id: Optional[str] = None):
async with self._lock:
if request_id in self._queue:
self._queue.remove(request_id)
if user_id:
self._per_user[user_id] = max(0, self._per_user.get(user_id, 1) - 1)
self._enqueued_at.pop(request_id, None)
async def cleanup_expired(self, ttl_sec: int):
async with self._lock:
now = asyncio.get_event_loop().time()
expired = [rid for rid, t in self._enqueued_at.items() if now - t > ttl_sec]
for rid in expired:
if rid in self._queue:
self._queue.remove(rid)
self._enqueued_at.pop(rid, None)
stream_queue = StreamingQueue(
max_concurrency=STREAM_MAX_CONCURRENCY,
max_queue=STREAM_MAX_QUEUE,
max_per_user=STREAM_MAX_PER_USER,
)
stream_sessions: Dict[str, Dict[str, Any]] = {}
stream_results: Dict[str, str] = {}
class ModelManager:
def __init__(self, model_path: str):
self.model_path = model_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
self._gen_pipes: Dict[Tuple[str, str, str], HeartMuLaGenPipeline] = {}
self._transcribe_pipe: Optional[HeartTranscriptorPipeline] = None
self.use_deepspeed = os.getenv("USE_DEEPSPEED_INFERENCE", "0").lower() in ("1", "true", "yes")
self.ds_inference_config = self._make_ds_inference_config()
def _make_ds_inference_config(self) -> Dict[str, Any]:
if not self.use_deepspeed:
return {}
mp_size = int(os.getenv("DEEPSPEED_TP_SIZE", os.getenv("WORLD_SIZE", "1")))
replace_method = os.getenv("DEEPSPEED_REPLACE_METHOD", "auto")
kernel_inject = os.getenv("DEEPSPEED_KERNEL_INJECT", "1").lower() in ("1", "true", "yes")
return {
"mp_size": mp_size,
"dtype": self.dtype,
"replace_method": replace_method,
"replace_with_kernel_inject": kernel_inject,
}
def _make_bnb_config(self, quant_mode: str) -> Optional[BitsAndBytesConfig]:
if quant_mode == "none":
return None
if self.device.type != "cuda":
raise gr.Error("Quantization requires CUDA.")
if quant_mode == "4bit":
quant_type = "nf4"
try:
major, _ = torch.cuda.get_device_capability()
if major >= 10:
quant_type = "fp4"
except Exception:
pass
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=quant_type,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
if quant_mode == "8bit":
return BitsAndBytesConfig(load_in_8bit=True)
raise gr.Error(f"Unknown quant mode: {quant_mode}")
def get_gen_pipeline(self, version: str, codec_version: str, quant_mode: str) -> HeartMuLaGenPipeline:
key = (version, codec_version, quant_mode)
if key not in self._gen_pipes:
bnb_config = self._make_bnb_config(quant_mode)
self._gen_pipes[key] = HeartMuLaGenPipeline.from_pretrained(
self.model_path,
device=self.device,
dtype=self.dtype,
version=version,
codec_version=codec_version,
bnb_config=bnb_config,
lazy_load=True,
use_deepspeed=self.use_deepspeed,
ds_inference_config=self.ds_inference_config,
)
return self._gen_pipes[key]
def get_transcriptor(self) -> HeartTranscriptorPipeline:
if self._transcribe_pipe is None:
self._transcribe_pipe = HeartTranscriptorPipeline.from_pretrained(
self.model_path,
device=self.device,
dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
)
return self._transcribe_pipe
model_manager: Optional[ModelManager] = None
def get_model_manager() -> ModelManager:
global model_manager
if model_manager is None:
os.makedirs(MODEL_PATH, exist_ok=True)
download_models_if_needed(MODEL_PATH)
model_manager = ModelManager(MODEL_PATH)
return model_manager
def update_tag_string(*args):
"""
Collects selected tags from all categories and joins them.
args: list of lists (selections from each CheckboxGroup)
"""
all_tags = []
for selection in args:
if selection:
if isinstance(selection, list):
all_tags.extend(selection)
else:
all_tags.append(selection)
# Remove duplicates while preserving order
seen = set()
unique_tags = []
for t in all_tags:
if t not in seen:
unique_tags.append(t)
seen.add(t)
return ",".join(unique_tags)
def process_lyrics_correct(content):
"""
Correct lyrics processing logic aligned with training data.
1. Removes timestamps [xx:xx].
2. Split lines and strip whitespace from each line.
3. Remove leading/trailing empty lines.
4. Collapse multiple newlines (3 or more) into 2.
"""
content = content.lower()
content = re.sub(r"\[[^\]]*\d{1,2}:\d{2}[^\]]*\]", "", content)
lines = [line.strip() for line in content.split("\n")]
while lines and lines[0] == "":
lines.pop(0)
while lines and lines[-1] == "":
lines.pop()
content = "\n".join(lines)
content = re.sub(r"\n{3,}", "\n\n", content)
return content
def save_audio_to_wav(sample_rate: int, audio_np: np.ndarray, output_dir: Path) -> str:
output_dir.mkdir(parents=True, exist_ok=True)
unique_id = str(uuid.uuid4())
wav_path = output_dir / f"{unique_id}.wav"
audio_int16 = (audio_np * 32767).astype(np.int16)
with wave.open(str(wav_path), "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_int16.tobytes())
return str(wav_path)
def convert_wav_to_mp3(wav_path: str, output_dir: Path) -> str:
output_dir.mkdir(parents=True, exist_ok=True)
unique_id = str(uuid.uuid4())
mp3_path = output_dir / f"{unique_id}.mp3"
cmd = [
"ffmpeg",
"-y",
"-i",
wav_path,
"-codec:a",
"libmp3lame",
"-qscale:a",
"2",
str(mp3_path),
]
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return str(mp3_path)
def check_models_exist(ckpt_dir):
required_files = [
os.path.join(ckpt_dir, "gen_config.json"),
os.path.join(ckpt_dir, "tokenizer.json"),
os.path.join(ckpt_dir, "HeartCodec-oss"),
os.path.join(ckpt_dir, "HeartMuLa-oss-3B"),
]
for path in required_files:
if not os.path.exists(path):
return False
return True
def download_models_if_needed(ckpt_dir):
if check_models_exist(ckpt_dir):
return
from modelscope import snapshot_download
snapshot_download("HeartMuLa/HeartMuLaGen", local_dir=ckpt_dir)
snapshot_download("HeartMuLa/HeartMuLa-oss-3B", local_dir=os.path.join(ckpt_dir, "HeartMuLa-oss-3B"))
snapshot_download("HeartMuLa/HeartCodec-oss", local_dir=os.path.join(ckpt_dir, "HeartCodec-oss"))
def check_transcriptor_exists(ckpt_dir):
return os.path.exists(os.path.join(ckpt_dir, "HeartTranscriptor-oss"))
def download_transcriptor_if_needed(ckpt_dir):
if check_transcriptor_exists(ckpt_dir):
return
from modelscope import snapshot_download
snapshot_download("HeartMuLa/HeartTranscriptor-oss", local_dir=os.path.join(ckpt_dir, "HeartTranscriptor-oss"))
def load_pipeline(model_path, version, codec_version, quant_mode):
manager = get_model_manager()
return manager.get_gen_pipeline(version, codec_version, quant_mode)
def load_transcriptor(model_path):
download_transcriptor_if_needed(model_path)
manager = get_model_manager()
return manager.get_transcriptor()
def generate(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
):
if not lyrics.strip():
raise gr.Error("Please enter lyrics")
if not tags.strip():
raise gr.Error("Please enter tags")
if backend == "exllama_v2":
raise gr.Error("ExLlamaV2 backend is not implemented yet.")
max_audio_length_ms = int(duration_sec * 1000)
pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
output_path = f.name
try:
with torch.no_grad():
pipe(
{
"lyrics": lyrics,
"tags": tags,
},
max_audio_length_ms=max_audio_length_ms,
save_path=output_path,
topk=topk,
temperature=temperature,
cfg_scale=cfg_scale,
keep_model_loaded=keep_model_loaded,
offload_mode=offload_mode,
)
if output_format == "mp3":
return convert_wav_to_mp3(output_path, DATA_DIR)
return output_path
except Exception as e:
raise gr.Error(f"Generation error: {str(e)}")
def transcribe_audio(audio_path, task, max_new_tokens, num_beams, temperature):
if not audio_path:
raise gr.Error("Please upload an audio file")
pipe = load_transcriptor(MODEL_PATH)
try:
with torch.no_grad():
result = pipe(
audio_path,
**{
"max_new_tokens": int(max_new_tokens),
"num_beams": int(num_beams),
"task": task,
"condition_on_prev_tokens": False,
"compression_ratio_threshold": 1.8,
"temperature": float(temperature),
"logprob_threshold": -1.0,
"no_speech_threshold": 0.4,
},
)
if isinstance(result, dict):
return result.get("text", "")
return str(result)
except Exception as e:
raise gr.Error(f"Transcription error: {str(e)}")
def _normalize_stream_chunk(chunk: np.ndarray) -> np.ndarray:
chunk = np.nan_to_num(chunk, nan=0.0, posinf=0.0, neginf=0.0)
return np.clip(chunk, -1.0, 1.0)
def generate_music_streaming(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
keep_model_loaded,
offload_mode,
backend,
chunk_frames,
) -> Iterator[Tuple[int, np.ndarray]]:
if backend == "exllama_v2":
raise gr.Error("ExLlamaV2 backend is not implemented yet.")
pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode)
max_audio_length_ms = int(duration_sec * 1000)
for chunk in pipe.stream(
{"lyrics": lyrics, "tags": tags},
max_audio_length_ms=max_audio_length_ms,
temperature=temperature,
topk=topk,
cfg_scale=cfg_scale,
chunk_frames=int(chunk_frames),
keep_model_loaded=keep_model_loaded,
offload_mode=offload_mode,
):
if chunk.dim() == 2:
chunk = chunk.squeeze(0)
chunk_np = chunk.cpu().numpy()
yield 48000, _normalize_stream_chunk(chunk_np)
def prepare_streaming_session(
lyrics,
tags,
cfg_scale,
duration_sec,
temperature,
topk,
version,
codec_version,
quant_mode,
keep_model_loaded,
offload_mode,
backend,
output_format,
chunk_frames,
user_id,
):
active_for_user = [
sid for sid, meta in stream_sessions.items()
if meta.get("user_id") == user_id
]
if len(active_for_user) >= STREAM_MAX_PER_USER:
raise gr.Error("You already have a pending streaming request.")
token = str(uuid.uuid4())
stream_sessions[token] = {
"lyrics": lyrics,
"tags": tags,
"cfg_scale": cfg_scale,
"duration_sec": duration_sec,
"temperature": temperature,
"topk": topk,
"version": version,
"codec_version": codec_version,
"quant_mode": quant_mode,
"keep_model_loaded": keep_model_loaded,
"offload_mode": offload_mode,
"backend": backend,
"output_format": output_format,
"chunk_frames": chunk_frames,
"user_id": user_id,
"created_at": asyncio.get_event_loop().time(),
}
return token
def load_audio_from_result(token):
if not token:
return None
path = stream_results.get(token)
if path and os.path.exists(path):
return path
return None
def cancel_stream(token, user_id):
if not token:
return "No active stream."
meta = stream_sessions.get(token)
if not meta or meta.get("user_id") != user_id:
return "No active stream for this session."
stream_sessions.pop(token, None)
return "Canceled."
async def websocket_stream_endpoint(websocket: WebSocket, token: str):
await websocket.accept()
if token not in stream_sessions:
await websocket.send_text(json.dumps({"type": "error", "message": "Invalid or expired session"}))
await websocket.close()
return
params = stream_sessions[token]
request_id = token
user_id = params.get("user_id")
await stream_queue.cleanup_expired(STREAM_SESSION_TTL_SEC)
ok, msg = await stream_queue.join(request_id, user_id or "")
if not ok:
await websocket.send_text(json.dumps({"type": "error", "message": msg}))
await websocket.close()
return
try:
while True:
if await stream_queue.acquire(request_id):
break
if token not in stream_sessions:
await websocket.send_text(json.dumps({"type": "error", "message": "Stream canceled."}))
await websocket.close()
return
ahead_count, est_wait = await stream_queue.get_wait_info(request_id)
await websocket.send_text(json.dumps({
"type": "queue",
"ahead": ahead_count,
"wait_seconds": int(est_wait),
"message": f"Waiting in queue... {ahead_count} ahead (Est. {int(est_wait)}s)",
}))
await asyncio.sleep(1.0)
await websocket.send_text(json.dumps({"type": "config", "sample_rate": 48000}))
iterator = generate_music_streaming(
lyrics=params["lyrics"],
tags=params["tags"],
cfg_scale=params["cfg_scale"],
duration_sec=params["duration_sec"],
temperature=params["temperature"],
topk=params["topk"],
version=params["version"],
codec_version=params["codec_version"],
quant_mode=params["quant_mode"],
keep_model_loaded=params["keep_model_loaded"],
offload_mode=params["offload_mode"],
backend=params["backend"],
chunk_frames=params["chunk_frames"],
)
total_generated_sec = 0.0
full_audio_buffer = []
async for item in iterate_in_threadpool(iterator):
if token not in stream_sessions:
await websocket.send_text(json.dumps({"type": "error", "message": "Stream canceled."}))
break
sr, chunk_np = item
full_audio_buffer.append(chunk_np)
chunk_duration = chunk_np.shape[0] / sr
total_generated_sec += chunk_duration
progress_val = min(total_generated_sec / max(params["duration_sec"], 1), 0.99)
await websocket.send_text(json.dumps({
"type": "progress",
"value": progress_val,
"text": f"Generated {total_generated_sec:.1f}s",
}))
await websocket.send_bytes(chunk_np.astype("float32").tobytes())
await websocket.send_text(json.dumps({
"type": "progress",
"value": 1.0,
"text": f"Generated {total_generated_sec:.1f}s (Complete)",
}))
saved_path = ""
download_url = ""
if full_audio_buffer:
full_audio_np = np.concatenate(full_audio_buffer)
wav_path = save_audio_to_wav(48000, full_audio_np, DATA_DIR)
saved_path = wav_path
if params["output_format"] == "mp3":
saved_path = convert_wav_to_mp3(wav_path, DATA_DIR)
download_url = f"download/{Path(saved_path).name}"
if saved_path:
stream_results[token] = str(saved_path)
await websocket.send_text(json.dumps({
"type": "download",
"url": download_url,
"filename": Path(saved_path).name,
}))
await websocket.send_text(json.dumps({"type": "complete"}))
except Exception as e:
await websocket.send_text(json.dumps({"type": "error", "message": str(e)}))
finally:
await stream_queue.release(request_id, user_id)
if token in stream_sessions:
del stream_sessions[token]
try:
await websocket.close()
except Exception:
pass
def generate_lyrics(theme, tags, language, api_choice, api_key_input, custom_base_url, custom_model, progress=gr.Progress()):
if not theme.strip():
raise gr.Error("Please enter a theme")
progress(0.1, desc="Preparing request...")
if api_choice not in LLM_PRESETS:
raise gr.Error(f"Unknown API choice: {api_choice}")
preset = LLM_PRESETS[api_choice]
api_key = api_key_input.strip() if api_key_input and api_key_input.strip() else None
if not api_key and preset["env_key"]:
api_key = os.environ.get(preset["env_key"])
if not api_key:
raise gr.Error("No API key provided. Please enter your API key in the field above.")
base_url = custom_base_url.strip() if custom_base_url and custom_base_url.strip() else preset["base_url"]
model_name = custom_model.strip() if custom_model and custom_model.strip() else preset["default_model"]
language_names = {
"en": "English",
"zh": "Chinese",
"jp": "Japanese",
"kr": "Korean",
"sp": "Spanish"
}
lang_name = language_names.get(language, "English")
tags_text = tags.strip() if tags.strip() else "pop, emotional"
prompt = f"""You are a professional songwriter. Generate song lyrics based on the following requirements:
**Theme**: {theme}
**Music Style/Tags**: {tags_text}
**Language**: {lang_name}
**Format Requirements** (CRITICAL):
1. Use lowercase for all lyrics text (except structure tags which are in brackets)
2. Include proper song structure tags: [Intro], [Verse], [Prechorus], [Chorus], [Bridge], [Outro]
3. Each structure tag should be on its own line
4. Separate different sections with a blank line (one empty line between sections)
5. NO timestamps like [00:12] - only structure tags allowed
6. Keep lyrics concise and suitable for a 3-4 minute song
**Structure Guidelines**:
- [Intro]: Optional, 1-2 lines if included
- [Verse]: Story-telling part, 4-6 lines, can repeat with different lyrics
- [Prechorus]: Optional, 2-4 lines, builds tension before chorus
- [Chorus]: Main hook, catchy and repetitive, 4-6 lines
- [Bridge]: Optional, provides contrast, 4-6 lines
- [Outro]: Closing, 1-2 lines
**Example Format**:
```
[Intro]
[Verse]
the sun creeps in across the floor
i hear the traffic outside the door
the coffee pot begins to hiss
it is another morning just like this
[Chorus]
every day the light returns
every day the fire burns
we keep on walking down this street
moving to the same steady beat
```
Now generate lyrics in {lang_name} based on the theme "{theme}" with style "{tags_text}".
Output ONLY the lyrics with structure tags, no explanations.
"""
try:
if preset["api_type"] == "gemini":
progress(0.3, desc=f"Connecting to {preset['name']}...")
client = genai.Client(api_key=api_key)
progress(0.5, desc=f"Generating lyrics with {preset['name']}...")
response = client.models.generate_content(
model=model_name,
contents=[
types.Content(
role='user',
parts=[types.Part(text=prompt)]
)
],
config=types.GenerateContentConfig(
temperature=0.8,
max_output_tokens=2000
)
)
generated_lyrics = response.text.strip()
elif preset["api_type"] == "openai":
progress(0.3, desc=f"Connecting to {preset['name']}...")
if base_url:
client = OpenAI(api_key=api_key, base_url=base_url)
else:
client = OpenAI(api_key=api_key)
progress(0.5, desc=f"Generating lyrics with {preset['name']}...")
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": "You are a professional songwriter who creates well-structured lyrics."},
{"role": "user", "content": prompt}
],
temperature=0.8,
max_tokens=2000
)
generated_lyrics = response.choices[0].message.content.strip()
else:
raise gr.Error(f"Unknown API type: {preset['api_type']}")
progress(0.9, desc="Processing response...")
if generated_lyrics.startswith("```"):
lines = generated_lyrics.split("\n")
generated_lyrics = "\n".join(lines[1:-1]) if len(lines) > 2 else generated_lyrics
generated_lyrics = process_lyrics_correct(generated_lyrics)
progress(1.0, desc="Done!")
return generated_lyrics
except Exception as e:
raise gr.Error(f"Lyrics generation error: {str(e)}")
STREAMING_JS = """
async (token) => {
if (!token) {
console.error("No token provided");
alert("Session initialization failed.");
return;
}
return new Promise((resolve, reject) => {
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const host = window.location.host;
const wsUrl = `${protocol}//${host}/ws_stream/${token}`;
const ws = new WebSocket(wsUrl);
ws.binaryType = "arraybuffer";
const AudioContext = window.AudioContext || window.webkitAudioContext;
const ctx = new AudioContext();
let nextTime = 0;
let sampleRate = 48000;
ws.onopen = () => {
if (ctx.state === "suspended") ctx.resume();
};
ws.onmessage = (event) => {
const data = event.data;
if (typeof data === "string") {
const msg = JSON.parse(data);
if (msg.type === "queue") {
const progressText = document.getElementById("stream-progress-text");
if (progressText) progressText.innerText = msg.message;
} else if (msg.type === "config") {
sampleRate = msg.sample_rate;
const progressContainer = document.getElementById("stream-progress-container");
const controlsContainer = document.getElementById("stream-controls-container");
if (progressContainer) progressContainer.style.display = "block";
if (controlsContainer) controlsContainer.style.display = "flex";
} else if (msg.type === "progress") {
const progressBar = document.getElementById("stream-progress-bar");
const progressText = document.getElementById("stream-progress-text");
if (progressBar) progressBar.style.width = (msg.value * 100) + "%";
if (progressText) progressText.innerText = msg.text;
} else if (msg.type === "download") {
const downloadLink = document.getElementById("stream-download-link");
if (downloadLink) {
downloadLink.href = msg.url;
downloadLink.download = msg.filename;
downloadLink.style.display = "inline-block";
downloadLink.innerText = "Download Full Audio";
}
} else if (msg.type === "error") {
alert("Error: " + msg.message);
ws.close();
reject(msg.message);
} else if (msg.type === "complete") {
ws.close();
resolve(token);
}
return;
}
if (data instanceof ArrayBuffer) {
const float32 = new Float32Array(data);
if (float32.length === 0) return;
const buffer = ctx.createBuffer(1, float32.length, sampleRate);
buffer.copyToChannel(float32, 0);
const source = ctx.createBufferSource();
source.buffer = buffer;
source.connect(ctx.destination);
const now = ctx.currentTime;
if (nextTime < now) nextTime = now + 0.05;
source.start(nextTime);
nextTime += buffer.duration;
}
};
window.toggleStreamPlayback = () => {
if (ctx.state === "running") {
ctx.suspend();
document.getElementById("stream-play-btn").innerText = "Resume";
} else if (ctx.state === "suspended") {
ctx.resume();
document.getElementById("stream-play-btn").innerText = "Pause";
}
};
ws.onerror = (e) => {
console.error("WS Error", e);
reject(e);
};
ws.onclose = () => {};
});
}
"""
def create_ui():
with gr.Blocks(title="HeartMuLa Music Generation") as demo:
gr.Markdown("# HeartMuLa Music Generation")
gr.Markdown("Generate music from lyrics and style tags")
with gr.Tabs():
with gr.Tab("Music Generation"):
with gr.Row():
with gr.Column():
lyrics = gr.Textbox(
label="Lyrics",
lines=15,
value=EXAMPLE_LYRICS,
placeholder="Enter lyrics here..."
)
format_btn = gr.Button("Format Lyrics", size="sm")
gr.Markdown("### Tags")
tags = gr.Textbox(
label="Selected Tags (comma-separated)",
value=EXAMPLE_TAGS,
placeholder="e.g., piano,happy,pop",
lines=2
)
tag_checkboxes = []
with gr.Accordion("Tag Categories (Click to Expand)", open=False):
with gr.Row():
with gr.Column():
t1 = gr.CheckboxGroup(choices=TAG_DATA["Gender"], label="Gender")
tag_checkboxes.append(t1)
t2 = gr.CheckboxGroup(choices=TAG_DATA["Genre"], label="Genre")
tag_checkboxes.append(t2)
with gr.Column():
t3 = gr.CheckboxGroup(choices=TAG_DATA["Instrument"], label="Instrument")
tag_checkboxes.append(t3)
t4 = gr.CheckboxGroup(choices=TAG_DATA["Mood"], label="Mood")
tag_checkboxes.append(t4)
with gr.Column():
t5 = gr.CheckboxGroup(choices=TAG_DATA["Scene"], label="Scene")
tag_checkboxes.append(t5)
t6 = gr.CheckboxGroup(choices=TAG_DATA["Singer Timbre"], label="Singer Timbre")
tag_checkboxes.append(t6)
with gr.Column():
t7 = gr.CheckboxGroup(choices=TAG_DATA["Topic"], label="Topic")
tag_checkboxes.append(t7)
with gr.Row():
cfg_scale = gr.Slider(0.0, 3.0, value=1.5, step=0.1, label="CFG Scale")
duration = gr.Slider(10, 300, value=180, step=10, label="Duration (sec)")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
topk = gr.Slider(1, 100, value=50, step=1, label="Top-K")
with gr.Accordion("Advanced Settings", open=False):
backend = gr.Dropdown(
choices=[("HF Pipeline", "hf"), ("ExLlamaV2 (Experimental)", "exllama_v2")],
value="hf",
label="Backend"
)
version = gr.Dropdown(
choices=["3B", "7B", "RL-oss-3B-20260123"],
value="3B",
label="Model Version"
)
codec_version = gr.Dropdown(
choices=["oss", "oss-20260123"],
value="oss",
label="Codec Version"
)
quant_mode = gr.Dropdown(
choices=[("None", "none"), ("4-bit (NF4/FP4)", "4bit"), ("8-bit", "8bit")],
value="none",
label="Quantization"
)
keep_model_loaded = gr.Checkbox(value=True, label="Keep Model Loaded")
offload_mode = gr.Dropdown(
choices=["auto", "aggressive"],
value="auto",
label="Offload Mode"
)
output_format = gr.Radio(
choices=[("WAV", "wav"), ("MP3", "mp3")],
value="wav",
label="Output Format"
)
chunk_frames = gr.Slider(5, 100, value=20, step=1, label="Streaming Chunk Frames")
generate_btn = gr.Button("Generate Music", variant="primary", size="lg")
stream_btn = gr.Button("Generate Music (Streaming)", variant="primary", size="lg")
state_token = gr.Textbox(visible=False, label="Stream Token")
with gr.Column():
with gr.Accordion("Usage Notice", open=True):
gr.Markdown("""
### Lyrics Format Requirements
**Automatic Processing:**
1. All text will be converted to **lowercase**
2. Timestamps (e.g., [00:12]) will be **automatically removed**
3. Leading/trailing whitespace on each line will be **stripped**
4. Leading/trailing empty lines will be **removed**
5. Multiple consecutive empty lines (3+) will be **collapsed to 2**
**Recommended Format:**
- Use standard song structure tags: `[Intro]`, `[Verse]`, `[Chorus]`, `[Bridge]`, `[Outro]`, etc.
- Separate sections with **blank lines**
- Case doesn't matter (will be auto-converted)
**Example:**
```
[Intro]
[Verse]
The sun creeps in across the floor
I hear the traffic outside the door
[Chorus]
Every day the light returns
Every day the fire burns
```
---
### Tags Format
- Use **commas** to separate multiple tags: `piano,happy,pop`
- Tags influence the style and mood of the generated music
- Select from categories below or type directly
""")
output_audio = gr.Audio(label="Generated Music", type="filepath")
gr.HTML("""
<div id="stream-progress-container" style="width: 100%; background-color: #f0f0f0; border-radius: 8px; margin-top: 10px; display: none; overflow: hidden;">
<div id="stream-progress-bar" style="width: 0%; height: 24px; background-color: #4CAF50; transition: width 0.2s ease-in-out;"></div>
</div>
<div id="stream-progress-text" style="text-align: center; font-family: monospace; margin-top: 5px; color: #666;"></div>
<div id="stream-controls-container" style="display: none; justify-content: center; gap: 10px; margin-top: 15px;">
<button id="stream-play-btn" onclick="window.toggleStreamPlayback()" style="padding: 8px 16px; border-radius: 4px; border: 1px solid #ccc; background: white; cursor: pointer;">Pause</button>
</div>
<a id="stream-download-link" style="display:none; margin-top: 10px;" target="_blank"></a>
""")
for cb in tag_checkboxes:
cb.change(fn=update_tag_string, inputs=tag_checkboxes, outputs=tags)
format_btn.click(
fn=process_lyrics_correct,
inputs=[lyrics],
outputs=[lyrics]
)
generate_btn.click(
fn=generate,
inputs=[
lyrics,
tags,
cfg_scale,
duration,
temperature,
topk,
version,
codec_version,
quant_mode,
output_format,
keep_model_loaded,
offload_mode,
backend,
],
outputs=[output_audio]
)
stream_btn.click(
fn=prepare_streaming_session,
inputs=[
lyrics,
tags,
cfg_scale,
duration,
temperature,
topk,
version,
codec_version,
quant_mode,
keep_model_loaded,
offload_mode,
backend,
output_format,
chunk_frames,
],
outputs=[state_token]
).then(
fn=load_audio_from_result,
inputs=[state_token],
outputs=[output_audio],
js=STREAMING_JS,
)
with gr.Tab("Lyrics Generation"):
with gr.Row():
with gr.Column():
gr.Markdown("### Generate Lyrics with AI")
api_selector = gr.Radio(
choices=[
("Google Gemini", "gemini"),
("OpenAI", "openai"),
("DeepSeek", "deepseek"),
("Custom (OpenAI-compatible)", "custom")
],
value="gemini",
label="Select LLM Provider"
)
with gr.Accordion("API Configuration", open=True):
api_key_input = gr.Textbox(
label="API Key (Required)",
type="password",
placeholder="Enter your API key or set environment variable",
info="Will use environment variable if not provided here"
)
custom_base_url = gr.Textbox(
label="Custom Base URL (Optional)",
placeholder="e.g., https://api.your-provider.com/v1",
info="Leave empty to use default. For custom providers only.",
visible=False
)
custom_model = gr.Textbox(
label="Model Name (Optional)",
placeholder="e.g., gpt-4o, deepseek-chat",
info="Leave empty to use recommended default",
visible=False
)
def update_custom_fields(choice):
if choice == "custom":
return gr.update(visible=True), gr.update(visible=True)
if choice == "deepseek":
return gr.update(visible=False), gr.update(visible=True)
return gr.update(visible=False), gr.update(visible=False)
api_selector.change(
fn=update_custom_fields,
inputs=[api_selector],
outputs=[custom_base_url, custom_model]
)
theme_input = gr.Textbox(
label="Theme",
placeholder="e.g., Love lost in the city, Dreams and hope, Rainy day memories...",
lines=2
)
tags_gen = gr.Textbox(
label="Music Style/Tags",
placeholder="e.g., piano, melancholy, pop",
value="pop,emotional"
)
language_select = gr.Radio(
choices=[
("English", "en"),
("中文 (Chinese)", "zh"),
("日本語 (Japanese)", "jp"),
("한국어 (Korean)", "kr"),
("Español (Spanish)", "sp")
],
value="en",
label="Language"
)
generate_lyrics_btn = gr.Button(
"Generate Lyrics",
variant="primary",
size="lg"
)
with gr.Column():
with gr.Accordion("How to Use", open=True):
gr.Markdown("""
### How to Generate Lyrics
**Theme**: Describe your song's story or emotion
- Examples: "Lost love in Tokyo", "Overcoming obstacles", "Summer road trip"
**Music Style/Tags**: Define mood and genre
- Examples: "piano,melancholy,ballad", "upbeat,electronic,dance"
**Tips**
- Generated lyrics follow standard song structure ([Intro], [Verse], [Chorus], etc.)
- Edit lyrics before using for music generation
- Be specific with themes for better results
""")
generated_lyrics_output = gr.Textbox(
label="Generated Lyrics",
lines=20,
placeholder="Generated lyrics will appear here...",
interactive=False
)
copy_to_music_gen = gr.Button(
"Copy to Music Generation Tab",
size="sm"
)
generate_lyrics_btn.click(
fn=generate_lyrics,
inputs=[theme_input, tags_gen, language_select, api_selector, api_key_input, custom_base_url, custom_model],
outputs=[generated_lyrics_output]
)
def copy_lyrics(lyrics_text):
return lyrics_text
copy_to_music_gen.click(
fn=copy_lyrics,
inputs=[generated_lyrics_output],
outputs=[lyrics]
)
with gr.Tab("Lyrics Transcription"):
with gr.Row():
with gr.Column():
gr.Markdown("### Transcribe or Translate Lyrics from Audio")
audio_input = gr.Audio(
label="Audio Input",
type="filepath"
)
task_select = gr.Radio(
choices=[
("Transcribe (Original Language)", "transcribe"),
("Translate to English", "translate")
],
value="transcribe",
label="Task"
)
max_new_tokens = gr.Slider(
64, 512, value=256, step=16, label="Max New Tokens"
)
num_beams = gr.Slider(
1, 5, value=2, step=1, label="Beam Search"
)
transcribe_temperature = gr.Slider(
0.0, 0.8, value=0.2, step=0.1, label="Temperature"
)
transcribe_btn = gr.Button(
"Run Transcription",
variant="primary",
size="lg"
)
use_generated_audio = gr.Button(
"Use Generated Music",
size="sm"
)
with gr.Column():
with gr.Accordion("Notes", open=True):
gr.Markdown("""
### Notes
- Best results come from **vocals-only** stems.
- If you pass full mixes, consider source separation first.
- The HeartTranscriptor model will auto-download on first use.
""")
transcription_output = gr.Textbox(
label="Transcription Result",
lines=18,
placeholder="Transcribed lyrics will appear here...",
interactive=False
)
transcribe_btn.click(
fn=transcribe_audio,
inputs=[audio_input, task_select, max_new_tokens, num_beams, transcribe_temperature],
outputs=[transcription_output]
)
use_generated_audio.click(
fn=lambda x: x,
inputs=[output_audio],
outputs=[audio_input]
)
return demo
def build_app():
demo = create_ui()
demo.queue(max_size=8)
app = FastAPI()
app.mount("/download", StaticFiles(directory=str(DATA_DIR)), name="download")
app.add_api_websocket_route("/ws_stream/{token}", websocket_stream_endpoint)
app = gr.mount_gradio_app(app, demo, path="/", allowed_paths=[str(DATA_DIR)])
return app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="./ckpt")
parser.add_argument("--port", type=int, default=8888)
parser.add_argument("--host", type=str, default="0.0.0.0")
args = parser.parse_args()
MODEL_PATH = args.model_path
app = build_app()
import uvicorn
uvicorn.run(app, host=args.host, port=args.port)