| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| import sys |
| import os |
| import uuid |
| import subprocess |
| import tempfile |
| import wave |
| from pathlib import Path |
| from typing import Dict, Any, Iterator, Optional, Tuple, TYPE_CHECKING |
|
|
| import numpy as np |
|
|
| |
| try: |
| import spaces |
| except Exception: |
| spaces = None |
|
|
| IS_SPACE = spaces is not None and os.environ.get("SPACE_ID") |
|
|
| _gpu_duration_env = os.environ.get("GPU_MAX_DURATION") |
| if _gpu_duration_env is None: |
| aoti_env = os.environ.get("ENABLE_AOTI") |
| enable_aoti_default = aoti_env is None or aoti_env.strip().lower() in ("1", "true", "yes", "y", "on") |
| GPU_MAX_DURATION = 600 if IS_SPACE and enable_aoti_default else 100 |
| else: |
| GPU_MAX_DURATION = int(_gpu_duration_env) |
|
|
|
|
| def _env_bool(name: str) -> Optional[bool]: |
| val = os.environ.get(name) |
| if val is None: |
| return None |
| return val.strip().lower() in ("1", "true", "yes", "y", "on") |
|
|
|
|
| _default_keep_model_loaded_env = _env_bool("KEEP_MODEL_LOADED_DEFAULT") |
| if _default_keep_model_loaded_env is None: |
| DEFAULT_KEEP_MODEL_LOADED = not IS_SPACE |
| else: |
| DEFAULT_KEEP_MODEL_LOADED = _default_keep_model_loaded_env |
|
|
|
|
| def _gpu_guard(fn): |
| if spaces is None: |
| return fn |
| if GPU_MAX_DURATION > 0: |
| try: |
| return spaces.GPU(fn, duration=GPU_MAX_DURATION) |
| except TypeError: |
| return spaces.GPU(fn) |
| return spaces.GPU(fn) |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) |
|
|
| |
| os.environ.setdefault("GRADIO_SSR_MODE", "0") |
| |
| os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| if TYPE_CHECKING: |
| from heartlib import HeartMuLaGenPipeline |
| import gradio as gr |
| import re |
| try: |
| from google import genai |
| from google.genai import types |
| except Exception: |
| genai = None |
| types = None |
| from openai import OpenAI |
| from transformers import BitsAndBytesConfig |
|
|
| |
|
|
| |
| MODEL_PATH = os.environ.get("MODEL_PATH", "./ckpt") |
|
|
| |
| LLM_PRESETS = { |
| "gemini": { |
| "name": "Google Gemini", |
| "api_type": "gemini", |
| "default_model": "gemini-2.0-flash-lite", |
| "env_key": "GEMINI_API_KEY", |
| "base_url": None, |
| }, |
| "openai": { |
| "name": "OpenAI", |
| "api_type": "openai", |
| "default_model": "gpt-4o-mini", |
| "env_key": "OPENAI_API_KEY", |
| "base_url": None, |
| }, |
| "deepseek": { |
| "name": "DeepSeek", |
| "api_type": "openai", |
| "default_model": "deepseek-chat", |
| "env_key": "DEEPSEEK_API_KEY", |
| "base_url": "https://api.deepseek.com", |
| }, |
| "custom": { |
| "name": "Custom OpenAI-Compatible", |
| "api_type": "openai", |
| "default_model": "custom-model", |
| "env_key": None, |
| "base_url": None, |
| } |
| } |
|
|
| |
| EXAMPLE_LYRICS = """[Intro] |
| |
| [Verse] |
| The sun creeps in across the floor |
| I hear the traffic outside the door |
| The coffee pot begins to hiss |
| It is another morning just like this |
| |
| [Prechorus] |
| The world keeps spinning round and round |
| Feet are planted on the ground |
| I find my rhythm in the sound |
| |
| [Chorus] |
| Every day the light returns |
| Every day the fire burns |
| We keep on walking down this street |
| Moving to the same steady beat |
| It is the ordinary magic that we meet |
| |
| [Verse] |
| The hours tick deeply into noon |
| Chasing shadows,chasing the moon |
| Work is done and the lights go low |
| Watching the city start to glow |
| |
| [Bridge] |
| It is not always easy,not always bright |
| Sometimes we wrestle with the night |
| But we make it to the morning light |
| |
| [Chorus] |
| Every day the light returns |
| Every day the fire burns |
| We keep on walking down this street |
| Moving to the same steady beat |
| |
| [Outro] |
| Just another day |
| Every single day""" |
|
|
| EXAMPLE_TAGS = "piano,happy" |
|
|
| |
| TAG_DATA = { |
| "Gender": [ |
| "Male", "Female" |
| ], |
| "Genre": [ |
| "Pop", "Folk", "Ballad", "Electronic", "Rock", "Acoustic", "R&B", |
| "Indie", "Dance", "Indie Pop", "J-Pop", "Hip-Hop", "Country", |
| "Latin", "Alternative", "Christian", "Cantopop", "Gospel", "Soul", |
| "Mandopop" |
| ], |
| "Instrument": [ |
| "Drums", "Piano", "Guitar", "Strings", "Synthesizer", "Bass", |
| "Acoustic Guitar", "Keyboard", "Electronic Drums", "Vocals", |
| "Drum Machine", "Electric Guitar", "Percussion", "Beat", |
| "Orchestra", "Saxophone", "Accordion", "Voice", "String", "Vocal" |
| ], |
| "Mood": [ |
| "Melancholy", "Romantic", "Energetic", "Hopeful", "Dreamy", |
| "Relaxed", "Sad", "Calm", "Cheerful", "Reflective", "Emotional", |
| "Joyful", "Sentimental", "Uplifting", "Warm", "Peaceful", "Upbeat", |
| "Gentle", "Nostalgic", "Epic" |
| ], |
| "Scene": [ |
| "Driving", "Road Trip", "Cafe", "Relaxing", "Wedding", "Meditation", |
| "Workout", "Walking", "Alone", "Travel", "Reflection", "Rainy Day", |
| "Night", "Church", "Coffee Shop", "Gym", "Gaming", "Study", |
| "Dating", "Date" |
| ], |
| "Singer Timbre": [ |
| "Soft", "Clear", "Warm", "Gentle", "Smooth", "Sweet", "Emotional", |
| "Mellow", "Powerful", "Youthful", "Bright", "Rough", "Raspy", |
| "Melodic", "Deep", "Soulful", "Strong", "Energetic", "Breathy", |
| "Passionate" |
| ], |
| "Topic": [ |
| "Love", "Relationship", "Hope", "Longing", "Loss", "Heartbreak", |
| "Memory", "Reflection", "Life", "Faith", "Regret", "Freedom", |
| "Breakup", "Nature", "Loneliness", "Dreams", "Nostalgia", "Romance", |
| "Friendship", "Youth" |
| ] |
| } |
|
|
| DATA_DIR = Path(os.environ.get("HEARTMULA_DATA_DIR", os.path.join(tempfile.gettempdir(), "heartmula_stream"))) |
| DATA_DIR.mkdir(parents=True, exist_ok=True) |
| print(f"DATA_DIR = {DATA_DIR}") |
|
|
| |
| offload_path = "/data-nvme/zerogpu-offload" |
| try: |
| if os.path.exists(offload_path): |
| for name in os.listdir(offload_path): |
| os.system(f"rm -rf {offload_path}/{name}") |
| except Exception as e: |
| print(f"WARN: failed to clear ZeroGPU offload cache: {e}") |
|
|
| GRADIO_QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "24")) |
| GRADIO_DEFAULT_CONCURRENCY = int(os.environ.get("GRADIO_DEFAULT_CONCURRENCY", "1")) |
| GPU_CONCURRENCY_LIMIT = int(os.environ.get("GRADIO_GPU_CONCURRENCY", "1")) |
| DEFAULT_DURATION_SEC = int(os.environ.get("DEFAULT_DURATION_SEC", "60" if IS_SPACE else "180")) |
| DEFAULT_QUANT_MODE = os.environ.get("DEFAULT_QUANT_MODE", "4bit" if IS_SPACE else "none") |
| DEFAULT_OFFLOAD_MODE = os.environ.get("DEFAULT_OFFLOAD_MODE", "aggressive" if IS_SPACE else "auto") |
| DEFAULT_GENERATION_MODE = os.environ.get( |
| "DEFAULT_GENERATION_MODE", |
| "Accelerated" if IS_SPACE else "Original (No Acceleration)", |
| ) |
| DEFAULT_SPEED_SUBMODE = os.environ.get("DEFAULT_SPEED_SUBMODE", "Standard") |
| DEFAULT_PRESET = os.environ.get("DEFAULT_PRESET", "ZeroGPU AOTI FP8" if IS_SPACE else "Balanced") |
| MODEL_DETOKENIZE_INTERVAL_SEC = float(os.environ.get("MODEL_DETOKENIZE_INTERVAL_SEC", "29.76")) |
| AUDIO_SAMPLE_RATE = int(os.environ.get("AUDIO_SAMPLE_RATE", "48000")) |
| FRAME_MS = 80.0 |
| BLOCK_FRAMES = max(1, int(round((MODEL_DETOKENIZE_INTERVAL_SEC * 1000.0) / FRAME_MS))) |
| BLOCK_SAMPLES = max(1, int(round(MODEL_DETOKENIZE_INTERVAL_SEC * AUDIO_SAMPLE_RATE))) |
| MIN_BUFFER_BLOCKS = int(os.environ.get("MIN_BUFFER_BLOCKS", "1")) |
| PREFETCH_BLOCKS = int(os.environ.get("PREFETCH_BLOCKS", "2")) |
| MAX_QUEUE_BLOCKS = int(os.environ.get("MAX_QUEUE_BLOCKS", "4")) |
| ALLOW_STREAMING_ZERO_GPU = _env_bool("ALLOW_STREAMING_ZERO_GPU") |
| STREAMING_ALLOWED = (not IS_SPACE) or bool(ALLOW_STREAMING_ZERO_GPU) |
|
|
| PRESET_CONFIGS = { |
| "ZeroGPU Safe": { |
| "duration": 60, |
| "quant_mode": "4bit", |
| "offload_mode": "aggressive", |
| "keep_model_loaded": False, |
| "temperature": 1.0, |
| "topk": 50, |
| "cfg_scale": 1.5, |
| "generation_mode": "Accelerated", |
| "speed_submode": "Standard", |
| }, |
| "ZeroGPU AOTI FP8": { |
| "duration": 60, |
| "quant_mode": "fp8", |
| "offload_mode": "aggressive", |
| "keep_model_loaded": False, |
| "temperature": 1.0, |
| "topk": 50, |
| "cfg_scale": 1.5, |
| "generation_mode": "Accelerated", |
| "speed_submode": "Standard", |
| }, |
| "Balanced": { |
| "duration": 120, |
| "quant_mode": "4bit" if IS_SPACE else "none", |
| "offload_mode": "auto", |
| "keep_model_loaded": DEFAULT_KEEP_MODEL_LOADED, |
| "temperature": 1.0, |
| "topk": 50, |
| "cfg_scale": 1.5, |
| "generation_mode": "Accelerated", |
| "speed_submode": "Standard", |
| }, |
| "Quality": { |
| "duration": 180, |
| "quant_mode": "none", |
| "offload_mode": "auto", |
| "keep_model_loaded": DEFAULT_KEEP_MODEL_LOADED, |
| "temperature": 0.9, |
| "topk": 50, |
| "cfg_scale": 2.0, |
| "generation_mode": "Original (No Acceleration)", |
| "speed_submode": "Standard", |
| }, |
| } |
| if DEFAULT_PRESET not in PRESET_CONFIGS: |
| DEFAULT_PRESET = "Balanced" |
| _default_preset_config = PRESET_CONFIGS[DEFAULT_PRESET] |
| if "DEFAULT_DURATION_SEC" not in os.environ: |
| DEFAULT_DURATION_SEC = _default_preset_config["duration"] |
| if "DEFAULT_QUANT_MODE" not in os.environ: |
| DEFAULT_QUANT_MODE = _default_preset_config["quant_mode"] |
| if "DEFAULT_OFFLOAD_MODE" not in os.environ: |
| DEFAULT_OFFLOAD_MODE = _default_preset_config["offload_mode"] |
| if "DEFAULT_GENERATION_MODE" not in os.environ: |
| DEFAULT_GENERATION_MODE = _default_preset_config["generation_mode"] |
| if "DEFAULT_SPEED_SUBMODE" not in os.environ: |
| DEFAULT_SPEED_SUBMODE = _default_preset_config["speed_submode"] |
|
|
|
|
| class ModelManager: |
| def __init__(self, model_path: str, use_deepspeed_override: Optional[bool] = None): |
| import torch |
| from heartlib import HeartMuLaGenPipeline |
|
|
| self.model_path = model_path |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
| self._gen_pipes: Dict[Tuple[str, str, str], "HeartMuLaGenPipeline"] = {} |
| if use_deepspeed_override is None: |
| self.use_deepspeed = os.getenv("USE_DEEPSPEED_INFERENCE", "0").lower() in ("1", "true", "yes") |
| else: |
| self.use_deepspeed = use_deepspeed_override |
| self.ds_inference_config = self._make_ds_inference_config() |
| self._HeartMuLaGenPipeline = HeartMuLaGenPipeline |
|
|
| def _make_ds_inference_config(self) -> Dict[str, Any]: |
| if not self.use_deepspeed: |
| return {} |
| mp_size = int(os.getenv("DEEPSPEED_TP_SIZE", os.getenv("WORLD_SIZE", "1"))) |
| replace_method = os.getenv("DEEPSPEED_REPLACE_METHOD", "auto") |
| kernel_inject = os.getenv("DEEPSPEED_KERNEL_INJECT", "1").lower() in ("1", "true", "yes") |
| return { |
| "mp_size": mp_size, |
| "dtype": self.dtype, |
| "replace_method": replace_method, |
| "replace_with_kernel_inject": kernel_inject, |
| } |
|
|
| def _make_bnb_config(self, quant_mode: str) -> Optional[BitsAndBytesConfig]: |
| import torch |
| if quant_mode == "none": |
| return None |
| if self.device.type != "cuda": |
| raise gr.Error("Quantization requires CUDA.") |
| if quant_mode == "fp8": |
| return None |
| if quant_mode == "4bit": |
| quant_type = "nf4" |
| try: |
| major, _ = torch.cuda.get_device_capability() |
| if major >= 10: |
| quant_type = "fp4" |
| except Exception: |
| pass |
| return BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type=quant_type, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
| if quant_mode == "8bit": |
| return BitsAndBytesConfig(load_in_8bit=True) |
| raise gr.Error(f"Unknown quant mode: {quant_mode}") |
|
|
| def get_gen_pipeline(self, version: str, codec_version: str, quant_mode: str) -> "HeartMuLaGenPipeline": |
| key = (version, codec_version, quant_mode) |
| if key not in self._gen_pipes: |
| bnb_config = self._make_bnb_config(quant_mode) |
| self._gen_pipes[key] = self._HeartMuLaGenPipeline.from_pretrained( |
| self.model_path, |
| device=self.device, |
| dtype=self.dtype, |
| version=version, |
| codec_version=codec_version, |
| bnb_config=bnb_config, |
| torchao_quantize=quant_mode == "fp8", |
| lazy_load=True, |
| use_deepspeed=self.use_deepspeed, |
| ds_inference_config=self.ds_inference_config, |
| ) |
| return self._gen_pipes[key] |
|
|
|
|
| model_managers: Dict[str, ModelManager] = {} |
|
|
|
|
| def get_model_manager(use_acceleration: bool) -> ModelManager: |
| key = "accelerated" if use_acceleration else "original" |
| if key not in model_managers: |
| os.makedirs(MODEL_PATH, exist_ok=True) |
| download_models_if_needed(MODEL_PATH) |
| use_deepspeed_override = None if use_acceleration else False |
| model_managers[key] = ModelManager(MODEL_PATH, use_deepspeed_override=use_deepspeed_override) |
| return model_managers[key] |
|
|
|
|
| def update_tag_string(*args): |
| """ |
| Collects selected tags from all categories and joins them. |
| args: list of lists (selections from each CheckboxGroup) |
| """ |
| all_tags = [] |
| for selection in args: |
| if selection: |
| if isinstance(selection, list): |
| all_tags.extend(selection) |
| else: |
| all_tags.append(selection) |
| |
| seen = set() |
| unique_tags = [] |
| for t in all_tags: |
| if t not in seen: |
| unique_tags.append(t) |
| seen.add(t) |
| return ",".join(unique_tags) |
|
|
|
|
| def process_lyrics_correct(content): |
| """ |
| Correct lyrics processing logic aligned with training data. |
| 1. Removes timestamps [xx:xx]. |
| 2. Split lines and strip whitespace from each line. |
| 3. Remove leading/trailing empty lines. |
| 4. Collapse multiple newlines (3 or more) into 2. |
| """ |
| |
| content = content.lower() |
|
|
| |
| content = re.sub(r'\[[^\]]*\d{1,2}:\d{2}[^\]]*\]', '', content) |
|
|
| |
| lines = [line.strip() for line in content.split('\n')] |
|
|
| |
| while lines and lines[0] == '': |
| lines.pop(0) |
|
|
| |
| while lines and lines[-1] == '': |
| lines.pop() |
|
|
| |
| content = '\n'.join(lines) |
|
|
| |
| content = re.sub(r'\n{3,}', '\n\n', content) |
|
|
| return content |
|
|
|
|
| def save_audio_to_wav(sample_rate: int, audio_np: np.ndarray, output_dir: Path) -> str: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| unique_id = str(uuid.uuid4()) |
| wav_path = output_dir / f"{unique_id}.wav" |
| audio_int16 = (audio_np * 32767).astype(np.int16) |
| with wave.open(str(wav_path), "wb") as wav_file: |
| wav_file.setnchannels(1) |
| wav_file.setsampwidth(2) |
| wav_file.setframerate(sample_rate) |
| wav_file.writeframes(audio_int16.tobytes()) |
| return str(wav_path) |
|
|
|
|
| def convert_wav_to_mp3(wav_path: str, output_dir: Path) -> str: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| unique_id = str(uuid.uuid4()) |
| mp3_path = output_dir / f"{unique_id}.mp3" |
| cmd = [ |
| "ffmpeg", |
| "-y", |
| "-i", |
| wav_path, |
| "-codec:a", |
| "libmp3lame", |
| "-qscale:a", |
| "2", |
| str(mp3_path), |
| ] |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| return str(mp3_path) |
|
|
|
|
| def check_models_exist(ckpt_dir): |
| """Check if all required models exist""" |
| required_files = [ |
| os.path.join(ckpt_dir, "gen_config.json"), |
| os.path.join(ckpt_dir, "tokenizer.json"), |
| os.path.join(ckpt_dir, "HeartCodec-oss"), |
| os.path.join(ckpt_dir, "HeartMuLa-oss-3B"), |
| ] |
|
|
| for path in required_files: |
| if not os.path.exists(path): |
| return False |
| return True |
|
|
|
|
| def download_models_if_needed(ckpt_dir): |
| """Download models from ModelScope if not present""" |
| if check_models_exist(ckpt_dir): |
| print("=" * 50) |
| print(f"✓ Checkpoints found in {ckpt_dir}") |
| print("✓ Skipping download") |
| print("=" * 50) |
| return |
|
|
| print("=" * 50) |
| print("⬇ Starting model download from ModelScope") |
| print("=" * 50) |
| print("") |
|
|
| from modelscope import snapshot_download |
|
|
| |
| print("━" * 50) |
| print("📦 [1/3] Downloading HeartMuLaGen config and tokenizer...") |
| print("━" * 50) |
| snapshot_download('HeartMuLa/HeartMuLaGen', local_dir=ckpt_dir) |
| print("✓ HeartMuLaGen download completed") |
| print("") |
|
|
| |
| print("━" * 50) |
| print("📦 [2/3] Downloading HeartMuLa-oss-3B model...") |
| print("━" * 50) |
| snapshot_download('HeartMuLa/HeartMuLa-oss-3B', |
| local_dir=os.path.join(ckpt_dir, 'HeartMuLa-oss-3B')) |
| print("✓ HeartMuLa-oss-3B download completed") |
| print("") |
|
|
| |
| print("━" * 50) |
| print("📦 [3/3] Downloading HeartCodec-oss model...") |
| print("━" * 50) |
| snapshot_download('HeartMuLa/HeartCodec-oss', |
| local_dir=os.path.join(ckpt_dir, 'HeartCodec-oss')) |
| print("✓ HeartCodec-oss download completed") |
| print("") |
|
|
| print("=" * 50) |
| print("✓ All models downloaded successfully!") |
| print("=" * 50) |
| print("") |
|
|
|
|
| def load_pipeline(model_path, version, codec_version, quant_mode, use_acceleration: bool): |
| """Load HeartMuLa pipeline (lazy)""" |
| manager = get_model_manager(use_acceleration) |
| print(f"Using model from {model_path} on {manager.device}...") |
| return manager.get_gen_pipeline(version, codec_version, quant_mode) |
|
|
|
|
| def _validate_generation_inputs(lyrics: str, tags: str) -> None: |
| if not lyrics.strip(): |
| raise gr.Error("Please enter lyrics") |
| if not tags.strip(): |
| raise gr.Error("Please enter tags") |
|
|
|
|
| def generate( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| use_acceleration, |
| ): |
| """Generate music""" |
| import torch |
| _validate_generation_inputs(lyrics, tags) |
|
|
| max_audio_length_ms = int(duration_sec * 1000) |
|
|
| try: |
| if backend == "exllama_v2": |
| raise gr.Error("ExLlamaV2 backend is not implemented yet.") |
|
|
| pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode, use_acceleration) |
| output_path = os.path.join(DATA_DIR, f"gen_{uuid.uuid4().hex}.wav") |
|
|
| with torch.inference_mode(): |
| pipe( |
| { |
| "lyrics": lyrics, |
| "tags": tags, |
| }, |
| max_audio_length_ms=max_audio_length_ms, |
| save_path=output_path, |
| topk=topk, |
| temperature=temperature, |
| cfg_scale=cfg_scale, |
| keep_model_loaded=keep_model_loaded, |
| offload_mode=offload_mode, |
| ) |
| try: |
| file_size = os.path.getsize(output_path) |
| print(f"Generated file: {output_path} ({file_size} bytes)") |
| except Exception as e: |
| print(f"WARN: failed to stat output file {output_path}: {e}") |
|
|
| if output_format == "mp3": |
| mp3_path = convert_wav_to_mp3(output_path, DATA_DIR) |
| return mp3_path |
|
|
| return output_path |
|
|
| except Exception as e: |
| raise gr.Error(f"Generation error: {str(e)}") |
| finally: |
| if not keep_model_loaded and torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| @_gpu_guard |
| def generate_original( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| ): |
| return generate( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| False, |
| ) |
|
|
|
|
| @_gpu_guard |
| def generate_accelerated( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| ): |
| return generate( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| True, |
| ) |
|
|
|
|
| def _normalize_stream_chunk(chunk: np.ndarray) -> np.ndarray: |
| chunk = np.nan_to_num(chunk, nan=0.0, posinf=0.0, neginf=0.0) |
| return np.clip(chunk, -1.0, 1.0) |
|
|
|
|
| def generate_music_streaming( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| use_acceleration, |
| ) -> Iterator[Tuple[int, np.ndarray]]: |
| _validate_generation_inputs(lyrics, tags) |
| if backend == "exllama_v2": |
| raise gr.Error("ExLlamaV2 backend is not implemented yet.") |
| pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode, use_acceleration) |
| max_audio_length_ms = int(duration_sec * 1000) |
| for chunk in pipe.stream( |
| {"lyrics": lyrics, "tags": tags}, |
| max_audio_length_ms=max_audio_length_ms, |
| temperature=temperature, |
| topk=topk, |
| cfg_scale=cfg_scale, |
| chunk_frames=BLOCK_FRAMES, |
| keep_model_loaded=keep_model_loaded, |
| offload_mode=offload_mode, |
| ): |
| if chunk.dim() == 2: |
| chunk = chunk.squeeze(0) |
| chunk_np = chunk.cpu().numpy() |
| chunk_np = _normalize_stream_chunk(chunk_np) |
| print(f"stream chunk: samples={chunk_np.shape[0]} sr={AUDIO_SAMPLE_RATE}") |
| yield AUDIO_SAMPLE_RATE, chunk_np |
|
|
|
|
| def stream_generate( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| use_acceleration, |
| ): |
| try: |
| _validate_generation_inputs(lyrics, tags) |
| start_threshold = max(MIN_BUFFER_BLOCKS, PREFETCH_BLOCKS) |
| queue = [] |
| started = False |
| print("block stream start:", f"block_sec={MODEL_DETOKENIZE_INTERVAL_SEC}", f"duration_sec={duration_sec}") |
| for sr, chunk_np in generate_music_streaming( |
| lyrics=lyrics, |
| tags=tags, |
| cfg_scale=cfg_scale, |
| duration_sec=duration_sec, |
| temperature=temperature, |
| topk=topk, |
| version=version, |
| codec_version=codec_version, |
| quant_mode=quant_mode, |
| keep_model_loaded=keep_model_loaded, |
| offload_mode=offload_mode, |
| backend=backend, |
| use_acceleration=use_acceleration, |
| ): |
| chunk_np = chunk_np.astype("float32", copy=False) |
| queue.append(chunk_np) |
| if not started and len(queue) < start_threshold and len(queue) < MAX_QUEUE_BLOCKS: |
| continue |
| if not started: |
| started = True |
| print(f"block stream start playback: buffered_blocks={len(queue)}") |
| while queue: |
| block = queue.pop(0) |
| print(f"block stream yield: samples={block.shape[0]}") |
| yield sr, block |
| except Exception as e: |
| raise gr.Error(f"Streaming error: {str(e)}") |
|
|
|
|
| @_gpu_guard |
| def stream_generate_accelerated( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| ): |
| if IS_SPACE and not STREAMING_ALLOWED: |
| raise gr.Error("Streaming is disabled on ZeroGPU. Use Standard mode or set ALLOW_STREAMING_ZERO_GPU=1.") |
| return stream_generate( |
| lyrics, |
| tags, |
| cfg_scale, |
| duration_sec, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| True, |
| ) |
|
|
|
|
| def generate_lyrics(theme, tags, language, api_choice, api_key_input, custom_base_url, custom_model, progress=gr.Progress()): |
| """Generate lyrics using selected LLM API""" |
|
|
| if not theme.strip(): |
| raise gr.Error("Please enter a theme") |
|
|
| progress(0.1, desc="Preparing request...") |
|
|
| |
| if api_choice not in LLM_PRESETS: |
| raise gr.Error(f"Unknown API choice: {api_choice}") |
|
|
| preset = LLM_PRESETS[api_choice] |
|
|
| |
| api_key = api_key_input.strip() if api_key_input and api_key_input.strip() else None |
| if not api_key and preset["env_key"]: |
| api_key = os.environ.get(preset["env_key"]) |
|
|
| if not api_key: |
| raise gr.Error(f"No API key provided. Please enter your API key in the field above.") |
|
|
| |
| base_url = custom_base_url.strip() if custom_base_url and custom_base_url.strip() else preset["base_url"] |
| model_name = custom_model.strip() if custom_model and custom_model.strip() else preset["default_model"] |
|
|
|
|
| |
| language_names = { |
| "en": "English", |
| "zh": "Chinese", |
| "jp": "Japanese", |
| "kr": "Korean", |
| "sp": "Spanish" |
| } |
| lang_name = language_names.get(language, "English") |
|
|
| |
| tags_text = tags.strip() if tags.strip() else "pop, emotional" |
|
|
| |
| prompt = f"""You are a professional songwriter. Generate song lyrics based on the following requirements: |
| |
| **Theme**: {theme} |
| **Music Style/Tags**: {tags_text} |
| **Language**: {lang_name} |
| |
| **Format Requirements** (CRITICAL): |
| 1. Use lowercase for all lyrics text (except structure tags which are in brackets) |
| 2. Include proper song structure tags: [Intro], [Verse], [Prechorus], [Chorus], [Bridge], [Outro] |
| 3. Each structure tag should be on its own line |
| 4. Separate different sections with a blank line (one empty line between sections) |
| 5. NO timestamps like [00:12] - only structure tags allowed |
| 6. Keep lyrics concise and suitable for a 3-4 minute song |
| |
| **Structure Guidelines**: |
| - [Intro]: Optional, 1-2 lines if included |
| - [Verse]: Story-telling part, 4-6 lines, can repeat with different lyrics |
| - [Prechorus]: Optional, 2-4 lines, builds tension before chorus |
| - [Chorus]: Main hook, catchy and repetitive, 4-6 lines |
| - [Bridge]: Optional, provides contrast, 4-6 lines |
| - [Outro]: Closing, 1-2 lines |
| |
| **Example Format**: |
| ``` |
| [Intro] |
| |
| [Verse] |
| the sun creeps in across the floor |
| i hear the traffic outside the door |
| the coffee pot begins to hiss |
| it is another morning just like this |
| |
| [Chorus] |
| every day the light returns |
| every day the fire burns |
| we keep on walking down this street |
| moving to the same steady beat |
| ``` |
| |
| Now generate lyrics in {lang_name} based on the theme "{theme}" with style "{tags_text}". |
| Output ONLY the lyrics with structure tags, no explanations. |
| """ |
|
|
| try: |
| if preset["api_type"] == "gemini": |
| if genai is None or types is None: |
| raise gr.Error("Gemini SDK not available. Install `google-genai` or switch provider.") |
| |
| progress(0.3, desc=f"Connecting to {preset['name']}...") |
|
|
| |
| try: |
| proxy_host = os.environ.get("PROXY_HOST", "127.0.0.1") |
| proxy_port = os.environ.get("PROXY_PORT", "7890") |
| os.environ['http_proxy'] = f'http://{proxy_host}:{proxy_port}' |
| os.environ['https_proxy'] = f'http://{proxy_host}:{proxy_port}' |
| except Exception: |
| pass |
|
|
| client = genai.Client(api_key=api_key) |
|
|
| progress(0.5, desc=f"Generating lyrics with {preset['name']}...") |
|
|
| response = client.models.generate_content( |
| model=model_name, |
| contents=[ |
| types.Content( |
| role='user', |
| parts=[types.Part(text=prompt)] |
| ) |
| ], |
| config=types.GenerateContentConfig( |
| temperature=0.8, |
| max_output_tokens=2000 |
| ) |
| ) |
|
|
| generated_lyrics = response.text.strip() |
|
|
| elif preset["api_type"] == "openai": |
| |
| progress(0.3, desc=f"Connecting to {preset['name']}...") |
|
|
| |
| if base_url: |
| client = OpenAI(api_key=api_key, base_url=base_url) |
| else: |
| client = OpenAI(api_key=api_key) |
|
|
| progress(0.5, desc=f"Generating lyrics with {preset['name']}...") |
|
|
| response = client.chat.completions.create( |
| model=model_name, |
| messages=[ |
| {"role": "system", "content": "You are a professional songwriter who creates well-structured lyrics."}, |
| {"role": "user", "content": prompt} |
| ], |
| temperature=0.8, |
| max_tokens=2000 |
| ) |
|
|
| generated_lyrics = response.choices[0].message.content.strip() |
|
|
| else: |
| raise gr.Error(f"Unknown API type: {preset['api_type']}") |
|
|
| progress(0.9, desc="Processing response...") |
|
|
| |
| if generated_lyrics.startswith("```"): |
| lines = generated_lyrics.split("\n") |
| generated_lyrics = "\n".join(lines[1:-1]) if len(lines) > 2 else generated_lyrics |
|
|
| |
| generated_lyrics = process_lyrics_correct(generated_lyrics) |
|
|
| progress(1.0, desc="Done!") |
|
|
| return generated_lyrics |
|
|
| except Exception as e: |
| raise gr.Error(f"Lyrics generation error: {str(e)}") |
|
|
|
|
| def create_ui(): |
| """Create Gradio UI""" |
| speed_submode_value = DEFAULT_SPEED_SUBMODE if STREAMING_ALLOWED else "Standard" |
| show_stream_default = ( |
| STREAMING_ALLOWED |
| and DEFAULT_GENERATION_MODE != "Original (No Acceleration)" |
| and speed_submode_value == "Streaming" |
| ) |
|
|
| with gr.Blocks(title="HeartMuLa Music Generation") as demo: |
| gr.Markdown("# HeartMuLa Music Generation") |
| gr.Markdown("Generate music from lyrics and style tags") |
| gr.Markdown("Tip: start with the **ZeroGPU Safe** preset for reliable generation on small GPUs.") |
|
|
| with gr.Tabs(): |
| |
| with gr.Tab("Music Generation"): |
| with gr.Row(): |
| with gr.Column(): |
| lyrics = gr.Textbox( |
| label="Lyrics", |
| lines=15, |
| value=EXAMPLE_LYRICS, |
| placeholder="Enter lyrics here..." |
| ) |
|
|
| |
| format_btn = gr.Button("Format Lyrics", size="sm") |
|
|
| |
| gr.Markdown("### Tags") |
|
|
| tags = gr.Textbox( |
| label="Selected Tags (comma-separated)", |
| value=EXAMPLE_TAGS, |
| placeholder="e.g., piano,happy,pop", |
| lines=2 |
| ) |
|
|
| |
| tag_checkboxes = [] |
| with gr.Accordion("Tag Categories (Click to Expand)", open=False): |
| with gr.Row(): |
| with gr.Column(): |
| t1 = gr.CheckboxGroup(choices=TAG_DATA["Gender"], label="Gender") |
| tag_checkboxes.append(t1) |
| t2 = gr.CheckboxGroup(choices=TAG_DATA["Genre"], label="Genre") |
| tag_checkboxes.append(t2) |
| with gr.Column(): |
| t3 = gr.CheckboxGroup(choices=TAG_DATA["Instrument"], label="Instrument") |
| tag_checkboxes.append(t3) |
| t4 = gr.CheckboxGroup(choices=TAG_DATA["Mood"], label="Mood") |
| tag_checkboxes.append(t4) |
| with gr.Column(): |
| t5 = gr.CheckboxGroup(choices=TAG_DATA["Scene"], label="Scene") |
| tag_checkboxes.append(t5) |
| t6 = gr.CheckboxGroup(choices=TAG_DATA["Singer Timbre"], label="Singer Timbre") |
| tag_checkboxes.append(t6) |
| with gr.Column(): |
| t7 = gr.CheckboxGroup(choices=TAG_DATA["Topic"], label="Topic") |
| tag_checkboxes.append(t7) |
|
|
| gr.Markdown("### Quick Presets") |
| with gr.Row(): |
| preset_selector = gr.Dropdown( |
| choices=list(PRESET_CONFIGS.keys()), |
| value=DEFAULT_PRESET, |
| label="Preset" |
| ) |
| apply_preset_btn = gr.Button("Apply Preset", size="sm") |
|
|
| |
| with gr.Row(): |
| cfg_scale = gr.Slider(0.0, 3.0, value=1.5, step=0.1, label="CFG Scale") |
| duration = gr.Slider(10, 300, value=DEFAULT_DURATION_SEC, step=10, label="Duration (sec)") |
|
|
| with gr.Row(): |
| temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature") |
| topk = gr.Slider(1, 100, value=50, step=1, label="Top-K") |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| backend = gr.Dropdown( |
| choices=[("HF Pipeline", "hf"), ("ExLlamaV2 (Not Implemented)", "exllama_v2")], |
| value="hf", |
| label="Backend" |
| ) |
| version = gr.Dropdown( |
| choices=["3B", "7B", "RL-oss-3B-20260123"], |
| value="3B", |
| label="Model Version" |
| ) |
| codec_version = gr.Dropdown( |
| choices=["oss", "oss-20260123"], |
| value="oss", |
| label="Codec Version" |
| ) |
| quant_mode = gr.Dropdown( |
| choices=[ |
| ("None", "none"), |
| ("4-bit (NF4/FP4)", "4bit"), |
| ("8-bit", "8bit"), |
| ("FP8 (TorchAO)", "fp8"), |
| ], |
| value=DEFAULT_QUANT_MODE, |
| label="Quantization (ZeroGPU recommended: 4-bit)" |
| ) |
| keep_model_loaded = gr.Checkbox( |
| value=DEFAULT_KEEP_MODEL_LOADED, |
| label="Keep Model Loaded" |
| ) |
| offload_mode = gr.Dropdown( |
| choices=["auto", "aggressive"], |
| value=DEFAULT_OFFLOAD_MODE, |
| label="Offload Mode" |
| ) |
| output_format = gr.Radio( |
| choices=[("WAV", "wav"), ("MP3", "mp3")], |
| value="wav", |
| label="Output Format", |
| visible=not show_stream_default, |
| ) |
| gr.Markdown( |
| f"Streaming is block-based: {MODEL_DETOKENIZE_INTERVAL_SEC:.2f}s per block " |
| f"({BLOCK_FRAMES} frames, {BLOCK_SAMPLES} samples)." |
| ) |
|
|
| gr.Markdown("### 🚀 Generation") |
|
|
| generation_mode = gr.Radio( |
| choices=["Original (No Acceleration)", "Accelerated"], |
| value=DEFAULT_GENERATION_MODE, |
| label="Generation Mode", |
| ) |
|
|
| speed_choices = [("Standard", "Standard")] |
| if STREAMING_ALLOWED: |
| speed_choices.append(("Block Streaming (Preview)", "Streaming")) |
|
|
| speed_submode = gr.Radio( |
| choices=speed_choices, |
| value=speed_submode_value, |
| label="Accelerated Options", |
| visible=DEFAULT_GENERATION_MODE != "Original (No Acceleration)", |
| ) |
|
|
| btn_original = gr.Button( |
| "🎼 Generate Music (Original)", |
| variant="primary", |
| size="lg", |
| visible=DEFAULT_GENERATION_MODE == "Original (No Acceleration)", |
| ) |
| btn_accel = gr.Button( |
| "🎼 Generate Music (Accelerated)", |
| variant="primary", |
| size="lg", |
| visible=DEFAULT_GENERATION_MODE != "Original (No Acceleration)" |
| and not show_stream_default, |
| ) |
| btn_stream = gr.Button( |
| "🎼 Generate Music (Block Streaming)", |
| variant="primary", |
| size="lg", |
| visible=show_stream_default, |
| ) |
| cancel_stream_btn = gr.Button( |
| "Cancel Streaming", |
| variant="secondary", |
| size="lg", |
| visible=show_stream_default, |
| ) |
| cancel_state = gr.State() |
|
|
| with gr.Column(): |
| |
| with gr.Accordion("Usage Notice", open=True): |
| gr.Markdown(""" |
| ### ZeroGPU Tips |
| - Default preset is **ZeroGPU AOTI FP8** for fastest H200 path |
| - Prefer **4-bit quantization** if FP8 or AOTI is unavailable |
| - Keep durations short (<= 60s) for faster first results |
| - **Streaming is block-based**, not token streaming: each block is ~29.76s |
| - Use Streaming for preview; Standard mode outputs a full downloadable file |
| - Try **FP8 + AOTI** (preset: ZeroGPU AOTI FP8) for the fastest H200 path |
| - ZeroGPU does not support generator streaming in forked GPU; streaming is for local deploy only |
| |
| ### Lyrics Format Requirements |
| |
| **Automatic Processing:** |
| 1. All text will be converted to **lowercase** |
| 2. Timestamps (e.g., [00:12]) will be **automatically removed** |
| 3. Leading/trailing whitespace on each line will be **stripped** |
| 4. Leading/trailing empty lines will be **removed** |
| 5. Multiple consecutive empty lines (3+) will be **collapsed to 2** |
| |
| **Recommended Format:** |
| - Use standard song structure tags: `[Intro]`, `[Verse]`, `[Chorus]`, `[Bridge]`, `[Outro]`, etc. |
| - Separate sections with **blank lines** |
| - Case doesn't matter (will be auto-converted) |
| |
| **Example:** |
| ``` |
| [Intro] |
| |
| [Verse] |
| The sun creeps in across the floor |
| I hear the traffic outside the door |
| |
| [Chorus] |
| Every day the light returns |
| Every day the fire burns |
| ``` |
| |
| --- |
| |
| ### Tags Format |
| - Use **commas** to separate multiple tags: `piano,happy,pop` |
| - Tags influence the style and mood of the generated music |
| - Select from categories below or type directly |
| """) |
|
|
| output_audio_file = gr.Audio( |
| label="Generated Music (Full)", |
| type="filepath", |
| interactive=False, |
| visible=not show_stream_default, |
| ) |
| stream_audio = gr.Audio( |
| label="Block Streaming Audio (Preview)", |
| streaming=True, |
| autoplay=True, |
| type="numpy", |
| format="wav", |
| interactive=False, |
| visible=show_stream_default, |
| ) |
|
|
| |
| for cb in tag_checkboxes: |
| cb.change(fn=update_tag_string, inputs=tag_checkboxes, outputs=tags) |
|
|
| |
| format_btn.click( |
| fn=process_lyrics_correct, |
| inputs=[lyrics], |
| outputs=[lyrics] |
| ) |
|
|
| def update_visibility(gen_mode, spd_mode): |
| if gen_mode == "Original (No Acceleration)": |
| return ( |
| gr.update(visible=False), |
| gr.update(visible=True), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=False), |
| gr.update(visible=True), |
| gr.update(visible=True), |
| gr.update(visible=False), |
| ) |
| show_stream = STREAMING_ALLOWED and spd_mode == "Streaming" |
| return ( |
| gr.update(visible=True), |
| gr.update(visible=False), |
| gr.update(visible=not show_stream), |
| gr.update(visible=show_stream), |
| gr.update(visible=show_stream), |
| gr.update(visible=not show_stream), |
| gr.update(visible=not show_stream), |
| gr.update(visible=show_stream), |
| ) |
|
|
| def apply_preset(preset_name): |
| preset = PRESET_CONFIGS.get(preset_name, PRESET_CONFIGS[DEFAULT_PRESET]) |
| return ( |
| gr.update(value=preset["cfg_scale"]), |
| gr.update(value=preset["duration"]), |
| gr.update(value=preset["temperature"]), |
| gr.update(value=preset["topk"]), |
| gr.update(value=preset["quant_mode"]), |
| gr.update(value=preset["keep_model_loaded"]), |
| gr.update(value=preset["offload_mode"]), |
| gr.update(value=preset["generation_mode"]), |
| gr.update(value=preset["speed_submode"]), |
| ) |
|
|
| generation_mode.change( |
| fn=update_visibility, |
| inputs=[generation_mode, speed_submode], |
| outputs=[ |
| speed_submode, |
| btn_original, |
| btn_accel, |
| btn_stream, |
| cancel_stream_btn, |
| output_format, |
| output_audio_file, |
| stream_audio, |
| ], |
| ) |
| speed_submode.change( |
| fn=update_visibility, |
| inputs=[generation_mode, speed_submode], |
| outputs=[ |
| speed_submode, |
| btn_original, |
| btn_accel, |
| btn_stream, |
| cancel_stream_btn, |
| output_format, |
| output_audio_file, |
| stream_audio, |
| ], |
| ) |
|
|
| preset_event = apply_preset_btn.click( |
| fn=apply_preset, |
| inputs=[preset_selector], |
| outputs=[ |
| cfg_scale, |
| duration, |
| temperature, |
| topk, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| generation_mode, |
| speed_submode, |
| ], |
| ) |
|
|
| preset_event.then( |
| fn=update_visibility, |
| inputs=[generation_mode, speed_submode], |
| outputs=[ |
| speed_submode, |
| btn_original, |
| btn_accel, |
| btn_stream, |
| cancel_stream_btn, |
| output_format, |
| output_audio_file, |
| stream_audio, |
| ], |
| ) |
|
|
| btn_original.click( |
| fn=generate_original, |
| inputs=[ |
| lyrics, |
| tags, |
| cfg_scale, |
| duration, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| ], |
| outputs=[output_audio_file], |
| concurrency_id="gpu_queue", |
| concurrency_limit=GPU_CONCURRENCY_LIMIT, |
| ) |
|
|
| btn_accel.click( |
| fn=generate_accelerated, |
| inputs=[ |
| lyrics, |
| tags, |
| cfg_scale, |
| duration, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| output_format, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| ], |
| outputs=[output_audio_file], |
| concurrency_id="gpu_queue", |
| concurrency_limit=GPU_CONCURRENCY_LIMIT, |
| ) |
|
|
| stream_event = btn_stream.click( |
| fn=stream_generate_accelerated, |
| inputs=[ |
| lyrics, |
| tags, |
| cfg_scale, |
| duration, |
| temperature, |
| topk, |
| version, |
| codec_version, |
| quant_mode, |
| keep_model_loaded, |
| offload_mode, |
| backend, |
| ], |
| outputs=[stream_audio], |
| concurrency_id="gpu_queue", |
| concurrency_limit=GPU_CONCURRENCY_LIMIT, |
| ) |
|
|
| cancel_stream_btn.click( |
| fn=lambda: None, |
| inputs=None, |
| outputs=[cancel_state], |
| cancels=[stream_event], |
| ) |
|
|
| |
| with gr.Tab("Lyrics Generation"): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Generate Lyrics with AI") |
|
|
| api_selector = gr.Radio( |
| choices=[ |
| ("Google Gemini", "gemini"), |
| ("OpenAI", "openai"), |
| ("DeepSeek", "deepseek"), |
| ("Custom (OpenAI-compatible)", "custom") |
| ], |
| value="gemini", |
| label="Select LLM Provider" |
| ) |
|
|
| with gr.Accordion("API Configuration", open=True): |
| api_key_input = gr.Textbox( |
| label="API Key (Required)", |
| type="password", |
| placeholder="Enter your API key or set environment variable", |
| info="Will use environment variable if not provided here" |
| ) |
|
|
| custom_base_url = gr.Textbox( |
| label="Custom Base URL (Optional)", |
| placeholder="e.g., https://api.your-provider.com/v1", |
| info="Leave empty to use default. For custom providers only.", |
| visible=False |
| ) |
|
|
| custom_model = gr.Textbox( |
| label="Model Name (Optional)", |
| placeholder="e.g., gpt-4o, deepseek-chat", |
| info="Leave empty to use recommended default", |
| visible=False |
| ) |
|
|
| def update_custom_fields(choice): |
| """Show/hide custom fields based on API choice""" |
| if choice == "custom": |
| return gr.update(visible=True), gr.update(visible=True) |
| elif choice == "deepseek": |
| return gr.update(visible=False), gr.update(visible=True) |
| else: |
| return gr.update(visible=False), gr.update(visible=False) |
|
|
| api_selector.change( |
| fn=update_custom_fields, |
| inputs=[api_selector], |
| outputs=[custom_base_url, custom_model] |
| ) |
|
|
| theme_input = gr.Textbox( |
| label="Theme", |
| placeholder="e.g., Love lost in the city, Dreams and hope, Rainy day memories...", |
| lines=2 |
| ) |
|
|
| tags_gen = gr.Textbox( |
| label="Music Style/Tags", |
| placeholder="e.g., piano, melancholy, pop", |
| value="pop,emotional" |
| ) |
|
|
| language_select = gr.Radio( |
| choices=[ |
| ("English", "en"), |
| ("中文 (Chinese)", "zh"), |
| ("日本語 (Japanese)", "jp"), |
| ("한국어 (Korean)", "kr"), |
| ("Español (Spanish)", "sp") |
| ], |
| value="en", |
| label="Language" |
| ) |
|
|
| generate_lyrics_btn = gr.Button( |
| "Generate Lyrics", |
| variant="primary", |
| size="lg" |
| ) |
|
|
| with gr.Column(): |
| with gr.Accordion("How to Use", open=True): |
| gr.Markdown(""" |
| ### How to Generate Lyrics |
| |
| **Theme**: Describe your song's story or emotion |
| - Examples: "Lost love in Tokyo", "Overcoming obstacles", "Summer road trip" |
| |
| **Music Style/Tags**: Define mood and genre |
| - Examples: "piano,melancholy,ballad", "upbeat,electronic,dance" |
| |
| **Tips** |
| - Generated lyrics follow standard song structure ([Intro], [Verse], [Chorus], etc.) |
| - Edit lyrics before using for music generation |
| - Be specific with themes for better results |
| |
| """) |
|
|
| generated_lyrics_output = gr.Textbox( |
| label="Generated Lyrics", |
| lines=20, |
| placeholder="Generated lyrics will appear here...", |
| interactive=False |
| ) |
|
|
| copy_to_music_gen = gr.Button( |
| "Copy to Music Generation Tab", |
| size="sm" |
| ) |
|
|
| |
| generate_lyrics_btn.click( |
| fn=generate_lyrics, |
| inputs=[theme_input, tags_gen, language_select, api_selector, api_key_input, custom_base_url, custom_model], |
| outputs=[generated_lyrics_output] |
| ) |
|
|
| |
| def copy_lyrics(lyrics_text): |
| return lyrics_text |
|
|
| copy_to_music_gen.click( |
| fn=copy_lyrics, |
| inputs=[generated_lyrics_output], |
| outputs=[lyrics] |
| ) |
|
|
| return demo |
|
|
|
|
| demo = create_ui() |
| demo.queue(max_size=GRADIO_QUEUE_MAX_SIZE, default_concurrency_limit=GRADIO_DEFAULT_CONCURRENCY) |
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", "7860")) |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=port, |
| allowed_paths=[str(DATA_DIR), "/tmp"], |
| ssr_mode=False, |
| ) |
|
|