| import argparse |
| import asyncio |
| import json |
| import os |
| import re |
| import subprocess |
| import tempfile |
| import uuid |
| import wave |
| from pathlib import Path |
| from typing import Any, Dict, Iterator, Optional, Tuple |
|
|
| import numpy as np |
| import gradio as gr |
| import torch |
| from fastapi import FastAPI, WebSocket |
| from google import genai |
| from google.genai import types |
| from openai import OpenAI |
| from starlette.concurrency import iterate_in_threadpool |
| from starlette.staticfiles import StaticFiles |
| from transformers import BitsAndBytesConfig |
|
|
| from heartlib import HeartMuLaGenPipeline, HeartTranscriptorPipeline |
|
|
| |
| MODEL_PATH = None |
|
|
| |
| LLM_PRESETS = { |
| "gemini": { |
| "name": "Google Gemini", |
| "api_type": "gemini", |
| "default_model": "gemini-2.0-flash-lite", |
| "env_key": "GEMINI_API_KEY", |
| "base_url": None, |
| }, |
| "openai": { |
| "name": "OpenAI", |
| "api_type": "openai", |
| "default_model": "gpt-4o-mini", |
| "env_key": "OPENAI_API_KEY", |
| "base_url": None, |
| }, |
| "deepseek": { |
| "name": "DeepSeek", |
| "api_type": "openai", |
| "default_model": "deepseek-chat", |
| "env_key": "DEEPSEEK_API_KEY", |
| "base_url": "https://api.deepseek.com", |
| }, |
| "custom": { |
| "name": "Custom OpenAI-Compatible", |
| "api_type": "openai", |
| "default_model": "custom-model", |
| "env_key": None, |
| "base_url": None, |
| } |
| } |
|
|
| |
| EXAMPLE_LYRICS = """[Intro] |
| |
| [Verse] |
| The sun creeps in across the floor |
| I hear the traffic outside the door |
| The coffee pot begins to hiss |
| It is another morning just like this |
| |
| [Prechorus] |
| The world keeps spinning round and round |
| Feet are planted on the ground |
| I find my rhythm in the sound |
| |
| [Chorus] |
| Every day the light returns |
| Every day the fire burns |
| We keep on walking down this street |
| Moving to the same steady beat |
| It is the ordinary magic that we meet |
| |
| [Verse] |
| The hours tick deeply into noon |
| Chasing shadows,chasing the moon |
| Work is done and the lights go low |
| Watching the city start to glow |
| |
| [Bridge] |
| It is not always easy,not always bright |
| Sometimes we wrestle with the night |
| But we make it to the morning light |
| |
| [Chorus] |
| Every day the light returns |
| Every day the fire burns |
| We keep on walking down this street |
| Moving to the same steady beat |
| |
| [Outro] |
| Just another day |
| Every single day""" |
|
|
| EXAMPLE_TAGS = "piano,happy" |
|
|
| |
| TAG_DATA = { |
| "Gender": [ |
| "Male", "Female" |
| ], |
| "Genre": [ |
| "Pop", "Folk", "Ballad", "Electronic", "Rock", "Acoustic", "R&B", |
| "Indie", "Dance", "Indie Pop", "J-Pop", "Hip-Hop", "Country", |
| "Latin", "Alternative", "Christian", "Cantopop", "Gospel", "Soul", |
| "Mandopop" |
| ], |
| "Instrument": [ |
| "Drums", "Piano", "Guitar", "Strings", "Synthesizer", "Bass", |
| "Acoustic Guitar", "Keyboard", "Electronic Drums", "Vocals", |
| "Drum Machine", "Electric Guitar", "Percussion", "Beat", |
| "Orchestra", "Saxophone", "Accordion", "Voice", "String", "Vocal" |
| ], |
| "Mood": [ |
| "Melancholy", "Romantic", "Energetic", "Hopeful", "Dreamy", |
| "Relaxed", "Sad", "Calm", "Cheerful", "Reflective", "Emotional", |
| "Joyful", "Sentimental", "Uplifting", "Warm", "Peaceful", "Upbeat", |
| "Gentle", "Nostalgic", "Epic" |
| ], |
| "Scene": [ |
| "Driving", "Road Trip", "Cafe", "Relaxing", "Wedding", "Meditation", |
| "Workout", "Walking", "Alone", "Travel", "Reflection", "Rainy Day", |
| "Night", "Church", "Coffee Shop", "Gym", "Gaming", "Study", |
| "Dating", "Date" |
| ], |
| "Singer Timbre": [ |
| "Soft", "Clear", "Warm", "Gentle", "Smooth", "Sweet", "Emotional", |
| "Mellow", "Powerful", "Youthful", "Bright", "Rough", "Raspy", |
| "Melodic", "Deep", "Soulful", "Strong", "Energetic", "Breathy", |
| "Passionate" |
| ], |
| "Topic": [ |
| "Love", "Relationship", "Hope", "Longing", "Loss", "Heartbreak", |
| "Memory", "Reflection", "Life", "Faith", "Regret", "Freedom", |
| "Breakup", "Nature", "Loneliness", "Dreams", "Nostalgia", "Romance", |
| "Friendship", "Youth" |
| ] |
| } |
|
|
| DATA_DIR = Path(os.environ.get("HEARTMULA_DATA_DIR", os.path.join(tempfile.gettempdir(), "heartmula_stream"))) |
| DATA_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| STREAM_MAX_CONCURRENCY = int(os.environ.get("STREAM_MAX_CONCURRENCY", "1")) |
| STREAM_MAX_QUEUE = int(os.environ.get("STREAM_MAX_QUEUE", "20")) |
| STREAM_MAX_PER_USER = int(os.environ.get("STREAM_MAX_PER_USER", "1")) |
| STREAM_SESSION_TTL_SEC = int(os.environ.get("STREAM_SESSION_TTL_SEC", "900")) |
|
|
|
|
| class StreamingQueue: |
| def __init__(self, max_concurrency: int = 1, max_queue: int = 20, max_per_user: int = 1): |
| self._max_concurrency = max(1, max_concurrency) |
| self._max_queue = max(1, max_queue) |
| self._max_per_user = max(1, max_per_user) |
| self._queue = [] |
| self._running = set() |
| self._per_user = {} |
| self._enqueued_at = {} |
| self._lock = asyncio.Lock() |
|
|
| async def join(self, request_id: str, user_id: str) -> Tuple[bool, str]: |
| async with self._lock: |
| if len(self._queue) >= self._max_queue: |
| return False, "Queue is full. Please try again later." |
| if self._per_user.get(user_id, 0) >= self._max_per_user: |
| return False, "You already have a pending request." |
| if request_id not in self._queue and request_id not in self._running: |
| self._queue.append(request_id) |
| self._per_user[user_id] = self._per_user.get(user_id, 0) + 1 |
| self._enqueued_at[request_id] = asyncio.get_event_loop().time() |
| return True, "OK" |
|
|
| async def acquire(self, request_id: str) -> bool: |
| async with self._lock: |
| if self._queue and self._queue[0] == request_id and len(self._running) < self._max_concurrency: |
| self._queue.pop(0) |
| self._running.add(request_id) |
| return True |
| return False |
|
|
| async def release(self, request_id: str, user_id: Optional[str] = None): |
| async with self._lock: |
| self._running.discard(request_id) |
| if user_id: |
| self._per_user[user_id] = max(0, self._per_user.get(user_id, 1) - 1) |
| self._enqueued_at.pop(request_id, None) |
|
|
| async def get_wait_info(self, request_id: str) -> Tuple[int, int]: |
| async with self._lock: |
| ahead = self._queue.index(request_id) if request_id in self._queue else 0 |
| est_wait = ahead * 60 |
| return ahead, est_wait |
|
|
| async def cancel(self, request_id: str, user_id: Optional[str] = None): |
| async with self._lock: |
| if request_id in self._queue: |
| self._queue.remove(request_id) |
| if user_id: |
| self._per_user[user_id] = max(0, self._per_user.get(user_id, 1) - 1) |
| self._enqueued_at.pop(request_id, None) |
|
|
| async def cleanup_expired(self, ttl_sec: int): |
| async with self._lock: |
| now = asyncio.get_event_loop().time() |
| expired = [rid for rid, t in self._enqueued_at.items() if now - t > ttl_sec] |
| for rid in expired: |
| if rid in self._queue: |
| self._queue.remove(rid) |
| self._enqueued_at.pop(rid, None) |
|
|
|
|
| stream_queue = StreamingQueue( |
| max_concurrency=STREAM_MAX_CONCURRENCY, |
| max_queue=STREAM_MAX_QUEUE, |
| max_per_user=STREAM_MAX_PER_USER, |
| ) |
| stream_sessions: Dict[str, Dict[str, Any]] = {} |
| stream_results: Dict[str, str] = {} |
|
|
|
|
| class ModelManager: |
| def __init__(self, model_path: str): |
| self.model_path = model_path |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| self._gen_pipes: Dict[Tuple[str, str, str], HeartMuLaGenPipeline] = {} |
| self._transcribe_pipe: Optional[HeartTranscriptorPipeline] = None |
| self.use_deepspeed = os.getenv("USE_DEEPSPEED_INFERENCE", "0").lower() in ("1", "true", "yes") |
| self.ds_inference_config = self._make_ds_inference_config() |
|
|
| def _make_ds_inference_config(self) -> Dict[str, Any]: |
| if not self.use_deepspeed: |
| return {} |
| mp_size = int(os.getenv("DEEPSPEED_TP_SIZE", os.getenv("WORLD_SIZE", "1"))) |
| replace_method = os.getenv("DEEPSPEED_REPLACE_METHOD", "auto") |
| kernel_inject = os.getenv("DEEPSPEED_KERNEL_INJECT", "1").lower() in ("1", "true", "yes") |
| return { |
| "mp_size": mp_size, |
| "dtype": self.dtype, |
| "replace_method": replace_method, |
| "replace_with_kernel_inject": kernel_inject, |
| } |
|
|
| def _make_bnb_config(self, quant_mode: str) -> Optional[BitsAndBytesConfig]: |
| if quant_mode == "none": |
| return None |
| if self.device.type != "cuda": |
| raise gr.Error("Quantization requires CUDA.") |
| if quant_mode == "4bit": |
| quant_type = "nf4" |
| try: |
| major, _ = torch.cuda.get_device_capability() |
| if major >= 10: |
| quant_type = "fp4" |
| except Exception: |
| pass |
| return BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type=quant_type, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
| if quant_mode == "8bit": |
| return BitsAndBytesConfig(load_in_8bit=True) |
| raise gr.Error(f"Unknown quant mode: {quant_mode}") |
|
|
| def get_gen_pipeline(self, version: str, codec_version: str, quant_mode: str) -> HeartMuLaGenPipeline: |
| key = (version, codec_version, quant_mode) |
| if key not in self._gen_pipes: |
| bnb_config = self._make_bnb_config(quant_mode) |
| self._gen_pipes[key] = HeartMuLaGenPipeline.from_pretrained( |
| self.model_path, |
| device=self.device, |
| dtype=self.dtype, |
| version=version, |
| codec_version=codec_version, |
| bnb_config=bnb_config, |
| lazy_load=True, |
| use_deepspeed=self.use_deepspeed, |
| ds_inference_config=self.ds_inference_config, |
| ) |
| return self._gen_pipes[key] |
|
|
| def get_transcriptor(self) -> HeartTranscriptorPipeline: |
| if self._transcribe_pipe is None: |
| self._transcribe_pipe = HeartTranscriptorPipeline.from_pretrained( |
| self.model_path, |
| device=self.device, |
| dtype=torch.float16 if self.device.type == "cuda" else torch.float32, |
| ) |
| return self._transcribe_pipe |
|
|
|
|
| model_manager: Optional[ModelManager] = None |
|
|
|
|
| def get_model_manager() -> ModelManager: |
| global model_manager |
| if model_manager is None: |
| os.makedirs(MODEL_PATH, exist_ok=True) |
| download_models_if_needed(MODEL_PATH) |
| model_manager = ModelManager(MODEL_PATH) |
| return model_manager |
|
|
|
|
| def update_tag_string(*args): |
| """ |
| Collects selected tags from all categories and joins them. |
| args: list of lists (selections from each CheckboxGroup) |
| """ |
| all_tags = [] |
| for selection in args: |
| if selection: |
| if isinstance(selection, list): |
| all_tags.extend(selection) |
| else: |
| all_tags.append(selection) |
| |
| seen = set() |
| unique_tags = [] |
| for t in all_tags: |
| if t not in seen: |
| unique_tags.append(t) |
| seen.add(t) |
| return ",".join(unique_tags) |
|
|
|
|
| def process_lyrics_correct(content): |
| """ |
| Correct lyrics processing logic aligned with training data. |
| 1. Removes timestamps [xx:xx]. |
| 2. Split lines and strip whitespace from each line. |
| 3. Remove leading/trailing empty lines. |
| 4. Collapse multiple newlines (3 or more) into 2. |
| """ |
| content = content.lower() |
| content = re.sub(r"\[[^\]]*\d{1,2}:\d{2}[^\]]*\]", "", content) |
| lines = [line.strip() for line in content.split("\n")] |
| while lines and lines[0] == "": |
| lines.pop(0) |
| while lines and lines[-1] == "": |
| lines.pop() |
| content = "\n".join(lines) |
| content = re.sub(r"\n{3,}", "\n\n", content) |
| return content |
|
|
|
|
| def save_audio_to_wav(sample_rate: int, audio_np: np.ndarray, output_dir: Path) -> str: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| unique_id = str(uuid.uuid4()) |
| wav_path = output_dir / f"{unique_id}.wav" |
| audio_int16 = (audio_np * 32767).astype(np.int16) |
| with wave.open(str(wav_path), "wb") as wav_file: |
| wav_file.setnchannels(1) |
| wav_file.setsampwidth(2) |
| wav_file.setframerate(sample_rate) |
| wav_file.writeframes(audio_int16.tobytes()) |
| return str(wav_path) |
|
|
|
|
| def convert_wav_to_mp3(wav_path: str, output_dir: Path) -> str: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| unique_id = str(uuid.uuid4()) |
| mp3_path = output_dir / f"{unique_id}.mp3" |
| cmd = [ |
| "ffmpeg", |
| "-y", |
| "-i", |
| wav_path, |
| "-codec:a", |
| "libmp3lame", |
| "-qscale:a", |
| "2", |
| str(mp3_path), |
| ] |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| return str(mp3_path) |
|
|
|
|
| def check_models_exist(ckpt_dir): |
| required_files = [ |
| os.path.join(ckpt_dir, "gen_config.json"), |
| os.path.join(ckpt_dir, "tokenizer.json"), |
| os.path.join(ckpt_dir, "HeartCodec-oss"), |
| os.path.join(ckpt_dir, "HeartMuLa-oss-3B"), |
| ] |
| for path in required_files: |
| if not os.path.exists(path): |
| return False |
| return True |
|
|
|
|
| def download_models_if_needed(ckpt_dir): |
| if check_models_exist(ckpt_dir): |
| return |
| from modelscope import snapshot_download |
| snapshot_download("HeartMuLa/HeartMuLaGen", local_dir=ckpt_dir) |
| snapshot_download("HeartMuLa/HeartMuLa-oss-3B", local_dir=os.path.join(ckpt_dir, "HeartMuLa-oss-3B")) |
| snapshot_download("HeartMuLa/HeartCodec-oss", local_dir=os.path.join(ckpt_dir, "HeartCodec-oss")) |
|
|
|
|
| def check_transcriptor_exists(ckpt_dir): |
| return os.path.exists(os.path.join(ckpt_dir, "HeartTranscriptor-oss")) |
|
|
|
|
| def download_transcriptor_if_needed(ckpt_dir): |
| if check_transcriptor_exists(ckpt_dir): |
| return |
| from modelscope import snapshot_download |
| snapshot_download("HeartMuLa/HeartTranscriptor-oss", local_dir=os.path.join(ckpt_dir, "HeartTranscriptor-oss")) |
|
|
|
|
| def load_pipeline(model_path, version, codec_version, quant_mode): |
| manager = get_model_manager() |
| return manager.get_gen_pipeline(version, codec_version, quant_mode) |
|
|
|
|
| def load_transcriptor(model_path): |
| download_transcriptor_if_needed(model_path) |
| manager = get_model_manager() |
| return manager.get_transcriptor() |
|
|
|
|
| def generate( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| ): |
| if not lyrics.strip(): |
| raise gr.Error("Please enter lyrics") |
| if not tags.strip(): |
| raise gr.Error("Please enter tags") |
| if backend == "exllama_v2": |
| raise gr.Error("ExLlamaV2 backend is not implemented yet.") |
|
|
| max_audio_length_ms = int(duration_sec * 1000) |
| pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode) |
|
|
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
| output_path = f.name |
|
|
| try: |
| with torch.no_grad(): |
| pipe( |
| { |
| "lyrics": lyrics, |
| "tags": tags, |
| }, |
| max_audio_length_ms=max_audio_length_ms, |
| save_path=output_path, |
| topk=topk, |
| temperature=temperature, |
| cfg_scale=cfg_scale, |
| keep_model_loaded=keep_model_loaded, |
| offload_mode=offload_mode, |
| ) |
|
|
| if output_format == "mp3": |
| return convert_wav_to_mp3(output_path, DATA_DIR) |
| return output_path |
| except Exception as e: |
| raise gr.Error(f"Generation error: {str(e)}") |
|
|
|
|
| def transcribe_audio(audio_path, task, max_new_tokens, num_beams, temperature): |
| if not audio_path: |
| raise gr.Error("Please upload an audio file") |
|
|
| pipe = load_transcriptor(MODEL_PATH) |
| try: |
| with torch.no_grad(): |
| result = pipe( |
| audio_path, |
| **{ |
| "max_new_tokens": int(max_new_tokens), |
| "num_beams": int(num_beams), |
| "task": task, |
| "condition_on_prev_tokens": False, |
| "compression_ratio_threshold": 1.8, |
| "temperature": float(temperature), |
| "logprob_threshold": -1.0, |
| "no_speech_threshold": 0.4, |
| }, |
| ) |
| if isinstance(result, dict): |
| return result.get("text", "") |
| return str(result) |
| except Exception as e: |
| raise gr.Error(f"Transcription error: {str(e)}") |
|
|
|
|
| def _normalize_stream_chunk(chunk: np.ndarray) -> np.ndarray: |
| chunk = np.nan_to_num(chunk, nan=0.0, posinf=0.0, neginf=0.0) |
| return np.clip(chunk, -1.0, 1.0) |
|
|
|
|
| def generate_music_streaming( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| chunk_frames, |
| ) -> Iterator[Tuple[int, np.ndarray]]: |
| if backend == "exllama_v2": |
| raise gr.Error("ExLlamaV2 backend is not implemented yet.") |
| pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode) |
| max_audio_length_ms = int(duration_sec * 1000) |
| for chunk in pipe.stream( |
| {"lyrics": lyrics, "tags": tags}, |
| max_audio_length_ms=max_audio_length_ms, |
| temperature=temperature, |
| topk=topk, |
| cfg_scale=cfg_scale, |
| chunk_frames=int(chunk_frames), |
| keep_model_loaded=keep_model_loaded, |
| offload_mode=offload_mode, |
| ): |
| if chunk.dim() == 2: |
| chunk = chunk.squeeze(0) |
| chunk_np = chunk.cpu().numpy() |
| yield 48000, _normalize_stream_chunk(chunk_np) |
|
|
|
|
| def prepare_streaming_session( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| output_format, |
| chunk_frames, |
| user_id, |
| ): |
| active_for_user = [ |
| sid for sid, meta in stream_sessions.items() |
| if meta.get("user_id") == user_id |
| ] |
| if len(active_for_user) >= STREAM_MAX_PER_USER: |
| raise gr.Error("You already have a pending streaming request.") |
|
|
| token = str(uuid.uuid4()) |
| stream_sessions[token] = { |
| "lyrics": lyrics, |
| "tags": tags, |
| "cfg_scale": cfg_scale, |
| "duration_sec": duration_sec, |
| "temperature": temperature, |
| "topk": topk, |
| "version": version, |
| "codec_version": codec_version, |
| "quant_mode": quant_mode, |
| "keep_model_loaded": keep_model_loaded, |
| "offload_mode": offload_mode, |
| "backend": backend, |
| "output_format": output_format, |
| "chunk_frames": chunk_frames, |
| "user_id": user_id, |
| "created_at": asyncio.get_event_loop().time(), |
| } |
| return token |
|
|
|
|
| def load_audio_from_result(token): |
| if not token: |
| return None |
| path = stream_results.get(token) |
| if path and os.path.exists(path): |
| return path |
| return None |
|
|
|
|
| def cancel_stream(token, user_id): |
| if not token: |
| return "No active stream." |
| meta = stream_sessions.get(token) |
| if not meta or meta.get("user_id") != user_id: |
| return "No active stream for this session." |
| stream_sessions.pop(token, None) |
| return "Canceled." |
|
|
|
|
| async def websocket_stream_endpoint(websocket: WebSocket, token: str): |
| await websocket.accept() |
| if token not in stream_sessions: |
| await websocket.send_text(json.dumps({"type": "error", "message": "Invalid or expired session"})) |
| await websocket.close() |
| return |
|
|
| params = stream_sessions[token] |
| request_id = token |
| user_id = params.get("user_id") |
| await stream_queue.cleanup_expired(STREAM_SESSION_TTL_SEC) |
| ok, msg = await stream_queue.join(request_id, user_id or "") |
| if not ok: |
| await websocket.send_text(json.dumps({"type": "error", "message": msg})) |
| await websocket.close() |
| return |
|
|
| try: |
| while True: |
| if await stream_queue.acquire(request_id): |
| break |
| if token not in stream_sessions: |
| await websocket.send_text(json.dumps({"type": "error", "message": "Stream canceled."})) |
| await websocket.close() |
| return |
| ahead_count, est_wait = await stream_queue.get_wait_info(request_id) |
| await websocket.send_text(json.dumps({ |
| "type": "queue", |
| "ahead": ahead_count, |
| "wait_seconds": int(est_wait), |
| "message": f"Waiting in queue... {ahead_count} ahead (Est. {int(est_wait)}s)", |
| })) |
| await asyncio.sleep(1.0) |
|
|
| await websocket.send_text(json.dumps({"type": "config", "sample_rate": 48000})) |
|
|
| iterator = generate_music_streaming( |
| lyrics=params["lyrics"], |
| tags=params["tags"], |
| cfg_scale=params["cfg_scale"], |
| duration_sec=params["duration_sec"], |
| temperature=params["temperature"], |
| topk=params["topk"], |
| version=params["version"], |
| codec_version=params["codec_version"], |
| quant_mode=params["quant_mode"], |
| keep_model_loaded=params["keep_model_loaded"], |
| offload_mode=params["offload_mode"], |
| backend=params["backend"], |
| chunk_frames=params["chunk_frames"], |
| ) |
|
|
| total_generated_sec = 0.0 |
| full_audio_buffer = [] |
|
|
| async for item in iterate_in_threadpool(iterator): |
| if token not in stream_sessions: |
| await websocket.send_text(json.dumps({"type": "error", "message": "Stream canceled."})) |
| break |
| sr, chunk_np = item |
| full_audio_buffer.append(chunk_np) |
| chunk_duration = chunk_np.shape[0] / sr |
| total_generated_sec += chunk_duration |
|
|
| progress_val = min(total_generated_sec / max(params["duration_sec"], 1), 0.99) |
| await websocket.send_text(json.dumps({ |
| "type": "progress", |
| "value": progress_val, |
| "text": f"Generated {total_generated_sec:.1f}s", |
| })) |
| await websocket.send_bytes(chunk_np.astype("float32").tobytes()) |
|
|
| await websocket.send_text(json.dumps({ |
| "type": "progress", |
| "value": 1.0, |
| "text": f"Generated {total_generated_sec:.1f}s (Complete)", |
| })) |
|
|
| saved_path = "" |
| download_url = "" |
| if full_audio_buffer: |
| full_audio_np = np.concatenate(full_audio_buffer) |
| wav_path = save_audio_to_wav(48000, full_audio_np, DATA_DIR) |
| saved_path = wav_path |
| if params["output_format"] == "mp3": |
| saved_path = convert_wav_to_mp3(wav_path, DATA_DIR) |
| download_url = f"download/{Path(saved_path).name}" |
|
|
| if saved_path: |
| stream_results[token] = str(saved_path) |
| await websocket.send_text(json.dumps({ |
| "type": "download", |
| "url": download_url, |
| "filename": Path(saved_path).name, |
| })) |
| await websocket.send_text(json.dumps({"type": "complete"})) |
| except Exception as e: |
| await websocket.send_text(json.dumps({"type": "error", "message": str(e)})) |
| finally: |
| await stream_queue.release(request_id, user_id) |
| if token in stream_sessions: |
| del stream_sessions[token] |
| try: |
| await websocket.close() |
| except Exception: |
| pass |
|
|
|
|
| def generate_lyrics(theme, tags, language, api_choice, api_key_input, custom_base_url, custom_model, progress=gr.Progress()): |
| if not theme.strip(): |
| raise gr.Error("Please enter a theme") |
|
|
| progress(0.1, desc="Preparing request...") |
|
|
| if api_choice not in LLM_PRESETS: |
| raise gr.Error(f"Unknown API choice: {api_choice}") |
|
|
| preset = LLM_PRESETS[api_choice] |
|
|
| api_key = api_key_input.strip() if api_key_input and api_key_input.strip() else None |
| if not api_key and preset["env_key"]: |
| api_key = os.environ.get(preset["env_key"]) |
|
|
| if not api_key: |
| raise gr.Error("No API key provided. Please enter your API key in the field above.") |
|
|
| base_url = custom_base_url.strip() if custom_base_url and custom_base_url.strip() else preset["base_url"] |
| model_name = custom_model.strip() if custom_model and custom_model.strip() else preset["default_model"] |
|
|
| language_names = { |
| "en": "English", |
| "zh": "Chinese", |
| "jp": "Japanese", |
| "kr": "Korean", |
| "sp": "Spanish" |
| } |
| lang_name = language_names.get(language, "English") |
|
|
| tags_text = tags.strip() if tags.strip() else "pop, emotional" |
|
|
| prompt = f"""You are a professional songwriter. Generate song lyrics based on the following requirements: |
| |
| **Theme**: {theme} |
| **Music Style/Tags**: {tags_text} |
| **Language**: {lang_name} |
| |
| **Format Requirements** (CRITICAL): |
| 1. Use lowercase for all lyrics text (except structure tags which are in brackets) |
| 2. Include proper song structure tags: [Intro], [Verse], [Prechorus], [Chorus], [Bridge], [Outro] |
| 3. Each structure tag should be on its own line |
| 4. Separate different sections with a blank line (one empty line between sections) |
| 5. NO timestamps like [00:12] - only structure tags allowed |
| 6. Keep lyrics concise and suitable for a 3-4 minute song |
| |
| **Structure Guidelines**: |
| - [Intro]: Optional, 1-2 lines if included |
| - [Verse]: Story-telling part, 4-6 lines, can repeat with different lyrics |
| - [Prechorus]: Optional, 2-4 lines, builds tension before chorus |
| - [Chorus]: Main hook, catchy and repetitive, 4-6 lines |
| - [Bridge]: Optional, provides contrast, 4-6 lines |
| - [Outro]: Closing, 1-2 lines |
| |
| **Example Format**: |
| ``` |
| [Intro] |
| |
| [Verse] |
| the sun creeps in across the floor |
| i hear the traffic outside the door |
| the coffee pot begins to hiss |
| it is another morning just like this |
| |
| [Chorus] |
| every day the light returns |
| every day the fire burns |
| we keep on walking down this street |
| moving to the same steady beat |
| ``` |
| |
| Now generate lyrics in {lang_name} based on the theme "{theme}" with style "{tags_text}". |
| Output ONLY the lyrics with structure tags, no explanations. |
| """ |
|
|
| try: |
| if preset["api_type"] == "gemini": |
| progress(0.3, desc=f"Connecting to {preset['name']}...") |
| client = genai.Client(api_key=api_key) |
| progress(0.5, desc=f"Generating lyrics with {preset['name']}...") |
| response = client.models.generate_content( |
| model=model_name, |
| contents=[ |
| types.Content( |
| role='user', |
| parts=[types.Part(text=prompt)] |
| ) |
| ], |
| config=types.GenerateContentConfig( |
| temperature=0.8, |
| max_output_tokens=2000 |
| ) |
| ) |
| generated_lyrics = response.text.strip() |
| elif preset["api_type"] == "openai": |
| progress(0.3, desc=f"Connecting to {preset['name']}...") |
| if base_url: |
| client = OpenAI(api_key=api_key, base_url=base_url) |
| else: |
| client = OpenAI(api_key=api_key) |
| progress(0.5, desc=f"Generating lyrics with {preset['name']}...") |
| response = client.chat.completions.create( |
| model=model_name, |
| messages=[ |
| {"role": "system", "content": "You are a professional songwriter who creates well-structured lyrics."}, |
| {"role": "user", "content": prompt} |
| ], |
| temperature=0.8, |
| max_tokens=2000 |
| ) |
| generated_lyrics = response.choices[0].message.content.strip() |
| else: |
| raise gr.Error(f"Unknown API type: {preset['api_type']}") |
|
|
| progress(0.9, desc="Processing response...") |
| if generated_lyrics.startswith("```"): |
| lines = generated_lyrics.split("\n") |
| generated_lyrics = "\n".join(lines[1:-1]) if len(lines) > 2 else generated_lyrics |
| generated_lyrics = process_lyrics_correct(generated_lyrics) |
| progress(1.0, desc="Done!") |
| return generated_lyrics |
| except Exception as e: |
| raise gr.Error(f"Lyrics generation error: {str(e)}") |
|
|
|
|
| STREAMING_JS = """ |
| async (token) => { |
| if (!token) { |
| console.error("No token provided"); |
| alert("Session initialization failed."); |
| return; |
| } |
| |
| return new Promise((resolve, reject) => { |
| const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; |
| const host = window.location.host; |
| const wsUrl = `${protocol}//${host}/ws_stream/${token}`; |
| const ws = new WebSocket(wsUrl); |
| ws.binaryType = "arraybuffer"; |
| |
| const AudioContext = window.AudioContext || window.webkitAudioContext; |
| const ctx = new AudioContext(); |
| let nextTime = 0; |
| let sampleRate = 48000; |
| |
| ws.onopen = () => { |
| if (ctx.state === "suspended") ctx.resume(); |
| }; |
| |
| ws.onmessage = (event) => { |
| const data = event.data; |
| if (typeof data === "string") { |
| const msg = JSON.parse(data); |
| if (msg.type === "queue") { |
| const progressText = document.getElementById("stream-progress-text"); |
| if (progressText) progressText.innerText = msg.message; |
| } else if (msg.type === "config") { |
| sampleRate = msg.sample_rate; |
| const progressContainer = document.getElementById("stream-progress-container"); |
| const controlsContainer = document.getElementById("stream-controls-container"); |
| if (progressContainer) progressContainer.style.display = "block"; |
| if (controlsContainer) controlsContainer.style.display = "flex"; |
| } else if (msg.type === "progress") { |
| const progressBar = document.getElementById("stream-progress-bar"); |
| const progressText = document.getElementById("stream-progress-text"); |
| if (progressBar) progressBar.style.width = (msg.value * 100) + "%"; |
| if (progressText) progressText.innerText = msg.text; |
| } else if (msg.type === "download") { |
| const downloadLink = document.getElementById("stream-download-link"); |
| if (downloadLink) { |
| downloadLink.href = msg.url; |
| downloadLink.download = msg.filename; |
| downloadLink.style.display = "inline-block"; |
| downloadLink.innerText = "Download Full Audio"; |
| } |
| } else if (msg.type === "error") { |
| alert("Error: " + msg.message); |
| ws.close(); |
| reject(msg.message); |
| } else if (msg.type === "complete") { |
| ws.close(); |
| resolve(token); |
| } |
| return; |
| } |
| |
| if (data instanceof ArrayBuffer) { |
| const float32 = new Float32Array(data); |
| if (float32.length === 0) return; |
| |
| const buffer = ctx.createBuffer(1, float32.length, sampleRate); |
| buffer.copyToChannel(float32, 0); |
| const source = ctx.createBufferSource(); |
| source.buffer = buffer; |
| source.connect(ctx.destination); |
| const now = ctx.currentTime; |
| if (nextTime < now) nextTime = now + 0.05; |
| source.start(nextTime); |
| nextTime += buffer.duration; |
| } |
| }; |
| |
| window.toggleStreamPlayback = () => { |
| if (ctx.state === "running") { |
| ctx.suspend(); |
| document.getElementById("stream-play-btn").innerText = "Resume"; |
| } else if (ctx.state === "suspended") { |
| ctx.resume(); |
| document.getElementById("stream-play-btn").innerText = "Pause"; |
| } |
| }; |
| |
| ws.onerror = (e) => { |
| console.error("WS Error", e); |
| reject(e); |
| }; |
| ws.onclose = () => {}; |
| }); |
| } |
| """ |
|
|
|
|
| def create_ui(): |
| with gr.Blocks(title="HeartMuLa Music Generation") as demo: |
| gr.Markdown("# HeartMuLa Music Generation") |
| gr.Markdown("Generate music from lyrics and style tags") |
|
|
| with gr.Tabs(): |
| with gr.Tab("Music Generation"): |
| with gr.Row(): |
| with gr.Column(): |
| lyrics = gr.Textbox( |
| label="Lyrics", |
| lines=15, |
| value=EXAMPLE_LYRICS, |
| placeholder="Enter lyrics here..." |
| ) |
|
|
| format_btn = gr.Button("Format Lyrics", size="sm") |
|
|
| gr.Markdown("### Tags") |
| tags = gr.Textbox( |
| label="Selected Tags (comma-separated)", |
| value=EXAMPLE_TAGS, |
| placeholder="e.g., piano,happy,pop", |
| lines=2 |
| ) |
|
|
| tag_checkboxes = [] |
| with gr.Accordion("Tag Categories (Click to Expand)", open=False): |
| with gr.Row(): |
| with gr.Column(): |
| t1 = gr.CheckboxGroup(choices=TAG_DATA["Gender"], label="Gender") |
| tag_checkboxes.append(t1) |
| t2 = gr.CheckboxGroup(choices=TAG_DATA["Genre"], label="Genre") |
| tag_checkboxes.append(t2) |
| with gr.Column(): |
| t3 = gr.CheckboxGroup(choices=TAG_DATA["Instrument"], label="Instrument") |
| tag_checkboxes.append(t3) |
| t4 = gr.CheckboxGroup(choices=TAG_DATA["Mood"], label="Mood") |
| tag_checkboxes.append(t4) |
| with gr.Column(): |
| t5 = gr.CheckboxGroup(choices=TAG_DATA["Scene"], label="Scene") |
| tag_checkboxes.append(t5) |
| t6 = gr.CheckboxGroup(choices=TAG_DATA["Singer Timbre"], label="Singer Timbre") |
| tag_checkboxes.append(t6) |
| with gr.Column(): |
| t7 = gr.CheckboxGroup(choices=TAG_DATA["Topic"], label="Topic") |
| tag_checkboxes.append(t7) |
|
|
| with gr.Row(): |
| cfg_scale = gr.Slider(0.0, 3.0, value=1.5, step=0.1, label="CFG Scale") |
| duration = gr.Slider(10, 300, value=180, step=10, label="Duration (sec)") |
|
|
| with gr.Row(): |
| temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") |
| topk = gr.Slider(1, 100, value=50, step=1, label="Top-K") |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| backend = gr.Dropdown( |
| choices=[("HF Pipeline", "hf"), ("ExLlamaV2 (Experimental)", "exllama_v2")], |
| value="hf", |
| label="Backend" |
| ) |
| version = gr.Dropdown( |
| choices=["3B", "7B", "RL-oss-3B-20260123"], |
| value="3B", |
| label="Model Version" |
| ) |
| codec_version = gr.Dropdown( |
| choices=["oss", "oss-20260123"], |
| value="oss", |
| label="Codec Version" |
| ) |
| quant_mode = gr.Dropdown( |
| choices=[("None", "none"), ("4-bit (NF4/FP4)", "4bit"), ("8-bit", "8bit")], |
| value="none", |
| label="Quantization" |
| ) |
| keep_model_loaded = gr.Checkbox(value=True, label="Keep Model Loaded") |
| offload_mode = gr.Dropdown( |
| choices=["auto", "aggressive"], |
| value="auto", |
| label="Offload Mode" |
| ) |
| output_format = gr.Radio( |
| choices=[("WAV", "wav"), ("MP3", "mp3")], |
| value="wav", |
| label="Output Format" |
| ) |
| chunk_frames = gr.Slider(5, 100, value=20, step=1, label="Streaming Chunk Frames") |
|
|
| generate_btn = gr.Button("Generate Music", variant="primary", size="lg") |
| stream_btn = gr.Button("Generate Music (Streaming)", variant="primary", size="lg") |
| state_token = gr.Textbox(visible=False, label="Stream Token") |
|
|
| with gr.Column(): |
| with gr.Accordion("Usage Notice", open=True): |
| gr.Markdown(""" |
| ### Lyrics Format Requirements |
| |
| **Automatic Processing:** |
| 1. All text will be converted to **lowercase** |
| 2. Timestamps (e.g., [00:12]) will be **automatically removed** |
| 3. Leading/trailing whitespace on each line will be **stripped** |
| 4. Leading/trailing empty lines will be **removed** |
| 5. Multiple consecutive empty lines (3+) will be **collapsed to 2** |
| |
| **Recommended Format:** |
| - Use standard song structure tags: `[Intro]`, `[Verse]`, `[Chorus]`, `[Bridge]`, `[Outro]`, etc. |
| - Separate sections with **blank lines** |
| - Case doesn't matter (will be auto-converted) |
| |
| **Example:** |
| ``` |
| [Intro] |
| |
| [Verse] |
| The sun creeps in across the floor |
| I hear the traffic outside the door |
| |
| [Chorus] |
| Every day the light returns |
| Every day the fire burns |
| ``` |
| |
| --- |
| |
| ### Tags Format |
| - Use **commas** to separate multiple tags: `piano,happy,pop` |
| - Tags influence the style and mood of the generated music |
| - Select from categories below or type directly |
| """) |
|
|
| output_audio = gr.Audio(label="Generated Music", type="filepath") |
| gr.HTML(""" |
| <div id="stream-progress-container" style="width: 100%; background-color: #f0f0f0; border-radius: 8px; margin-top: 10px; display: none; overflow: hidden;"> |
| <div id="stream-progress-bar" style="width: 0%; height: 24px; background-color: #4CAF50; transition: width 0.2s ease-in-out;"></div> |
| </div> |
| <div id="stream-progress-text" style="text-align: center; font-family: monospace; margin-top: 5px; color: #666;"></div> |
| <div id="stream-controls-container" style="display: none; justify-content: center; gap: 10px; margin-top: 15px;"> |
| <button id="stream-play-btn" onclick="window.toggleStreamPlayback()" style="padding: 8px 16px; border-radius: 4px; border: 1px solid #ccc; background: white; cursor: pointer;">Pause</button> |
| </div> |
| <a id="stream-download-link" style="display:none; margin-top: 10px;" target="_blank"></a> |
| """) |
|
|
| for cb in tag_checkboxes: |
| cb.change(fn=update_tag_string, inputs=tag_checkboxes, outputs=tags) |
|
|
| format_btn.click( |
| fn=process_lyrics_correct, |
| inputs=[lyrics], |
| outputs=[lyrics] |
| ) |
|
|
| generate_btn.click( |
| fn=generate, |
| inputs=[ |
| lyrics, |
| tags, |
| cfg_scale, |
| duration, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| ], |
| outputs=[output_audio] |
| ) |
|
|
| stream_btn.click( |
| fn=prepare_streaming_session, |
| inputs=[ |
| lyrics, |
| tags, |
| cfg_scale, |
| duration, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| output_format, |
| chunk_frames, |
| ], |
| outputs=[state_token] |
| ).then( |
| fn=load_audio_from_result, |
| inputs=[state_token], |
| outputs=[output_audio], |
| js=STREAMING_JS, |
| ) |
|
|
| with gr.Tab("Lyrics Generation"): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Generate Lyrics with AI") |
|
|
| api_selector = gr.Radio( |
| choices=[ |
| ("Google Gemini", "gemini"), |
| ("OpenAI", "openai"), |
| ("DeepSeek", "deepseek"), |
| ("Custom (OpenAI-compatible)", "custom") |
| ], |
| value="gemini", |
| label="Select LLM Provider" |
| ) |
|
|
| with gr.Accordion("API Configuration", open=True): |
| api_key_input = gr.Textbox( |
| label="API Key (Required)", |
| type="password", |
| placeholder="Enter your API key or set environment variable", |
| info="Will use environment variable if not provided here" |
| ) |
|
|
| custom_base_url = gr.Textbox( |
| label="Custom Base URL (Optional)", |
| placeholder="e.g., https://api.your-provider.com/v1", |
| info="Leave empty to use default. For custom providers only.", |
| visible=False |
| ) |
|
|
| custom_model = gr.Textbox( |
| label="Model Name (Optional)", |
| placeholder="e.g., gpt-4o, deepseek-chat", |
| info="Leave empty to use recommended default", |
| visible=False |
| ) |
|
|
| def update_custom_fields(choice): |
| if choice == "custom": |
| return gr.update(visible=True), gr.update(visible=True) |
| if choice == "deepseek": |
| return gr.update(visible=False), gr.update(visible=True) |
| return gr.update(visible=False), gr.update(visible=False) |
|
|
| api_selector.change( |
| fn=update_custom_fields, |
| inputs=[api_selector], |
| outputs=[custom_base_url, custom_model] |
| ) |
|
|
| theme_input = gr.Textbox( |
| label="Theme", |
| placeholder="e.g., Love lost in the city, Dreams and hope, Rainy day memories...", |
| lines=2 |
| ) |
|
|
| tags_gen = gr.Textbox( |
| label="Music Style/Tags", |
| placeholder="e.g., piano, melancholy, pop", |
| value="pop,emotional" |
| ) |
|
|
| language_select = gr.Radio( |
| choices=[ |
| ("English", "en"), |
| ("中文 (Chinese)", "zh"), |
| ("日本語 (Japanese)", "jp"), |
| ("한국어 (Korean)", "kr"), |
| ("Español (Spanish)", "sp") |
| ], |
| value="en", |
| label="Language" |
| ) |
|
|
| generate_lyrics_btn = gr.Button( |
| "Generate Lyrics", |
| variant="primary", |
| size="lg" |
| ) |
|
|
| with gr.Column(): |
| with gr.Accordion("How to Use", open=True): |
| gr.Markdown(""" |
| ### How to Generate Lyrics |
| |
| **Theme**: Describe your song's story or emotion |
| - Examples: "Lost love in Tokyo", "Overcoming obstacles", "Summer road trip" |
| |
| **Music Style/Tags**: Define mood and genre |
| - Examples: "piano,melancholy,ballad", "upbeat,electronic,dance" |
| |
| **Tips** |
| - Generated lyrics follow standard song structure ([Intro], [Verse], [Chorus], etc.) |
| - Edit lyrics before using for music generation |
| - Be specific with themes for better results |
| |
| """) |
|
|
| generated_lyrics_output = gr.Textbox( |
| label="Generated Lyrics", |
| lines=20, |
| placeholder="Generated lyrics will appear here...", |
| interactive=False |
| ) |
|
|
| copy_to_music_gen = gr.Button( |
| "Copy to Music Generation Tab", |
| size="sm" |
| ) |
|
|
| generate_lyrics_btn.click( |
| fn=generate_lyrics, |
| inputs=[theme_input, tags_gen, language_select, api_selector, api_key_input, custom_base_url, custom_model], |
| outputs=[generated_lyrics_output] |
| ) |
|
|
| def copy_lyrics(lyrics_text): |
| return lyrics_text |
|
|
| copy_to_music_gen.click( |
| fn=copy_lyrics, |
| inputs=[generated_lyrics_output], |
| outputs=[lyrics] |
| ) |
|
|
| with gr.Tab("Lyrics Transcription"): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Transcribe or Translate Lyrics from Audio") |
|
|
| audio_input = gr.Audio( |
| label="Audio Input", |
| type="filepath" |
| ) |
|
|
| task_select = gr.Radio( |
| choices=[ |
| ("Transcribe (Original Language)", "transcribe"), |
| ("Translate to English", "translate") |
| ], |
| value="transcribe", |
| label="Task" |
| ) |
|
|
| max_new_tokens = gr.Slider( |
| 64, 512, value=256, step=16, label="Max New Tokens" |
| ) |
| num_beams = gr.Slider( |
| 1, 5, value=2, step=1, label="Beam Search" |
| ) |
| transcribe_temperature = gr.Slider( |
| 0.0, 0.8, value=0.2, step=0.1, label="Temperature" |
| ) |
|
|
| transcribe_btn = gr.Button( |
| "Run Transcription", |
| variant="primary", |
| size="lg" |
| ) |
|
|
| use_generated_audio = gr.Button( |
| "Use Generated Music", |
| size="sm" |
| ) |
|
|
| with gr.Column(): |
| with gr.Accordion("Notes", open=True): |
| gr.Markdown(""" |
| ### Notes |
| - Best results come from **vocals-only** stems. |
| - If you pass full mixes, consider source separation first. |
| - The HeartTranscriptor model will auto-download on first use. |
| """) |
|
|
| transcription_output = gr.Textbox( |
| label="Transcription Result", |
| lines=18, |
| placeholder="Transcribed lyrics will appear here...", |
| interactive=False |
| ) |
|
|
| transcribe_btn.click( |
| fn=transcribe_audio, |
| inputs=[audio_input, task_select, max_new_tokens, num_beams, transcribe_temperature], |
| outputs=[transcription_output] |
| ) |
|
|
| use_generated_audio.click( |
| fn=lambda x: x, |
| inputs=[output_audio], |
| outputs=[audio_input] |
| ) |
|
|
| return demo |
|
|
|
|
| def build_app(): |
| demo = create_ui() |
| demo.queue(max_size=8) |
| app = FastAPI() |
| app.mount("/download", StaticFiles(directory=str(DATA_DIR)), name="download") |
| app.add_api_websocket_route("/ws_stream/{token}", websocket_stream_endpoint) |
| app = gr.mount_gradio_app(app, demo, path="/", allowed_paths=[str(DATA_DIR)]) |
| return app |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_path", type=str, default="./ckpt") |
| parser.add_argument("--port", type=int, default=8888) |
| parser.add_argument("--host", type=str, default="0.0.0.0") |
| args = parser.parse_args() |
|
|
| MODEL_PATH = args.model_path |
|
|
| app = build_app() |
|
|
| import uvicorn |
| uvicorn.run(app, host=args.host, port=args.port) |
|
|