Spaces:
Running on Zero
Running on Zero
Commit ·
dc5fc4b
0
Parent(s):
feat: ACE-Step Studio — custom frontend for ACE-Step v1.5 music generation
Browse files- gr.Server with custom HTML frontend + API endpoints
- v1.5 AceStepHandler with acestep-v15-xl-turbo (8-step turbo)
- Peak normalization, ZeroGPU permission fixes
- /generate and /inspire API endpoints
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +39 -0
- .gitignore +1 -0
- README.md +31 -0
- acestep/__init__.py +1 -0
- acestep/acestep_v15_pipeline.py +303 -0
- acestep/api_server.py +1700 -0
- acestep/audio_utils.py +378 -0
- acestep/constants.py +109 -0
- acestep/constrained_logits_processor.py +0 -0
- acestep/dataset_handler.py +37 -0
- acestep/dit_alignment_score.py +870 -0
- acestep/genres_vocab.txt +0 -0
- acestep/gradio_ui/__init__.py +1 -0
- acestep/gradio_ui/events/__init__.py +1355 -0
- acestep/gradio_ui/events/generation_handlers.py +1071 -0
- acestep/gradio_ui/events/results_handlers.py +0 -0
- acestep/gradio_ui/events/training_handlers.py +644 -0
- acestep/gradio_ui/i18n.py +152 -0
- acestep/gradio_ui/i18n/en.json +245 -0
- acestep/gradio_ui/i18n/ja.json +245 -0
- acestep/gradio_ui/i18n/zh.json +245 -0
- acestep/gradio_ui/interfaces/__init__.py +105 -0
- acestep/gradio_ui/interfaces/dataset.py +101 -0
- acestep/gradio_ui/interfaces/generation.py +694 -0
- acestep/gradio_ui/interfaces/result.py +598 -0
- acestep/gradio_ui/interfaces/training.py +562 -0
- acestep/handler.py +0 -0
- acestep/inference.py +1181 -0
- acestep/llm_inference.py +0 -0
- acestep/local_cache.py +129 -0
- acestep/test_time_scaling.py +410 -0
- acestep/third_parts/nano-vllm/LICENSE +21 -0
- acestep/third_parts/nano-vllm/README.md +66 -0
- acestep/third_parts/nano-vllm/bench.py +32 -0
- acestep/third_parts/nano-vllm/example.py +33 -0
- acestep/third_parts/nano-vllm/nanovllm/__init__.py +2 -0
- acestep/third_parts/nano-vllm/nanovllm/config.py +26 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py +119 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py +178 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +543 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py +230 -0
- acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py +96 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/activation.py +14 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/attention.py +75 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py +66 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py +50 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/linear.py +153 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/rotary_embedding.py +61 -0
- acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py +114 -0
- acestep/third_parts/nano-vllm/nanovllm/llm.py +5 -0
.gitattributes
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.claude/
|
README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Ace-Step Studio
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.12.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Minimalist dark UI for ACE-Step music generation
|
| 12 |
+
models:
|
| 13 |
+
- ACE-Step/Ace-Step1.5
|
| 14 |
+
- ACE-Step/acestep-v15-xl-turbo
|
| 15 |
+
preload_from_hub:
|
| 16 |
+
- ACE-Step/Ace-Step1.5
|
| 17 |
+
- ACE-Step/acestep-v15-xl-turbo
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
# ACE-Step Studio
|
| 21 |
+
|
| 22 |
+
A minimalist, dark-themed interface for generating music with [ACE-Step](https://github.com/ace-step/ACE-Step).
|
| 23 |
+
|
| 24 |
+
**Model**: `ACE-Step/acestep-v15-xl-turbo` — generates 1 minute of audio in ~2 seconds (8-step turbo distillation).
|
| 25 |
+
|
| 26 |
+
## Usage
|
| 27 |
+
|
| 28 |
+
1. Enter style tags (e.g. `lo-fi, chill, piano, female vocals`)
|
| 29 |
+
2. Write lyrics with `[verse]`, `[chorus]`, `[bridge]` section markers
|
| 30 |
+
3. Hit **Generate** — a waveform appears when ready
|
| 31 |
+
4. Use **✨ Inspire me** to auto-generate lyrics via LLM
|
acestep/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""ACE-Step package."""
|
acestep/acestep_v15_pipeline.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step V1.5 Pipeline
|
| 3 |
+
Handler wrapper connecting model and UI
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
# Load environment variables from .env file in project root
|
| 9 |
+
# This allows configuration without hardcoding values
|
| 10 |
+
# Falls back to .env.example if .env is not found
|
| 11 |
+
try:
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
# Get project root directory
|
| 14 |
+
_current_file = os.path.abspath(__file__)
|
| 15 |
+
_project_root = os.path.dirname(os.path.dirname(_current_file))
|
| 16 |
+
_env_path = os.path.join(_project_root, '.env')
|
| 17 |
+
_env_example_path = os.path.join(_project_root, '.env.example')
|
| 18 |
+
|
| 19 |
+
if os.path.exists(_env_path):
|
| 20 |
+
load_dotenv(_env_path)
|
| 21 |
+
print(f"Loaded configuration from {_env_path}")
|
| 22 |
+
elif os.path.exists(_env_example_path):
|
| 23 |
+
load_dotenv(_env_example_path)
|
| 24 |
+
print(f"Loaded configuration from {_env_example_path} (fallback)")
|
| 25 |
+
except ImportError:
|
| 26 |
+
# python-dotenv not installed, skip loading .env
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
# Clear proxy settings that may affect Gradio
|
| 30 |
+
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
| 31 |
+
os.environ.pop(proxy_var, None)
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# When executed as a module: `python -m acestep.acestep_v15_pipeline`
|
| 35 |
+
from .handler import AceStepHandler
|
| 36 |
+
from .llm_inference import LLMHandler
|
| 37 |
+
from .dataset_handler import DatasetHandler
|
| 38 |
+
from .gradio_ui import create_gradio_interface
|
| 39 |
+
except ImportError:
|
| 40 |
+
# When executed as a script: `python acestep/acestep_v15_pipeline.py`
|
| 41 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 42 |
+
if project_root not in sys.path:
|
| 43 |
+
sys.path.insert(0, project_root)
|
| 44 |
+
from acestep.handler import AceStepHandler
|
| 45 |
+
from acestep.llm_inference import LLMHandler
|
| 46 |
+
from acestep.dataset_handler import DatasetHandler
|
| 47 |
+
from acestep.gradio_ui import create_gradio_interface
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def create_demo(init_params=None, language='en'):
|
| 51 |
+
"""
|
| 52 |
+
Create Gradio demo interface
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 56 |
+
If None, service will not be pre-initialized.
|
| 57 |
+
Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
|
| 58 |
+
'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
|
| 59 |
+
'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
|
| 60 |
+
'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
|
| 61 |
+
'language' (UI language code)
|
| 62 |
+
language: UI language code ('en', 'zh', 'ja', default: 'en')
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Gradio Blocks instance
|
| 66 |
+
"""
|
| 67 |
+
# Get persistent storage path from init_params (for HuggingFace Space)
|
| 68 |
+
persistent_storage_path = None
|
| 69 |
+
if init_params:
|
| 70 |
+
persistent_storage_path = init_params.get('persistent_storage_path')
|
| 71 |
+
|
| 72 |
+
# Use pre-initialized handlers if available, otherwise create new ones
|
| 73 |
+
if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
|
| 74 |
+
dit_handler = init_params['dit_handler']
|
| 75 |
+
llm_handler = init_params['llm_handler']
|
| 76 |
+
else:
|
| 77 |
+
dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path)
|
| 78 |
+
llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path)
|
| 79 |
+
|
| 80 |
+
dataset_handler = DatasetHandler() # Dataset handler
|
| 81 |
+
|
| 82 |
+
# Create Gradio interface with all handlers and initialization parameters
|
| 83 |
+
demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
|
| 84 |
+
|
| 85 |
+
return demo
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_gpu_memory_gb():
|
| 89 |
+
"""
|
| 90 |
+
Get GPU memory in GB. Returns 0 if no GPU is available.
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
import torch
|
| 94 |
+
if torch.cuda.is_available():
|
| 95 |
+
# Get total memory of the first GPU in GB
|
| 96 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory
|
| 97 |
+
memory_gb = total_memory / (1024**3) # Convert bytes to GB
|
| 98 |
+
return memory_gb
|
| 99 |
+
else:
|
| 100 |
+
return 0
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
|
| 103 |
+
return 0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main():
|
| 107 |
+
"""Main entry function"""
|
| 108 |
+
import argparse
|
| 109 |
+
|
| 110 |
+
# Detect GPU memory to auto-configure offload settings
|
| 111 |
+
gpu_memory_gb = get_gpu_memory_gb()
|
| 112 |
+
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
|
| 113 |
+
|
| 114 |
+
if auto_offload:
|
| 115 |
+
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
|
| 116 |
+
print("Auto-enabling CPU offload to reduce GPU memory usage")
|
| 117 |
+
elif gpu_memory_gb > 0:
|
| 118 |
+
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
|
| 119 |
+
print("CPU offload disabled by default")
|
| 120 |
+
else:
|
| 121 |
+
print("No GPU detected, running on CPU")
|
| 122 |
+
|
| 123 |
+
parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
|
| 124 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
|
| 125 |
+
parser.add_argument("--share", action="store_true", help="Create a public link")
|
| 126 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 127 |
+
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
|
| 128 |
+
parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "ja"], help="UI language: en (English), zh (中文), ja (日本語)")
|
| 129 |
+
|
| 130 |
+
# Service mode argument
|
| 131 |
+
parser.add_argument("--service_mode", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False,
|
| 132 |
+
help="Enable service mode (default: False). When enabled, uses preset models and restricts UI options.")
|
| 133 |
+
|
| 134 |
+
# Service initialization arguments
|
| 135 |
+
parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
|
| 136 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
|
| 137 |
+
parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
|
| 138 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Processing device (default: auto)")
|
| 139 |
+
parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=True, help="Initialize 5Hz LM (default: True)")
|
| 140 |
+
parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
|
| 141 |
+
parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
|
| 142 |
+
parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
|
| 143 |
+
parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
|
| 144 |
+
parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
|
| 145 |
+
|
| 146 |
+
args = parser.parse_args()
|
| 147 |
+
|
| 148 |
+
# Service mode defaults (can be configured via .env file)
|
| 149 |
+
if args.service_mode:
|
| 150 |
+
print("Service mode enabled - applying preset configurations...")
|
| 151 |
+
# Force init_service in service mode
|
| 152 |
+
args.init_service = True
|
| 153 |
+
# Default DiT model for service mode (from env or fallback)
|
| 154 |
+
if args.config_path is None:
|
| 155 |
+
args.config_path = os.environ.get(
|
| 156 |
+
"SERVICE_MODE_DIT_MODEL",
|
| 157 |
+
"acestep-v15-turbo-fix-inst-shift-dynamic"
|
| 158 |
+
)
|
| 159 |
+
# Default LM model for service mode (from env or fallback)
|
| 160 |
+
if args.lm_model_path is None:
|
| 161 |
+
args.lm_model_path = os.environ.get(
|
| 162 |
+
"SERVICE_MODE_LM_MODEL",
|
| 163 |
+
"acestep-5Hz-lm-1.7B-v4-fix"
|
| 164 |
+
)
|
| 165 |
+
# Backend for service mode (from env or fallback to vllm)
|
| 166 |
+
args.backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
|
| 167 |
+
print(f" DiT model: {args.config_path}")
|
| 168 |
+
print(f" LM model: {args.lm_model_path}")
|
| 169 |
+
print(f" Backend: {args.backend}")
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
init_params = None
|
| 173 |
+
|
| 174 |
+
# If init_service is True, perform initialization before creating UI
|
| 175 |
+
if args.init_service:
|
| 176 |
+
print("Initializing service from command line...")
|
| 177 |
+
|
| 178 |
+
# Create handler instances for initialization
|
| 179 |
+
dit_handler = AceStepHandler()
|
| 180 |
+
llm_handler = LLMHandler()
|
| 181 |
+
|
| 182 |
+
# Auto-select config_path if not provided
|
| 183 |
+
if args.config_path is None:
|
| 184 |
+
available_models = dit_handler.get_available_acestep_v15_models()
|
| 185 |
+
if available_models:
|
| 186 |
+
args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
|
| 187 |
+
print(f"Auto-selected config_path: {args.config_path}")
|
| 188 |
+
else:
|
| 189 |
+
print("Error: No available models found. Please specify --config_path", file=sys.stderr)
|
| 190 |
+
sys.exit(1)
|
| 191 |
+
|
| 192 |
+
# Get project root (same logic as in handler)
|
| 193 |
+
current_file = os.path.abspath(__file__)
|
| 194 |
+
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 195 |
+
|
| 196 |
+
# Determine flash attention setting
|
| 197 |
+
use_flash_attention = args.use_flash_attention
|
| 198 |
+
if use_flash_attention is None:
|
| 199 |
+
use_flash_attention = dit_handler.is_flash_attention_available()
|
| 200 |
+
|
| 201 |
+
# Initialize DiT handler
|
| 202 |
+
print(f"Initializing DiT model: {args.config_path} on {args.device}...")
|
| 203 |
+
init_status, enable_generate = dit_handler.initialize_service(
|
| 204 |
+
project_root=project_root,
|
| 205 |
+
config_path=args.config_path,
|
| 206 |
+
device=args.device,
|
| 207 |
+
use_flash_attention=use_flash_attention,
|
| 208 |
+
compile_model=False,
|
| 209 |
+
offload_to_cpu=args.offload_to_cpu,
|
| 210 |
+
offload_dit_to_cpu=args.offload_dit_to_cpu
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if not enable_generate:
|
| 214 |
+
print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
|
| 215 |
+
sys.exit(1)
|
| 216 |
+
|
| 217 |
+
print(f"DiT model initialized successfully")
|
| 218 |
+
|
| 219 |
+
# Initialize LM handler if requested
|
| 220 |
+
lm_status = ""
|
| 221 |
+
if args.init_llm:
|
| 222 |
+
if args.lm_model_path is None:
|
| 223 |
+
# Try to get default LM model
|
| 224 |
+
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 225 |
+
if available_lm_models:
|
| 226 |
+
args.lm_model_path = available_lm_models[0]
|
| 227 |
+
print(f"Using default LM model: {args.lm_model_path}")
|
| 228 |
+
else:
|
| 229 |
+
print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
|
| 230 |
+
args.init_llm = False
|
| 231 |
+
|
| 232 |
+
if args.init_llm and args.lm_model_path:
|
| 233 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 234 |
+
print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
|
| 235 |
+
lm_status, lm_success = llm_handler.initialize(
|
| 236 |
+
checkpoint_dir=checkpoint_dir,
|
| 237 |
+
lm_model_path=args.lm_model_path,
|
| 238 |
+
backend=args.backend,
|
| 239 |
+
device=args.device,
|
| 240 |
+
offload_to_cpu=args.offload_to_cpu,
|
| 241 |
+
dtype=dit_handler.dtype
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
if lm_success:
|
| 245 |
+
print(f"5Hz LM initialized successfully")
|
| 246 |
+
init_status += f"\n{lm_status}"
|
| 247 |
+
else:
|
| 248 |
+
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
|
| 249 |
+
init_status += f"\n{lm_status}"
|
| 250 |
+
|
| 251 |
+
# Prepare initialization parameters for UI
|
| 252 |
+
init_params = {
|
| 253 |
+
'pre_initialized': True,
|
| 254 |
+
'service_mode': args.service_mode,
|
| 255 |
+
'checkpoint': args.checkpoint,
|
| 256 |
+
'config_path': args.config_path,
|
| 257 |
+
'device': args.device,
|
| 258 |
+
'init_llm': args.init_llm,
|
| 259 |
+
'lm_model_path': args.lm_model_path,
|
| 260 |
+
'backend': args.backend,
|
| 261 |
+
'use_flash_attention': use_flash_attention,
|
| 262 |
+
'offload_to_cpu': args.offload_to_cpu,
|
| 263 |
+
'offload_dit_to_cpu': args.offload_dit_to_cpu,
|
| 264 |
+
'init_status': init_status,
|
| 265 |
+
'enable_generate': enable_generate,
|
| 266 |
+
'dit_handler': dit_handler,
|
| 267 |
+
'llm_handler': llm_handler,
|
| 268 |
+
'language': args.language
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
print("Service initialization completed successfully!")
|
| 272 |
+
|
| 273 |
+
# Create and launch demo
|
| 274 |
+
print(f"Creating Gradio interface with language: {args.language}...")
|
| 275 |
+
demo = create_demo(init_params=init_params, language=args.language)
|
| 276 |
+
|
| 277 |
+
# Enable queue for multi-user support
|
| 278 |
+
# This ensures proper request queuing and prevents concurrent generation conflicts
|
| 279 |
+
print("Enabling queue for multi-user support...")
|
| 280 |
+
demo.queue(
|
| 281 |
+
max_size=20, # Maximum queue size (adjust based on your needs)
|
| 282 |
+
status_update_rate="auto", # Update rate for queue status
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
print(f"Launching server on {args.server_name}:{args.port}...")
|
| 286 |
+
demo.launch(
|
| 287 |
+
server_name=args.server_name,
|
| 288 |
+
server_port=args.port,
|
| 289 |
+
share=args.share,
|
| 290 |
+
debug=args.debug,
|
| 291 |
+
show_error=True,
|
| 292 |
+
prevent_thread_lock=False, # Keep thread locked to maintain server running
|
| 293 |
+
inbrowser=False, # Don't auto-open browser
|
| 294 |
+
)
|
| 295 |
+
except Exception as e:
|
| 296 |
+
print(f"Error launching Gradio: {e}", file=sys.stderr)
|
| 297 |
+
import traceback
|
| 298 |
+
traceback.print_exc()
|
| 299 |
+
sys.exit(1)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
if __name__ == "__main__":
|
| 303 |
+
main()
|
acestep/api_server.py
ADDED
|
@@ -0,0 +1,1700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server for ACE-Step V1.5.
|
| 2 |
+
|
| 3 |
+
Endpoints:
|
| 4 |
+
- POST /release_task Create music generation task
|
| 5 |
+
- POST /query_result Batch query task results
|
| 6 |
+
- POST /v1/music/random Create random sample task
|
| 7 |
+
- GET /v1/models List available models
|
| 8 |
+
- GET /v1/audio Download audio file
|
| 9 |
+
- GET /health Health check
|
| 10 |
+
|
| 11 |
+
NOTE:
|
| 12 |
+
- In-memory queue and job store -> run uvicorn with workers=1.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import asyncio
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
import traceback
|
| 23 |
+
import tempfile
|
| 24 |
+
import urllib.parse
|
| 25 |
+
from collections import deque
|
| 26 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 27 |
+
from contextlib import asynccontextmanager
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from threading import Lock
|
| 31 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 32 |
+
from uuid import uuid4
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from dotenv import load_dotenv
|
| 36 |
+
except ImportError: # Optional dependency
|
| 37 |
+
load_dotenv = None # type: ignore
|
| 38 |
+
|
| 39 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 40 |
+
from pydantic import BaseModel, Field
|
| 41 |
+
from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 42 |
+
|
| 43 |
+
from acestep.handler import AceStepHandler
|
| 44 |
+
from acestep.llm_inference import LLMHandler
|
| 45 |
+
from acestep.constants import (
|
| 46 |
+
DEFAULT_DIT_INSTRUCTION,
|
| 47 |
+
DEFAULT_LM_INSTRUCTION,
|
| 48 |
+
TASK_INSTRUCTIONS,
|
| 49 |
+
)
|
| 50 |
+
from acestep.inference import (
|
| 51 |
+
GenerationParams,
|
| 52 |
+
GenerationConfig,
|
| 53 |
+
generate_music,
|
| 54 |
+
create_sample,
|
| 55 |
+
format_sample,
|
| 56 |
+
)
|
| 57 |
+
from acestep.gradio_ui.events.results_handlers import _build_generation_info
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# =============================================================================
|
| 61 |
+
# Constants
|
| 62 |
+
# =============================================================================
|
| 63 |
+
|
| 64 |
+
RESULT_KEY_PREFIX = "ace_step_v1.5_"
|
| 65 |
+
RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days
|
| 66 |
+
TASK_TIMEOUT_SECONDS = 3600 # 1 hour
|
| 67 |
+
STATUS_MAP = {"queued": 0, "running": 0, "succeeded": 1, "failed": 2}
|
| 68 |
+
|
| 69 |
+
LM_DEFAULT_TEMPERATURE = 0.85
|
| 70 |
+
LM_DEFAULT_CFG_SCALE = 2.5
|
| 71 |
+
LM_DEFAULT_TOP_P = 0.9
|
| 72 |
+
|
| 73 |
+
# Parameter aliases for request parsing
|
| 74 |
+
PARAM_ALIASES = {
|
| 75 |
+
"prompt": ["prompt"],
|
| 76 |
+
"sample_mode": ["sample_mode", "sampleMode"],
|
| 77 |
+
"sample_query": ["sample_query", "sampleQuery", "description", "desc"],
|
| 78 |
+
"use_format": ["use_format", "useFormat", "format"],
|
| 79 |
+
"model": ["model", "dit_model", "ditModel"],
|
| 80 |
+
"key_scale": ["key_scale", "keyscale", "keyScale"],
|
| 81 |
+
"time_signature": ["time_signature", "timesignature", "timeSignature"],
|
| 82 |
+
"audio_duration": ["audio_duration", "duration", "audioDuration", "target_duration", "targetDuration"],
|
| 83 |
+
"vocal_language": ["vocal_language", "vocalLanguage"],
|
| 84 |
+
"inference_steps": ["inference_steps", "inferenceSteps"],
|
| 85 |
+
"guidance_scale": ["guidance_scale", "guidanceScale"],
|
| 86 |
+
"use_random_seed": ["use_random_seed", "useRandomSeed"],
|
| 87 |
+
"audio_code_string": ["audio_code_string", "audioCodeString"],
|
| 88 |
+
"audio_cover_strength": ["audio_cover_strength", "audioCoverStrength"],
|
| 89 |
+
"task_type": ["task_type", "taskType"],
|
| 90 |
+
"infer_method": ["infer_method", "inferMethod"],
|
| 91 |
+
"use_tiled_decode": ["use_tiled_decode", "useTiledDecode"],
|
| 92 |
+
"constrained_decoding": ["constrained_decoding", "constrainedDecoding", "constrained"],
|
| 93 |
+
"constrained_decoding_debug": ["constrained_decoding_debug", "constrainedDecodingDebug"],
|
| 94 |
+
"use_cot_caption": ["use_cot_caption", "cot_caption", "cot-caption"],
|
| 95 |
+
"use_cot_language": ["use_cot_language", "cot_language", "cot-language"],
|
| 96 |
+
"is_format_caption": ["is_format_caption", "isFormatCaption"],
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
|
| 101 |
+
"""
|
| 102 |
+
Parse a description string to extract language code and instrumental flag.
|
| 103 |
+
|
| 104 |
+
This function analyzes user descriptions like "Pop rock. English" or "piano solo"
|
| 105 |
+
to detect:
|
| 106 |
+
- Language: Maps language names to ISO codes (e.g., "English" -> "en")
|
| 107 |
+
- Instrumental: Detects patterns indicating instrumental/no-vocal music
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
description: User's natural language music description
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
(language_code, is_instrumental) tuple:
|
| 114 |
+
- language_code: ISO language code (e.g., "en", "zh") or None if not detected
|
| 115 |
+
- is_instrumental: True if description indicates instrumental music
|
| 116 |
+
"""
|
| 117 |
+
import re
|
| 118 |
+
|
| 119 |
+
if not description:
|
| 120 |
+
return None, False
|
| 121 |
+
|
| 122 |
+
description_lower = description.lower().strip()
|
| 123 |
+
|
| 124 |
+
# Language mapping: input patterns -> ISO code
|
| 125 |
+
language_mapping = {
|
| 126 |
+
'english': 'en', 'en': 'en',
|
| 127 |
+
'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh',
|
| 128 |
+
'japanese': 'ja', '日本語': 'ja', 'ja': 'ja',
|
| 129 |
+
'korean': 'ko', '한국어': 'ko', 'ko': 'ko',
|
| 130 |
+
'spanish': 'es', 'español': 'es', 'es': 'es',
|
| 131 |
+
'french': 'fr', 'français': 'fr', 'fr': 'fr',
|
| 132 |
+
'german': 'de', 'deutsch': 'de', 'de': 'de',
|
| 133 |
+
'italian': 'it', 'italiano': 'it', 'it': 'it',
|
| 134 |
+
'portuguese': 'pt', 'português': 'pt', 'pt': 'pt',
|
| 135 |
+
'russian': 'ru', 'русский': 'ru', 'ru': 'ru',
|
| 136 |
+
'bengali': 'bn', 'bn': 'bn',
|
| 137 |
+
'hindi': 'hi', 'hi': 'hi',
|
| 138 |
+
'arabic': 'ar', 'ar': 'ar',
|
| 139 |
+
'thai': 'th', 'th': 'th',
|
| 140 |
+
'vietnamese': 'vi', 'vi': 'vi',
|
| 141 |
+
'indonesian': 'id', 'id': 'id',
|
| 142 |
+
'turkish': 'tr', 'tr': 'tr',
|
| 143 |
+
'dutch': 'nl', 'nl': 'nl',
|
| 144 |
+
'polish': 'pl', 'pl': 'pl',
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
# Detect language
|
| 148 |
+
detected_language = None
|
| 149 |
+
for lang_name, lang_code in language_mapping.items():
|
| 150 |
+
if len(lang_name) <= 2:
|
| 151 |
+
pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])'
|
| 152 |
+
else:
|
| 153 |
+
pattern = r'\b' + re.escape(lang_name) + r'\b'
|
| 154 |
+
|
| 155 |
+
if re.search(pattern, description_lower):
|
| 156 |
+
detected_language = lang_code
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
# Detect instrumental
|
| 160 |
+
is_instrumental = False
|
| 161 |
+
if 'instrumental' in description_lower:
|
| 162 |
+
is_instrumental = True
|
| 163 |
+
elif 'pure music' in description_lower or 'pure instrument' in description_lower:
|
| 164 |
+
is_instrumental = True
|
| 165 |
+
elif description_lower.endswith(' solo') or description_lower == 'solo':
|
| 166 |
+
is_instrumental = True
|
| 167 |
+
|
| 168 |
+
return detected_language, is_instrumental
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class GenerateMusicRequest(BaseModel):
|
| 175 |
+
prompt: str = Field(default="", description="Text prompt describing the music")
|
| 176 |
+
lyrics: str = Field(default="", description="Lyric text")
|
| 177 |
+
|
| 178 |
+
# New API semantics:
|
| 179 |
+
# - thinking=True: use 5Hz LM to generate audio codes (lm-dit behavior)
|
| 180 |
+
# - thinking=False: do not use LM to generate codes (dit behavior)
|
| 181 |
+
# Regardless of thinking, if some metas are missing, server may use LM to fill them.
|
| 182 |
+
thinking: bool = False
|
| 183 |
+
# Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
|
| 184 |
+
sample_mode: bool = False
|
| 185 |
+
# Description for sample mode: auto-generate caption/lyrics from description query
|
| 186 |
+
sample_query: str = Field(default="", description="Query/description for sample mode (use create_sample)")
|
| 187 |
+
# Whether to use format_sample() to enhance input caption/lyrics
|
| 188 |
+
use_format: bool = Field(default=False, description="Use format_sample() to enhance input (default: False)")
|
| 189 |
+
# Model name for multi-model support (select which DiT model to use)
|
| 190 |
+
model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
|
| 191 |
+
|
| 192 |
+
bpm: Optional[int] = None
|
| 193 |
+
# Accept common client keys via manual parsing (see RequestParser).
|
| 194 |
+
key_scale: str = ""
|
| 195 |
+
time_signature: str = ""
|
| 196 |
+
vocal_language: str = "en"
|
| 197 |
+
inference_steps: int = 8
|
| 198 |
+
guidance_scale: float = 7.0
|
| 199 |
+
use_random_seed: bool = True
|
| 200 |
+
seed: int = -1
|
| 201 |
+
|
| 202 |
+
reference_audio_path: Optional[str] = None
|
| 203 |
+
src_audio_path: Optional[str] = None
|
| 204 |
+
audio_duration: Optional[float] = None
|
| 205 |
+
batch_size: Optional[int] = None
|
| 206 |
+
|
| 207 |
+
audio_code_string: str = ""
|
| 208 |
+
|
| 209 |
+
repainting_start: float = 0.0
|
| 210 |
+
repainting_end: Optional[float] = None
|
| 211 |
+
|
| 212 |
+
instruction: str = DEFAULT_DIT_INSTRUCTION
|
| 213 |
+
audio_cover_strength: float = 1.0
|
| 214 |
+
task_type: str = "text2music"
|
| 215 |
+
|
| 216 |
+
use_adg: bool = False
|
| 217 |
+
cfg_interval_start: float = 0.0
|
| 218 |
+
cfg_interval_end: float = 1.0
|
| 219 |
+
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 220 |
+
shift: float = Field(
|
| 221 |
+
default=3.0,
|
| 222 |
+
description="Timestep shift factor (range 1.0~5.0, default 3.0). Only effective for base models, not turbo models."
|
| 223 |
+
)
|
| 224 |
+
timesteps: Optional[str] = Field(
|
| 225 |
+
default=None,
|
| 226 |
+
description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
audio_format: str = "mp3"
|
| 230 |
+
use_tiled_decode: bool = True
|
| 231 |
+
|
| 232 |
+
# 5Hz LM (server-side): used for metadata completion and (when thinking=True) codes generation.
|
| 233 |
+
lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
|
| 234 |
+
lm_backend: Literal["vllm", "pt"] = "vllm"
|
| 235 |
+
|
| 236 |
+
constrained_decoding: bool = True
|
| 237 |
+
constrained_decoding_debug: bool = False
|
| 238 |
+
use_cot_caption: bool = True
|
| 239 |
+
use_cot_language: bool = True
|
| 240 |
+
is_format_caption: bool = False
|
| 241 |
+
|
| 242 |
+
lm_temperature: float = 0.85
|
| 243 |
+
lm_cfg_scale: float = 2.5
|
| 244 |
+
lm_top_k: Optional[int] = None
|
| 245 |
+
lm_top_p: Optional[float] = 0.9
|
| 246 |
+
lm_repetition_penalty: float = 1.0
|
| 247 |
+
lm_negative_prompt: str = "NO USER INPUT"
|
| 248 |
+
|
| 249 |
+
class Config:
|
| 250 |
+
allow_population_by_field_name = True
|
| 251 |
+
allow_population_by_alias = True
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class CreateJobResponse(BaseModel):
|
| 255 |
+
task_id: str
|
| 256 |
+
status: JobStatus
|
| 257 |
+
queue_position: int = 0 # 1-based best-effort position when queued
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class JobResult(BaseModel):
|
| 261 |
+
first_audio_path: Optional[str] = None
|
| 262 |
+
second_audio_path: Optional[str] = None
|
| 263 |
+
audio_paths: list[str] = Field(default_factory=list)
|
| 264 |
+
|
| 265 |
+
generation_info: str = ""
|
| 266 |
+
status_message: str = ""
|
| 267 |
+
seed_value: str = ""
|
| 268 |
+
|
| 269 |
+
metas: Dict[str, Any] = Field(default_factory=dict)
|
| 270 |
+
bpm: Optional[int] = None
|
| 271 |
+
duration: Optional[float] = None
|
| 272 |
+
genres: Optional[str] = None
|
| 273 |
+
keyscale: Optional[str] = None
|
| 274 |
+
timesignature: Optional[str] = None
|
| 275 |
+
|
| 276 |
+
# Model information
|
| 277 |
+
lm_model: Optional[str] = None
|
| 278 |
+
dit_model: Optional[str] = None
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class JobResponse(BaseModel):
|
| 282 |
+
job_id: str
|
| 283 |
+
status: JobStatus
|
| 284 |
+
created_at: float
|
| 285 |
+
started_at: Optional[float] = None
|
| 286 |
+
finished_at: Optional[float] = None
|
| 287 |
+
|
| 288 |
+
# queue observability
|
| 289 |
+
queue_position: int = 0
|
| 290 |
+
eta_seconds: Optional[float] = None
|
| 291 |
+
avg_job_seconds: Optional[float] = None
|
| 292 |
+
|
| 293 |
+
result: Optional[JobResult] = None
|
| 294 |
+
error: Optional[str] = None
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@dataclass
|
| 298 |
+
class _JobRecord:
|
| 299 |
+
job_id: str
|
| 300 |
+
status: JobStatus
|
| 301 |
+
created_at: float
|
| 302 |
+
started_at: Optional[float] = None
|
| 303 |
+
finished_at: Optional[float] = None
|
| 304 |
+
result: Optional[Dict[str, Any]] = None
|
| 305 |
+
error: Optional[str] = None
|
| 306 |
+
env: str = "development"
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class _JobStore:
|
| 310 |
+
def __init__(self) -> None:
|
| 311 |
+
self._lock = Lock()
|
| 312 |
+
self._jobs: Dict[str, _JobRecord] = {}
|
| 313 |
+
|
| 314 |
+
def create(self) -> _JobRecord:
|
| 315 |
+
job_id = str(uuid4())
|
| 316 |
+
rec = _JobRecord(job_id=job_id, status="queued", created_at=time.time())
|
| 317 |
+
with self._lock:
|
| 318 |
+
self._jobs[job_id] = rec
|
| 319 |
+
return rec
|
| 320 |
+
|
| 321 |
+
def create_with_id(self, job_id: str, env: str = "development") -> _JobRecord:
|
| 322 |
+
"""Create job record with specified ID"""
|
| 323 |
+
rec = _JobRecord(
|
| 324 |
+
job_id=job_id,
|
| 325 |
+
status="queued",
|
| 326 |
+
created_at=time.time(),
|
| 327 |
+
env=env
|
| 328 |
+
)
|
| 329 |
+
with self._lock:
|
| 330 |
+
self._jobs[job_id] = rec
|
| 331 |
+
return rec
|
| 332 |
+
|
| 333 |
+
def get(self, job_id: str) -> Optional[_JobRecord]:
|
| 334 |
+
with self._lock:
|
| 335 |
+
return self._jobs.get(job_id)
|
| 336 |
+
|
| 337 |
+
def mark_running(self, job_id: str) -> None:
|
| 338 |
+
with self._lock:
|
| 339 |
+
rec = self._jobs[job_id]
|
| 340 |
+
rec.status = "running"
|
| 341 |
+
rec.started_at = time.time()
|
| 342 |
+
|
| 343 |
+
def mark_succeeded(self, job_id: str, result: Dict[str, Any]) -> None:
|
| 344 |
+
with self._lock:
|
| 345 |
+
rec = self._jobs[job_id]
|
| 346 |
+
rec.status = "succeeded"
|
| 347 |
+
rec.finished_at = time.time()
|
| 348 |
+
rec.result = result
|
| 349 |
+
rec.error = None
|
| 350 |
+
|
| 351 |
+
def mark_failed(self, job_id: str, error: str) -> None:
|
| 352 |
+
with self._lock:
|
| 353 |
+
rec = self._jobs[job_id]
|
| 354 |
+
rec.status = "failed"
|
| 355 |
+
rec.finished_at = time.time()
|
| 356 |
+
rec.result = None
|
| 357 |
+
rec.error = error
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def _env_bool(name: str, default: bool) -> bool:
|
| 361 |
+
v = os.getenv(name)
|
| 362 |
+
if v is None:
|
| 363 |
+
return default
|
| 364 |
+
return v.strip().lower() in {"1", "true", "yes", "y", "on"}
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _get_project_root() -> str:
|
| 368 |
+
current_file = os.path.abspath(__file__)
|
| 369 |
+
return os.path.dirname(os.path.dirname(current_file))
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _get_model_name(config_path: str) -> str:
|
| 373 |
+
"""
|
| 374 |
+
Extract model name from config_path.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
config_path: Path like "acestep-v15-turbo" or "/path/to/acestep-v15-turbo"
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
Model name (last directory name from config_path)
|
| 381 |
+
"""
|
| 382 |
+
if not config_path:
|
| 383 |
+
return ""
|
| 384 |
+
normalized = config_path.rstrip("/\\")
|
| 385 |
+
return os.path.basename(normalized)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _load_project_env() -> None:
|
| 389 |
+
if load_dotenv is None:
|
| 390 |
+
return
|
| 391 |
+
try:
|
| 392 |
+
project_root = _get_project_root()
|
| 393 |
+
env_path = os.path.join(project_root, ".env")
|
| 394 |
+
if os.path.exists(env_path):
|
| 395 |
+
load_dotenv(env_path, override=False)
|
| 396 |
+
except Exception:
|
| 397 |
+
# Optional best-effort: continue even if .env loading fails.
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
_load_project_env()
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
|
| 405 |
+
if v is None:
|
| 406 |
+
return default
|
| 407 |
+
if isinstance(v, int):
|
| 408 |
+
return v
|
| 409 |
+
s = str(v).strip()
|
| 410 |
+
if s == "":
|
| 411 |
+
return default
|
| 412 |
+
try:
|
| 413 |
+
return int(s)
|
| 414 |
+
except Exception:
|
| 415 |
+
return default
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _to_float(v: Any, default: Optional[float] = None) -> Optional[float]:
|
| 419 |
+
if v is None:
|
| 420 |
+
return default
|
| 421 |
+
if isinstance(v, float):
|
| 422 |
+
return v
|
| 423 |
+
s = str(v).strip()
|
| 424 |
+
if s == "":
|
| 425 |
+
return default
|
| 426 |
+
try:
|
| 427 |
+
return float(s)
|
| 428 |
+
except Exception:
|
| 429 |
+
return default
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _to_bool(v: Any, default: bool = False) -> bool:
|
| 433 |
+
if v is None:
|
| 434 |
+
return default
|
| 435 |
+
if isinstance(v, bool):
|
| 436 |
+
return v
|
| 437 |
+
s = str(v).strip().lower()
|
| 438 |
+
if s == "":
|
| 439 |
+
return default
|
| 440 |
+
return s in {"1", "true", "yes", "y", "on"}
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _map_status(status: str) -> int:
|
| 444 |
+
"""Map job status string to integer code."""
|
| 445 |
+
return STATUS_MAP.get(status, 2)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _parse_timesteps(s: Optional[str]) -> Optional[List[float]]:
|
| 449 |
+
"""Parse comma-separated timesteps string to list of floats."""
|
| 450 |
+
if not s or not s.strip():
|
| 451 |
+
return None
|
| 452 |
+
try:
|
| 453 |
+
return [float(t.strip()) for t in s.split(",") if t.strip()]
|
| 454 |
+
except (ValueError, Exception):
|
| 455 |
+
return None
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class RequestParser:
|
| 459 |
+
"""Parse request parameters from multiple sources with alias support."""
|
| 460 |
+
|
| 461 |
+
def __init__(self, raw: dict):
|
| 462 |
+
self._raw = dict(raw) if raw else {}
|
| 463 |
+
self._param_obj = self._parse_json(self._raw.get("param_obj"))
|
| 464 |
+
self._metas = self._find_metas()
|
| 465 |
+
|
| 466 |
+
def _parse_json(self, v) -> dict:
|
| 467 |
+
if isinstance(v, dict):
|
| 468 |
+
return v
|
| 469 |
+
if isinstance(v, str) and v.strip():
|
| 470 |
+
try:
|
| 471 |
+
return json.loads(v)
|
| 472 |
+
except Exception:
|
| 473 |
+
pass
|
| 474 |
+
return {}
|
| 475 |
+
|
| 476 |
+
def _find_metas(self) -> dict:
|
| 477 |
+
for key in ("metas", "meta", "metadata", "user_metadata", "userMetadata"):
|
| 478 |
+
v = self._raw.get(key)
|
| 479 |
+
if v:
|
| 480 |
+
return self._parse_json(v)
|
| 481 |
+
return {}
|
| 482 |
+
|
| 483 |
+
def get(self, name: str, default=None):
|
| 484 |
+
"""Get parameter by canonical name from all sources."""
|
| 485 |
+
aliases = PARAM_ALIASES.get(name, [name])
|
| 486 |
+
for source in (self._raw, self._param_obj, self._metas):
|
| 487 |
+
for alias in aliases:
|
| 488 |
+
v = source.get(alias)
|
| 489 |
+
if v is not None:
|
| 490 |
+
return v
|
| 491 |
+
return default
|
| 492 |
+
|
| 493 |
+
def str(self, name: str, default: str = "") -> str:
|
| 494 |
+
v = self.get(name)
|
| 495 |
+
return str(v) if v is not None else default
|
| 496 |
+
|
| 497 |
+
def int(self, name: str, default: Optional[int] = None) -> Optional[int]:
|
| 498 |
+
return _to_int(self.get(name), default)
|
| 499 |
+
|
| 500 |
+
def float(self, name: str, default: Optional[float] = None) -> Optional[float]:
|
| 501 |
+
return _to_float(self.get(name), default)
|
| 502 |
+
|
| 503 |
+
def bool(self, name: str, default: bool = False) -> bool:
|
| 504 |
+
return _to_bool(self.get(name), default)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str:
|
| 508 |
+
suffix = Path(upload.filename or "").suffix
|
| 509 |
+
fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix)
|
| 510 |
+
os.close(fd)
|
| 511 |
+
try:
|
| 512 |
+
with open(path, "wb") as f:
|
| 513 |
+
while True:
|
| 514 |
+
chunk = await upload.read(1024 * 1024)
|
| 515 |
+
if not chunk:
|
| 516 |
+
break
|
| 517 |
+
f.write(chunk)
|
| 518 |
+
except Exception:
|
| 519 |
+
try:
|
| 520 |
+
os.remove(path)
|
| 521 |
+
except Exception:
|
| 522 |
+
pass
|
| 523 |
+
raise
|
| 524 |
+
finally:
|
| 525 |
+
try:
|
| 526 |
+
await upload.close()
|
| 527 |
+
except Exception:
|
| 528 |
+
pass
|
| 529 |
+
return path
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def create_app() -> FastAPI:
|
| 533 |
+
store = _JobStore()
|
| 534 |
+
|
| 535 |
+
QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200"))
|
| 536 |
+
WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # Single GPU recommended
|
| 537 |
+
|
| 538 |
+
INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0"))
|
| 539 |
+
AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50"))
|
| 540 |
+
|
| 541 |
+
def _path_to_audio_url(path: str) -> str:
|
| 542 |
+
"""Convert local file path to downloadable relative URL"""
|
| 543 |
+
if not path:
|
| 544 |
+
return path
|
| 545 |
+
if path.startswith("http://") or path.startswith("https://"):
|
| 546 |
+
return path
|
| 547 |
+
encoded_path = urllib.parse.quote(path, safe="")
|
| 548 |
+
return f"/v1/audio?path={encoded_path}"
|
| 549 |
+
|
| 550 |
+
@asynccontextmanager
|
| 551 |
+
async def lifespan(app: FastAPI):
|
| 552 |
+
# Clear proxy env that may affect downstream libs
|
| 553 |
+
for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
| 554 |
+
os.environ.pop(proxy_var, None)
|
| 555 |
+
|
| 556 |
+
# Ensure compilation/temp caches do not fill up small default /tmp.
|
| 557 |
+
# Triton/Inductor (and the system compiler) can create large temporary files.
|
| 558 |
+
project_root = _get_project_root()
|
| 559 |
+
cache_root = os.path.join(project_root, ".cache", "acestep")
|
| 560 |
+
tmp_root = (os.getenv("ACESTEP_TMPDIR") or os.path.join(cache_root, "tmp")).strip()
|
| 561 |
+
triton_cache_root = (os.getenv("TRITON_CACHE_DIR") or os.path.join(cache_root, "triton")).strip()
|
| 562 |
+
inductor_cache_root = (os.getenv("TORCHINDUCTOR_CACHE_DIR") or os.path.join(cache_root, "torchinductor")).strip()
|
| 563 |
+
|
| 564 |
+
for p in [cache_root, tmp_root, triton_cache_root, inductor_cache_root]:
|
| 565 |
+
try:
|
| 566 |
+
os.makedirs(p, exist_ok=True)
|
| 567 |
+
except Exception:
|
| 568 |
+
# Best-effort: do not block startup if directory creation fails.
|
| 569 |
+
pass
|
| 570 |
+
|
| 571 |
+
# Respect explicit user overrides; if ACESTEP_TMPDIR is set, it should win.
|
| 572 |
+
if os.getenv("ACESTEP_TMPDIR"):
|
| 573 |
+
os.environ["TMPDIR"] = tmp_root
|
| 574 |
+
os.environ["TEMP"] = tmp_root
|
| 575 |
+
os.environ["TMP"] = tmp_root
|
| 576 |
+
else:
|
| 577 |
+
os.environ.setdefault("TMPDIR", tmp_root)
|
| 578 |
+
os.environ.setdefault("TEMP", tmp_root)
|
| 579 |
+
os.environ.setdefault("TMP", tmp_root)
|
| 580 |
+
|
| 581 |
+
os.environ.setdefault("TRITON_CACHE_DIR", triton_cache_root)
|
| 582 |
+
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", inductor_cache_root)
|
| 583 |
+
|
| 584 |
+
handler = AceStepHandler()
|
| 585 |
+
llm_handler = LLMHandler()
|
| 586 |
+
init_lock = asyncio.Lock()
|
| 587 |
+
app.state._initialized = False
|
| 588 |
+
app.state._init_error = None
|
| 589 |
+
app.state._init_lock = init_lock
|
| 590 |
+
|
| 591 |
+
app.state.llm_handler = llm_handler
|
| 592 |
+
app.state._llm_initialized = False
|
| 593 |
+
app.state._llm_init_error = None
|
| 594 |
+
app.state._llm_init_lock = Lock()
|
| 595 |
+
|
| 596 |
+
# Multi-model support: secondary DiT handlers
|
| 597 |
+
handler2 = None
|
| 598 |
+
handler3 = None
|
| 599 |
+
config_path2 = os.getenv("ACESTEP_CONFIG_PATH2", "").strip()
|
| 600 |
+
config_path3 = os.getenv("ACESTEP_CONFIG_PATH3", "").strip()
|
| 601 |
+
|
| 602 |
+
if config_path2:
|
| 603 |
+
handler2 = AceStepHandler()
|
| 604 |
+
if config_path3:
|
| 605 |
+
handler3 = AceStepHandler()
|
| 606 |
+
|
| 607 |
+
app.state.handler2 = handler2
|
| 608 |
+
app.state.handler3 = handler3
|
| 609 |
+
app.state._initialized2 = False
|
| 610 |
+
app.state._initialized3 = False
|
| 611 |
+
app.state._config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
|
| 612 |
+
app.state._config_path2 = config_path2
|
| 613 |
+
app.state._config_path3 = config_path3
|
| 614 |
+
|
| 615 |
+
max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
|
| 616 |
+
executor = ThreadPoolExecutor(max_workers=max_workers)
|
| 617 |
+
|
| 618 |
+
# Queue & observability
|
| 619 |
+
app.state.job_queue = asyncio.Queue(maxsize=QUEUE_MAXSIZE) # (job_id, req)
|
| 620 |
+
app.state.pending_ids = deque() # queued job_ids
|
| 621 |
+
app.state.pending_lock = asyncio.Lock()
|
| 622 |
+
|
| 623 |
+
# temp files per job (from multipart uploads)
|
| 624 |
+
app.state.job_temp_files = {} # job_id -> list[path]
|
| 625 |
+
app.state.job_temp_files_lock = asyncio.Lock()
|
| 626 |
+
|
| 627 |
+
# stats
|
| 628 |
+
app.state.stats_lock = asyncio.Lock()
|
| 629 |
+
app.state.recent_durations = deque(maxlen=AVG_WINDOW)
|
| 630 |
+
app.state.avg_job_seconds = INITIAL_AVG_JOB_SECONDS
|
| 631 |
+
|
| 632 |
+
app.state.handler = handler
|
| 633 |
+
app.state.executor = executor
|
| 634 |
+
app.state.job_store = store
|
| 635 |
+
app.state._python_executable = sys.executable
|
| 636 |
+
|
| 637 |
+
# Temporary directory for saving generated audio files
|
| 638 |
+
app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
|
| 639 |
+
os.makedirs(app.state.temp_audio_dir, exist_ok=True)
|
| 640 |
+
|
| 641 |
+
# Initialize local cache
|
| 642 |
+
try:
|
| 643 |
+
from acestep.local_cache import get_local_cache
|
| 644 |
+
local_cache_dir = os.path.join(cache_root, "local_redis")
|
| 645 |
+
app.state.local_cache = get_local_cache(local_cache_dir)
|
| 646 |
+
except ImportError:
|
| 647 |
+
app.state.local_cache = None
|
| 648 |
+
|
| 649 |
+
async def _ensure_initialized() -> None:
|
| 650 |
+
h: AceStepHandler = app.state.handler
|
| 651 |
+
|
| 652 |
+
if getattr(app.state, "_initialized", False):
|
| 653 |
+
return
|
| 654 |
+
if getattr(app.state, "_init_error", None):
|
| 655 |
+
raise RuntimeError(app.state._init_error)
|
| 656 |
+
|
| 657 |
+
async with app.state._init_lock:
|
| 658 |
+
if getattr(app.state, "_initialized", False):
|
| 659 |
+
return
|
| 660 |
+
if getattr(app.state, "_init_error", None):
|
| 661 |
+
raise RuntimeError(app.state._init_error)
|
| 662 |
+
|
| 663 |
+
project_root = _get_project_root()
|
| 664 |
+
config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
|
| 665 |
+
device = os.getenv("ACESTEP_DEVICE", "auto")
|
| 666 |
+
|
| 667 |
+
use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
|
| 668 |
+
offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
|
| 669 |
+
offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
|
| 670 |
+
|
| 671 |
+
# Initialize primary model
|
| 672 |
+
status_msg, ok = h.initialize_service(
|
| 673 |
+
project_root=project_root,
|
| 674 |
+
config_path=config_path,
|
| 675 |
+
device=device,
|
| 676 |
+
use_flash_attention=use_flash_attention,
|
| 677 |
+
compile_model=False,
|
| 678 |
+
offload_to_cpu=offload_to_cpu,
|
| 679 |
+
offload_dit_to_cpu=offload_dit_to_cpu,
|
| 680 |
+
)
|
| 681 |
+
if not ok:
|
| 682 |
+
app.state._init_error = status_msg
|
| 683 |
+
raise RuntimeError(status_msg)
|
| 684 |
+
app.state._initialized = True
|
| 685 |
+
|
| 686 |
+
# Initialize secondary model if configured
|
| 687 |
+
if app.state.handler2 and app.state._config_path2:
|
| 688 |
+
try:
|
| 689 |
+
status_msg2, ok2 = app.state.handler2.initialize_service(
|
| 690 |
+
project_root=project_root,
|
| 691 |
+
config_path=app.state._config_path2,
|
| 692 |
+
device=device,
|
| 693 |
+
use_flash_attention=use_flash_attention,
|
| 694 |
+
compile_model=False,
|
| 695 |
+
offload_to_cpu=offload_to_cpu,
|
| 696 |
+
offload_dit_to_cpu=offload_dit_to_cpu,
|
| 697 |
+
)
|
| 698 |
+
app.state._initialized2 = ok2
|
| 699 |
+
if ok2:
|
| 700 |
+
print(f"[API Server] Secondary model loaded: {_get_model_name(app.state._config_path2)}")
|
| 701 |
+
else:
|
| 702 |
+
print(f"[API Server] Warning: Secondary model failed to load: {status_msg2}")
|
| 703 |
+
except Exception as e:
|
| 704 |
+
print(f"[API Server] Warning: Failed to initialize secondary model: {e}")
|
| 705 |
+
app.state._initialized2 = False
|
| 706 |
+
|
| 707 |
+
# Initialize third model if configured
|
| 708 |
+
if app.state.handler3 and app.state._config_path3:
|
| 709 |
+
try:
|
| 710 |
+
status_msg3, ok3 = app.state.handler3.initialize_service(
|
| 711 |
+
project_root=project_root,
|
| 712 |
+
config_path=app.state._config_path3,
|
| 713 |
+
device=device,
|
| 714 |
+
use_flash_attention=use_flash_attention,
|
| 715 |
+
compile_model=False,
|
| 716 |
+
offload_to_cpu=offload_to_cpu,
|
| 717 |
+
offload_dit_to_cpu=offload_dit_to_cpu,
|
| 718 |
+
)
|
| 719 |
+
app.state._initialized3 = ok3
|
| 720 |
+
if ok3:
|
| 721 |
+
print(f"[API Server] Third model loaded: {_get_model_name(app.state._config_path3)}")
|
| 722 |
+
else:
|
| 723 |
+
print(f"[API Server] Warning: Third model failed to load: {status_msg3}")
|
| 724 |
+
except Exception as e:
|
| 725 |
+
print(f"[API Server] Warning: Failed to initialize third model: {e}")
|
| 726 |
+
app.state._initialized3 = False
|
| 727 |
+
|
| 728 |
+
async def _cleanup_job_temp_files(job_id: str) -> None:
|
| 729 |
+
async with app.state.job_temp_files_lock:
|
| 730 |
+
paths = app.state.job_temp_files.pop(job_id, [])
|
| 731 |
+
for p in paths:
|
| 732 |
+
try:
|
| 733 |
+
os.remove(p)
|
| 734 |
+
except Exception:
|
| 735 |
+
pass
|
| 736 |
+
|
| 737 |
+
def _update_local_cache(job_id: str, result: Optional[Dict], status: str) -> None:
|
| 738 |
+
"""Update local cache with job result"""
|
| 739 |
+
local_cache = getattr(app.state, 'local_cache', None)
|
| 740 |
+
if not local_cache:
|
| 741 |
+
return
|
| 742 |
+
|
| 743 |
+
rec = store.get(job_id)
|
| 744 |
+
env = getattr(rec, 'env', 'development') if rec else 'development'
|
| 745 |
+
create_time = rec.created_at if rec else time.time()
|
| 746 |
+
|
| 747 |
+
status_int = _map_status(status)
|
| 748 |
+
|
| 749 |
+
if status == "succeeded" and result:
|
| 750 |
+
audio_paths = result.get("audio_paths", [])
|
| 751 |
+
# Final prompt/lyrics (may be modified by thinking/format)
|
| 752 |
+
final_prompt = result.get("prompt", "")
|
| 753 |
+
final_lyrics = result.get("lyrics", "")
|
| 754 |
+
# Original user input from metas
|
| 755 |
+
metas_raw = result.get("metas", {}) or {}
|
| 756 |
+
original_prompt = metas_raw.get("prompt", "")
|
| 757 |
+
original_lyrics = metas_raw.get("lyrics", "")
|
| 758 |
+
# metas contains original input + other metadata
|
| 759 |
+
metas = {
|
| 760 |
+
"bpm": metas_raw.get("bpm"),
|
| 761 |
+
"duration": metas_raw.get("duration"),
|
| 762 |
+
"genres": metas_raw.get("genres", ""),
|
| 763 |
+
"keyscale": metas_raw.get("keyscale", ""),
|
| 764 |
+
"timesignature": metas_raw.get("timesignature", ""),
|
| 765 |
+
"prompt": original_prompt,
|
| 766 |
+
"lyrics": original_lyrics,
|
| 767 |
+
}
|
| 768 |
+
# Extra fields for Discord bot
|
| 769 |
+
generation_info = result.get("generation_info", "")
|
| 770 |
+
seed_value = result.get("seed_value", "")
|
| 771 |
+
lm_model = result.get("lm_model", "")
|
| 772 |
+
dit_model = result.get("dit_model", "")
|
| 773 |
+
|
| 774 |
+
if audio_paths:
|
| 775 |
+
result_data = [
|
| 776 |
+
{
|
| 777 |
+
"file": p,
|
| 778 |
+
"wave": "",
|
| 779 |
+
"status": status_int,
|
| 780 |
+
"create_time": int(create_time),
|
| 781 |
+
"env": env,
|
| 782 |
+
"prompt": final_prompt,
|
| 783 |
+
"lyrics": final_lyrics,
|
| 784 |
+
"metas": metas,
|
| 785 |
+
"generation_info": generation_info,
|
| 786 |
+
"seed_value": seed_value,
|
| 787 |
+
"lm_model": lm_model,
|
| 788 |
+
"dit_model": dit_model,
|
| 789 |
+
}
|
| 790 |
+
for p in audio_paths
|
| 791 |
+
]
|
| 792 |
+
else:
|
| 793 |
+
result_data = [{
|
| 794 |
+
"file": "",
|
| 795 |
+
"wave": "",
|
| 796 |
+
"status": status_int,
|
| 797 |
+
"create_time": int(create_time),
|
| 798 |
+
"env": env,
|
| 799 |
+
"prompt": final_prompt,
|
| 800 |
+
"lyrics": final_lyrics,
|
| 801 |
+
"metas": metas,
|
| 802 |
+
"generation_info": generation_info,
|
| 803 |
+
"seed_value": seed_value,
|
| 804 |
+
"lm_model": lm_model,
|
| 805 |
+
"dit_model": dit_model,
|
| 806 |
+
}]
|
| 807 |
+
else:
|
| 808 |
+
result_data = [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
|
| 809 |
+
|
| 810 |
+
result_key = f"{RESULT_KEY_PREFIX}{job_id}"
|
| 811 |
+
local_cache.set(result_key, result_data, ex=RESULT_EXPIRE_SECONDS)
|
| 812 |
+
|
| 813 |
+
async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
|
| 814 |
+
job_store: _JobStore = app.state.job_store
|
| 815 |
+
llm: LLMHandler = app.state.llm_handler
|
| 816 |
+
executor: ThreadPoolExecutor = app.state.executor
|
| 817 |
+
|
| 818 |
+
await _ensure_initialized()
|
| 819 |
+
job_store.mark_running(job_id)
|
| 820 |
+
|
| 821 |
+
# Select DiT handler based on user's model choice
|
| 822 |
+
# Default: use primary handler
|
| 823 |
+
selected_handler: AceStepHandler = app.state.handler
|
| 824 |
+
selected_model_name = _get_model_name(app.state._config_path)
|
| 825 |
+
|
| 826 |
+
if req.model:
|
| 827 |
+
model_matched = False
|
| 828 |
+
|
| 829 |
+
# Check if it matches the second model
|
| 830 |
+
if app.state.handler2 and getattr(app.state, "_initialized2", False):
|
| 831 |
+
model2_name = _get_model_name(app.state._config_path2)
|
| 832 |
+
if req.model == model2_name:
|
| 833 |
+
selected_handler = app.state.handler2
|
| 834 |
+
selected_model_name = model2_name
|
| 835 |
+
model_matched = True
|
| 836 |
+
print(f"[API Server] Job {job_id}: Using second model: {model2_name}")
|
| 837 |
+
|
| 838 |
+
# Check if it matches the third model
|
| 839 |
+
if not model_matched and app.state.handler3 and getattr(app.state, "_initialized3", False):
|
| 840 |
+
model3_name = _get_model_name(app.state._config_path3)
|
| 841 |
+
if req.model == model3_name:
|
| 842 |
+
selected_handler = app.state.handler3
|
| 843 |
+
selected_model_name = model3_name
|
| 844 |
+
model_matched = True
|
| 845 |
+
print(f"[API Server] Job {job_id}: Using third model: {model3_name}")
|
| 846 |
+
|
| 847 |
+
if not model_matched:
|
| 848 |
+
available_models = [_get_model_name(app.state._config_path)]
|
| 849 |
+
if app.state.handler2 and getattr(app.state, "_initialized2", False):
|
| 850 |
+
available_models.append(_get_model_name(app.state._config_path2))
|
| 851 |
+
if app.state.handler3 and getattr(app.state, "_initialized3", False):
|
| 852 |
+
available_models.append(_get_model_name(app.state._config_path3))
|
| 853 |
+
print(f"[API Server] Job {job_id}: Model '{req.model}' not found in {available_models}, using primary: {selected_model_name}")
|
| 854 |
+
|
| 855 |
+
# Use selected handler for generation
|
| 856 |
+
h: AceStepHandler = selected_handler
|
| 857 |
+
|
| 858 |
+
def _blocking_generate() -> Dict[str, Any]:
|
| 859 |
+
"""Generate music using unified inference logic from acestep.inference"""
|
| 860 |
+
|
| 861 |
+
def _ensure_llm_ready() -> None:
|
| 862 |
+
"""Ensure LLM handler is initialized when needed"""
|
| 863 |
+
with app.state._llm_init_lock:
|
| 864 |
+
initialized = getattr(app.state, "_llm_initialized", False)
|
| 865 |
+
had_error = getattr(app.state, "_llm_init_error", None)
|
| 866 |
+
if initialized or had_error is not None:
|
| 867 |
+
return
|
| 868 |
+
|
| 869 |
+
project_root = _get_project_root()
|
| 870 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 871 |
+
lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
|
| 872 |
+
backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
|
| 873 |
+
if backend not in {"vllm", "pt"}:
|
| 874 |
+
backend = "vllm"
|
| 875 |
+
|
| 876 |
+
lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
|
| 877 |
+
lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
|
| 878 |
+
|
| 879 |
+
status, ok = llm.initialize(
|
| 880 |
+
checkpoint_dir=checkpoint_dir,
|
| 881 |
+
lm_model_path=lm_model_path,
|
| 882 |
+
backend=backend,
|
| 883 |
+
device=lm_device,
|
| 884 |
+
offload_to_cpu=lm_offload,
|
| 885 |
+
dtype=h.dtype,
|
| 886 |
+
)
|
| 887 |
+
if not ok:
|
| 888 |
+
app.state._llm_init_error = status
|
| 889 |
+
else:
|
| 890 |
+
app.state._llm_initialized = True
|
| 891 |
+
|
| 892 |
+
def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 893 |
+
"""Ensure a stable `metas` dict (keys always present)."""
|
| 894 |
+
meta = meta or {}
|
| 895 |
+
out: Dict[str, Any] = dict(meta)
|
| 896 |
+
|
| 897 |
+
# Normalize key aliases
|
| 898 |
+
if "keyscale" not in out and "key_scale" in out:
|
| 899 |
+
out["keyscale"] = out.get("key_scale")
|
| 900 |
+
if "timesignature" not in out and "time_signature" in out:
|
| 901 |
+
out["timesignature"] = out.get("time_signature")
|
| 902 |
+
|
| 903 |
+
# Ensure required keys exist
|
| 904 |
+
for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
|
| 905 |
+
if out.get(k) in (None, ""):
|
| 906 |
+
out[k] = "N/A"
|
| 907 |
+
return out
|
| 908 |
+
|
| 909 |
+
# Normalize LM sampling parameters
|
| 910 |
+
lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0
|
| 911 |
+
lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9
|
| 912 |
+
|
| 913 |
+
# Determine if LLM is needed
|
| 914 |
+
thinking = bool(req.thinking)
|
| 915 |
+
sample_mode = bool(req.sample_mode)
|
| 916 |
+
has_sample_query = bool(req.sample_query and req.sample_query.strip())
|
| 917 |
+
use_format = bool(req.use_format)
|
| 918 |
+
use_cot_caption = bool(req.use_cot_caption)
|
| 919 |
+
use_cot_language = bool(req.use_cot_language)
|
| 920 |
+
|
| 921 |
+
# LLM is needed for:
|
| 922 |
+
# - thinking mode (LM generates audio codes)
|
| 923 |
+
# - sample_mode (LM generates random caption/lyrics/metas)
|
| 924 |
+
# - sample_query/description (LM generates from description)
|
| 925 |
+
# - use_format (LM enhances caption/lyrics)
|
| 926 |
+
# - use_cot_caption or use_cot_language (LM enhances metadata)
|
| 927 |
+
need_llm = thinking or sample_mode or has_sample_query or use_format or use_cot_caption or use_cot_language
|
| 928 |
+
|
| 929 |
+
# Ensure LLM is ready if needed
|
| 930 |
+
if need_llm:
|
| 931 |
+
_ensure_llm_ready()
|
| 932 |
+
if getattr(app.state, "_llm_init_error", None):
|
| 933 |
+
raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
|
| 934 |
+
|
| 935 |
+
# Handle sample mode or description: generate caption/lyrics/metas via LM
|
| 936 |
+
caption = req.prompt
|
| 937 |
+
lyrics = req.lyrics
|
| 938 |
+
bpm = req.bpm
|
| 939 |
+
key_scale = req.key_scale
|
| 940 |
+
time_signature = req.time_signature
|
| 941 |
+
audio_duration = req.audio_duration
|
| 942 |
+
|
| 943 |
+
# Save original user input for metas
|
| 944 |
+
original_prompt = req.prompt or ""
|
| 945 |
+
original_lyrics = req.lyrics or ""
|
| 946 |
+
|
| 947 |
+
if sample_mode or has_sample_query:
|
| 948 |
+
if has_sample_query:
|
| 949 |
+
# Use create_sample() with description query
|
| 950 |
+
parsed_language, parsed_instrumental = _parse_description_hints(req.sample_query)
|
| 951 |
+
|
| 952 |
+
# Determine vocal_language with priority:
|
| 953 |
+
# 1. User-specified vocal_language (if not default "en")
|
| 954 |
+
# 2. Language parsed from description
|
| 955 |
+
# 3. None (no constraint)
|
| 956 |
+
if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
|
| 957 |
+
sample_language = req.vocal_language
|
| 958 |
+
else:
|
| 959 |
+
sample_language = parsed_language
|
| 960 |
+
|
| 961 |
+
sample_result = create_sample(
|
| 962 |
+
llm_handler=llm,
|
| 963 |
+
query=req.sample_query,
|
| 964 |
+
instrumental=parsed_instrumental,
|
| 965 |
+
vocal_language=sample_language,
|
| 966 |
+
temperature=req.lm_temperature,
|
| 967 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 968 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 969 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
if not sample_result.success:
|
| 973 |
+
raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}")
|
| 974 |
+
|
| 975 |
+
# Use generated sample data
|
| 976 |
+
caption = sample_result.caption
|
| 977 |
+
lyrics = sample_result.lyrics
|
| 978 |
+
bpm = sample_result.bpm
|
| 979 |
+
key_scale = sample_result.keyscale
|
| 980 |
+
time_signature = sample_result.timesignature
|
| 981 |
+
audio_duration = sample_result.duration
|
| 982 |
+
else:
|
| 983 |
+
# Original sample_mode behavior: random generation
|
| 984 |
+
sample_metadata, sample_status = llm.understand_audio_from_codes(
|
| 985 |
+
audio_codes="NO USER INPUT",
|
| 986 |
+
temperature=req.lm_temperature,
|
| 987 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 988 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 989 |
+
repetition_penalty=req.lm_repetition_penalty,
|
| 990 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 991 |
+
constrained_decoding_debug=req.constrained_decoding_debug,
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
if not sample_metadata or str(sample_status).startswith("❌"):
|
| 995 |
+
raise RuntimeError(f"Sample generation failed: {sample_status}")
|
| 996 |
+
|
| 997 |
+
# Use generated values with fallback defaults
|
| 998 |
+
caption = sample_metadata.get("caption", "")
|
| 999 |
+
lyrics = sample_metadata.get("lyrics", "")
|
| 1000 |
+
bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
|
| 1001 |
+
key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
|
| 1002 |
+
time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
|
| 1003 |
+
audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
|
| 1004 |
+
|
| 1005 |
+
# Apply format_sample() if use_format is True and caption/lyrics are provided
|
| 1006 |
+
format_has_duration = False
|
| 1007 |
+
|
| 1008 |
+
if req.use_format and (caption or lyrics):
|
| 1009 |
+
_ensure_llm_ready()
|
| 1010 |
+
if getattr(app.state, "_llm_init_error", None):
|
| 1011 |
+
raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
|
| 1012 |
+
|
| 1013 |
+
# Build user_metadata from request params (matching bot.py behavior)
|
| 1014 |
+
user_metadata_for_format = {}
|
| 1015 |
+
if bpm is not None:
|
| 1016 |
+
user_metadata_for_format['bpm'] = bpm
|
| 1017 |
+
if audio_duration is not None and audio_duration > 0:
|
| 1018 |
+
user_metadata_for_format['duration'] = int(audio_duration)
|
| 1019 |
+
if key_scale:
|
| 1020 |
+
user_metadata_for_format['keyscale'] = key_scale
|
| 1021 |
+
if time_signature:
|
| 1022 |
+
user_metadata_for_format['timesignature'] = time_signature
|
| 1023 |
+
if req.vocal_language and req.vocal_language != "unknown":
|
| 1024 |
+
user_metadata_for_format['language'] = req.vocal_language
|
| 1025 |
+
|
| 1026 |
+
format_result = format_sample(
|
| 1027 |
+
llm_handler=llm,
|
| 1028 |
+
caption=caption,
|
| 1029 |
+
lyrics=lyrics,
|
| 1030 |
+
user_metadata=user_metadata_for_format if user_metadata_for_format else None,
|
| 1031 |
+
temperature=req.lm_temperature,
|
| 1032 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 1033 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 1034 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
if format_result.success:
|
| 1038 |
+
# Extract all formatted data (matching bot.py behavior)
|
| 1039 |
+
caption = format_result.caption or caption
|
| 1040 |
+
lyrics = format_result.lyrics or lyrics
|
| 1041 |
+
if format_result.duration:
|
| 1042 |
+
audio_duration = format_result.duration
|
| 1043 |
+
format_has_duration = True
|
| 1044 |
+
if format_result.bpm:
|
| 1045 |
+
bpm = format_result.bpm
|
| 1046 |
+
if format_result.keyscale:
|
| 1047 |
+
key_scale = format_result.keyscale
|
| 1048 |
+
if format_result.timesignature:
|
| 1049 |
+
time_signature = format_result.timesignature
|
| 1050 |
+
|
| 1051 |
+
# Parse timesteps string to list of floats if provided
|
| 1052 |
+
parsed_timesteps = _parse_timesteps(req.timesteps)
|
| 1053 |
+
|
| 1054 |
+
# Determine actual inference steps (timesteps override inference_steps)
|
| 1055 |
+
actual_inference_steps = len(parsed_timesteps) if parsed_timesteps else req.inference_steps
|
| 1056 |
+
|
| 1057 |
+
# Auto-select instruction based on task_type if user didn't provide custom instruction
|
| 1058 |
+
# This matches gradio behavior which uses TASK_INSTRUCTIONS for each task type
|
| 1059 |
+
instruction_to_use = req.instruction
|
| 1060 |
+
if instruction_to_use == DEFAULT_DIT_INSTRUCTION and req.task_type in TASK_INSTRUCTIONS:
|
| 1061 |
+
instruction_to_use = TASK_INSTRUCTIONS[req.task_type]
|
| 1062 |
+
|
| 1063 |
+
# Build GenerationParams using unified interface
|
| 1064 |
+
# Note: thinking controls LM code generation, sample_mode only affects CoT metas
|
| 1065 |
+
params = GenerationParams(
|
| 1066 |
+
task_type=req.task_type,
|
| 1067 |
+
instruction=instruction_to_use,
|
| 1068 |
+
reference_audio=req.reference_audio_path,
|
| 1069 |
+
src_audio=req.src_audio_path,
|
| 1070 |
+
audio_codes=req.audio_code_string,
|
| 1071 |
+
caption=caption,
|
| 1072 |
+
lyrics=lyrics,
|
| 1073 |
+
instrumental=False,
|
| 1074 |
+
vocal_language=req.vocal_language,
|
| 1075 |
+
bpm=bpm,
|
| 1076 |
+
keyscale=key_scale,
|
| 1077 |
+
timesignature=time_signature,
|
| 1078 |
+
duration=audio_duration if audio_duration else -1.0,
|
| 1079 |
+
inference_steps=req.inference_steps,
|
| 1080 |
+
seed=req.seed,
|
| 1081 |
+
guidance_scale=req.guidance_scale,
|
| 1082 |
+
use_adg=req.use_adg,
|
| 1083 |
+
cfg_interval_start=req.cfg_interval_start,
|
| 1084 |
+
cfg_interval_end=req.cfg_interval_end,
|
| 1085 |
+
shift=req.shift,
|
| 1086 |
+
infer_method=req.infer_method,
|
| 1087 |
+
timesteps=parsed_timesteps,
|
| 1088 |
+
repainting_start=req.repainting_start,
|
| 1089 |
+
repainting_end=req.repainting_end if req.repainting_end else -1,
|
| 1090 |
+
audio_cover_strength=req.audio_cover_strength,
|
| 1091 |
+
# LM parameters
|
| 1092 |
+
thinking=thinking, # Use LM for code generation when thinking=True
|
| 1093 |
+
lm_temperature=req.lm_temperature,
|
| 1094 |
+
lm_cfg_scale=req.lm_cfg_scale,
|
| 1095 |
+
lm_top_k=lm_top_k,
|
| 1096 |
+
lm_top_p=lm_top_p,
|
| 1097 |
+
lm_negative_prompt=req.lm_negative_prompt,
|
| 1098 |
+
# use_cot_metas logic:
|
| 1099 |
+
# - sample_mode: metas already generated, skip Phase 1
|
| 1100 |
+
# - format with duration: metas already generated, skip Phase 1
|
| 1101 |
+
# - format without duration: need Phase 1 to generate duration
|
| 1102 |
+
# - no format: need Phase 1 to generate all metas
|
| 1103 |
+
use_cot_metas=not sample_mode and not format_has_duration,
|
| 1104 |
+
use_cot_caption=req.use_cot_caption,
|
| 1105 |
+
use_cot_language=req.use_cot_language,
|
| 1106 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
# Build GenerationConfig - default to 2 audios like gradio_ui
|
| 1110 |
+
batch_size = req.batch_size if req.batch_size is not None else 2
|
| 1111 |
+
config = GenerationConfig(
|
| 1112 |
+
batch_size=batch_size,
|
| 1113 |
+
use_random_seed=req.use_random_seed,
|
| 1114 |
+
seeds=None, # Let unified logic handle seed generation
|
| 1115 |
+
audio_format=req.audio_format,
|
| 1116 |
+
constrained_decoding_debug=req.constrained_decoding_debug,
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
# Check LLM initialization status
|
| 1120 |
+
llm_is_initialized = getattr(app.state, "_llm_initialized", False)
|
| 1121 |
+
llm_to_pass = llm if llm_is_initialized else None
|
| 1122 |
+
|
| 1123 |
+
# Generate music using unified interface
|
| 1124 |
+
result = generate_music(
|
| 1125 |
+
dit_handler=h,
|
| 1126 |
+
llm_handler=llm_to_pass,
|
| 1127 |
+
params=params,
|
| 1128 |
+
config=config,
|
| 1129 |
+
save_dir=app.state.temp_audio_dir,
|
| 1130 |
+
progress=None,
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
if not result.success:
|
| 1134 |
+
raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
|
| 1135 |
+
|
| 1136 |
+
# Extract results
|
| 1137 |
+
audio_paths = [audio["path"] for audio in result.audios if audio.get("path")]
|
| 1138 |
+
first_audio = audio_paths[0] if len(audio_paths) > 0 else None
|
| 1139 |
+
second_audio = audio_paths[1] if len(audio_paths) > 1 else None
|
| 1140 |
+
|
| 1141 |
+
# Get metadata from LM or CoT results
|
| 1142 |
+
lm_metadata = result.extra_outputs.get("lm_metadata", {})
|
| 1143 |
+
metas_out = _normalize_metas(lm_metadata)
|
| 1144 |
+
|
| 1145 |
+
# Update metas with actual values used
|
| 1146 |
+
if params.cot_bpm:
|
| 1147 |
+
metas_out["bpm"] = params.cot_bpm
|
| 1148 |
+
elif bpm:
|
| 1149 |
+
metas_out["bpm"] = bpm
|
| 1150 |
+
|
| 1151 |
+
if params.cot_duration:
|
| 1152 |
+
metas_out["duration"] = params.cot_duration
|
| 1153 |
+
elif audio_duration:
|
| 1154 |
+
metas_out["duration"] = audio_duration
|
| 1155 |
+
|
| 1156 |
+
if params.cot_keyscale:
|
| 1157 |
+
metas_out["keyscale"] = params.cot_keyscale
|
| 1158 |
+
elif key_scale:
|
| 1159 |
+
metas_out["keyscale"] = key_scale
|
| 1160 |
+
|
| 1161 |
+
if params.cot_timesignature:
|
| 1162 |
+
metas_out["timesignature"] = params.cot_timesignature
|
| 1163 |
+
elif time_signature:
|
| 1164 |
+
metas_out["timesignature"] = time_signature
|
| 1165 |
+
|
| 1166 |
+
# Store original user input in metas (not the final/modified values)
|
| 1167 |
+
metas_out["prompt"] = original_prompt
|
| 1168 |
+
metas_out["lyrics"] = original_lyrics
|
| 1169 |
+
|
| 1170 |
+
# Extract seed values for response (comma-separated for multiple audios)
|
| 1171 |
+
seed_values = []
|
| 1172 |
+
for audio in result.audios:
|
| 1173 |
+
audio_params = audio.get("params", {})
|
| 1174 |
+
seed = audio_params.get("seed")
|
| 1175 |
+
if seed is not None:
|
| 1176 |
+
seed_values.append(str(seed))
|
| 1177 |
+
seed_value = ",".join(seed_values) if seed_values else ""
|
| 1178 |
+
|
| 1179 |
+
# Build generation_info using the helper function (like gradio_ui)
|
| 1180 |
+
time_costs = result.extra_outputs.get("time_costs", {})
|
| 1181 |
+
generation_info = _build_generation_info(
|
| 1182 |
+
lm_metadata=lm_metadata,
|
| 1183 |
+
time_costs=time_costs,
|
| 1184 |
+
seed_value=seed_value,
|
| 1185 |
+
inference_steps=req.inference_steps,
|
| 1186 |
+
num_audios=len(result.audios),
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
def _none_if_na_str(v: Any) -> Optional[str]:
|
| 1190 |
+
if v is None:
|
| 1191 |
+
return None
|
| 1192 |
+
s = str(v).strip()
|
| 1193 |
+
if s in {"", "N/A"}:
|
| 1194 |
+
return None
|
| 1195 |
+
return s
|
| 1196 |
+
|
| 1197 |
+
# Get model information
|
| 1198 |
+
lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B")
|
| 1199 |
+
# Use selected_model_name (set at the beginning of _run_one_job)
|
| 1200 |
+
dit_model_name = selected_model_name
|
| 1201 |
+
|
| 1202 |
+
return {
|
| 1203 |
+
"first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
|
| 1204 |
+
"second_audio_path": _path_to_audio_url(second_audio) if second_audio else None,
|
| 1205 |
+
"audio_paths": [_path_to_audio_url(p) for p in audio_paths],
|
| 1206 |
+
"generation_info": generation_info,
|
| 1207 |
+
"status_message": result.status_message,
|
| 1208 |
+
"seed_value": seed_value,
|
| 1209 |
+
# Final prompt/lyrics (may be modified by thinking/format)
|
| 1210 |
+
"prompt": caption or "",
|
| 1211 |
+
"lyrics": lyrics or "",
|
| 1212 |
+
# metas contains original user input + other metadata
|
| 1213 |
+
"metas": metas_out,
|
| 1214 |
+
"bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None,
|
| 1215 |
+
"duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None,
|
| 1216 |
+
"genres": _none_if_na_str(metas_out.get("genres")),
|
| 1217 |
+
"keyscale": _none_if_na_str(metas_out.get("keyscale")),
|
| 1218 |
+
"timesignature": _none_if_na_str(metas_out.get("timesignature")),
|
| 1219 |
+
"lm_model": lm_model_name,
|
| 1220 |
+
"dit_model": dit_model_name,
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
t0 = time.time()
|
| 1224 |
+
try:
|
| 1225 |
+
loop = asyncio.get_running_loop()
|
| 1226 |
+
result = await loop.run_in_executor(executor, _blocking_generate)
|
| 1227 |
+
job_store.mark_succeeded(job_id, result)
|
| 1228 |
+
|
| 1229 |
+
# Update local cache
|
| 1230 |
+
_update_local_cache(job_id, result, "succeeded")
|
| 1231 |
+
except Exception:
|
| 1232 |
+
job_store.mark_failed(job_id, traceback.format_exc())
|
| 1233 |
+
|
| 1234 |
+
# Update local cache
|
| 1235 |
+
_update_local_cache(job_id, None, "failed")
|
| 1236 |
+
finally:
|
| 1237 |
+
dt = max(0.0, time.time() - t0)
|
| 1238 |
+
async with app.state.stats_lock:
|
| 1239 |
+
app.state.recent_durations.append(dt)
|
| 1240 |
+
if app.state.recent_durations:
|
| 1241 |
+
app.state.avg_job_seconds = sum(app.state.recent_durations) / len(app.state.recent_durations)
|
| 1242 |
+
|
| 1243 |
+
async def _queue_worker(worker_idx: int) -> None:
|
| 1244 |
+
while True:
|
| 1245 |
+
job_id, req = await app.state.job_queue.get()
|
| 1246 |
+
try:
|
| 1247 |
+
async with app.state.pending_lock:
|
| 1248 |
+
try:
|
| 1249 |
+
app.state.pending_ids.remove(job_id)
|
| 1250 |
+
except ValueError:
|
| 1251 |
+
pass
|
| 1252 |
+
|
| 1253 |
+
await _run_one_job(job_id, req)
|
| 1254 |
+
finally:
|
| 1255 |
+
await _cleanup_job_temp_files(job_id)
|
| 1256 |
+
app.state.job_queue.task_done()
|
| 1257 |
+
|
| 1258 |
+
worker_count = max(1, WORKER_COUNT)
|
| 1259 |
+
workers = [asyncio.create_task(_queue_worker(i)) for i in range(worker_count)]
|
| 1260 |
+
app.state.worker_tasks = workers
|
| 1261 |
+
|
| 1262 |
+
try:
|
| 1263 |
+
yield
|
| 1264 |
+
finally:
|
| 1265 |
+
for t in workers:
|
| 1266 |
+
t.cancel()
|
| 1267 |
+
executor.shutdown(wait=False, cancel_futures=True)
|
| 1268 |
+
|
| 1269 |
+
app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan)
|
| 1270 |
+
|
| 1271 |
+
async def _queue_position(job_id: str) -> int:
|
| 1272 |
+
async with app.state.pending_lock:
|
| 1273 |
+
try:
|
| 1274 |
+
return list(app.state.pending_ids).index(job_id) + 1
|
| 1275 |
+
except ValueError:
|
| 1276 |
+
return 0
|
| 1277 |
+
|
| 1278 |
+
async def _eta_seconds_for_position(pos: int) -> Optional[float]:
|
| 1279 |
+
if pos <= 0:
|
| 1280 |
+
return None
|
| 1281 |
+
async with app.state.stats_lock:
|
| 1282 |
+
avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
|
| 1283 |
+
return pos * avg
|
| 1284 |
+
|
| 1285 |
+
@app.post("/release_task", response_model=CreateJobResponse)
|
| 1286 |
+
async def create_music_generate_job(request: Request) -> CreateJobResponse:
|
| 1287 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 1288 |
+
temp_files: list[str] = []
|
| 1289 |
+
|
| 1290 |
+
def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest:
|
| 1291 |
+
"""Build GenerateMusicRequest from parsed parameters."""
|
| 1292 |
+
return GenerateMusicRequest(
|
| 1293 |
+
prompt=p.str("prompt"),
|
| 1294 |
+
lyrics=p.str("lyrics"),
|
| 1295 |
+
thinking=p.bool("thinking"),
|
| 1296 |
+
sample_mode=p.bool("sample_mode"),
|
| 1297 |
+
sample_query=p.str("sample_query"),
|
| 1298 |
+
use_format=p.bool("use_format"),
|
| 1299 |
+
model=p.str("model") or None,
|
| 1300 |
+
bpm=p.int("bpm"),
|
| 1301 |
+
key_scale=p.str("key_scale"),
|
| 1302 |
+
time_signature=p.str("time_signature"),
|
| 1303 |
+
audio_duration=p.float("audio_duration"),
|
| 1304 |
+
vocal_language=p.str("vocal_language", "en"),
|
| 1305 |
+
inference_steps=p.int("inference_steps", 8),
|
| 1306 |
+
guidance_scale=p.float("guidance_scale", 7.0),
|
| 1307 |
+
use_random_seed=p.bool("use_random_seed", True),
|
| 1308 |
+
seed=p.int("seed", -1),
|
| 1309 |
+
batch_size=p.int("batch_size"),
|
| 1310 |
+
audio_code_string=p.str("audio_code_string"),
|
| 1311 |
+
repainting_start=p.float("repainting_start", 0.0),
|
| 1312 |
+
repainting_end=p.float("repainting_end"),
|
| 1313 |
+
instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION),
|
| 1314 |
+
audio_cover_strength=p.float("audio_cover_strength", 1.0),
|
| 1315 |
+
task_type=p.str("task_type", "text2music"),
|
| 1316 |
+
use_adg=p.bool("use_adg"),
|
| 1317 |
+
cfg_interval_start=p.float("cfg_interval_start", 0.0),
|
| 1318 |
+
cfg_interval_end=p.float("cfg_interval_end", 1.0),
|
| 1319 |
+
infer_method=p.str("infer_method", "ode"),
|
| 1320 |
+
shift=p.float("shift", 3.0),
|
| 1321 |
+
audio_format=p.str("audio_format", "mp3"),
|
| 1322 |
+
use_tiled_decode=p.bool("use_tiled_decode", True),
|
| 1323 |
+
lm_model_path=p.str("lm_model_path") or None,
|
| 1324 |
+
lm_backend=p.str("lm_backend", "vllm"),
|
| 1325 |
+
lm_temperature=p.float("lm_temperature", LM_DEFAULT_TEMPERATURE),
|
| 1326 |
+
lm_cfg_scale=p.float("lm_cfg_scale", LM_DEFAULT_CFG_SCALE),
|
| 1327 |
+
lm_top_k=p.int("lm_top_k"),
|
| 1328 |
+
lm_top_p=p.float("lm_top_p", LM_DEFAULT_TOP_P),
|
| 1329 |
+
lm_repetition_penalty=p.float("lm_repetition_penalty", 1.0),
|
| 1330 |
+
lm_negative_prompt=p.str("lm_negative_prompt", "NO USER INPUT"),
|
| 1331 |
+
constrained_decoding=p.bool("constrained_decoding", True),
|
| 1332 |
+
constrained_decoding_debug=p.bool("constrained_decoding_debug"),
|
| 1333 |
+
use_cot_caption=p.bool("use_cot_caption", True),
|
| 1334 |
+
use_cot_language=p.bool("use_cot_language", True),
|
| 1335 |
+
is_format_caption=p.bool("is_format_caption"),
|
| 1336 |
+
**kwargs,
|
| 1337 |
+
)
|
| 1338 |
+
|
| 1339 |
+
if content_type.startswith("application/json"):
|
| 1340 |
+
body = await request.json()
|
| 1341 |
+
if not isinstance(body, dict):
|
| 1342 |
+
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 1343 |
+
req = _build_request(RequestParser(body))
|
| 1344 |
+
|
| 1345 |
+
elif content_type.endswith("+json"):
|
| 1346 |
+
body = await request.json()
|
| 1347 |
+
if not isinstance(body, dict):
|
| 1348 |
+
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 1349 |
+
req = _build_request(RequestParser(body))
|
| 1350 |
+
|
| 1351 |
+
elif content_type.startswith("multipart/form-data"):
|
| 1352 |
+
form = await request.form()
|
| 1353 |
+
|
| 1354 |
+
ref_up = form.get("reference_audio")
|
| 1355 |
+
src_up = form.get("src_audio")
|
| 1356 |
+
|
| 1357 |
+
reference_audio_path = None
|
| 1358 |
+
src_audio_path = None
|
| 1359 |
+
|
| 1360 |
+
if isinstance(ref_up, StarletteUploadFile):
|
| 1361 |
+
reference_audio_path = await _save_upload_to_temp(ref_up, prefix="reference_audio")
|
| 1362 |
+
temp_files.append(reference_audio_path)
|
| 1363 |
+
else:
|
| 1364 |
+
reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
|
| 1365 |
+
|
| 1366 |
+
if isinstance(src_up, StarletteUploadFile):
|
| 1367 |
+
src_audio_path = await _save_upload_to_temp(src_up, prefix="src_audio")
|
| 1368 |
+
temp_files.append(src_audio_path)
|
| 1369 |
+
else:
|
| 1370 |
+
src_audio_path = str(form.get("src_audio_path") or "").strip() or None
|
| 1371 |
+
|
| 1372 |
+
req = _build_request(
|
| 1373 |
+
RequestParser(dict(form)),
|
| 1374 |
+
reference_audio_path=reference_audio_path,
|
| 1375 |
+
src_audio_path=src_audio_path,
|
| 1376 |
+
)
|
| 1377 |
+
|
| 1378 |
+
elif content_type.startswith("application/x-www-form-urlencoded"):
|
| 1379 |
+
form = await request.form()
|
| 1380 |
+
reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
|
| 1381 |
+
src_audio_path = str(form.get("src_audio_path") or "").strip() or None
|
| 1382 |
+
req = _build_request(
|
| 1383 |
+
RequestParser(dict(form)),
|
| 1384 |
+
reference_audio_path=reference_audio_path,
|
| 1385 |
+
src_audio_path=src_audio_path,
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
else:
|
| 1389 |
+
raw = await request.body()
|
| 1390 |
+
raw_stripped = raw.lstrip()
|
| 1391 |
+
# Best-effort: accept missing/incorrect Content-Type if payload is valid JSON.
|
| 1392 |
+
if raw_stripped.startswith(b"{") or raw_stripped.startswith(b"["):
|
| 1393 |
+
try:
|
| 1394 |
+
body = json.loads(raw.decode("utf-8"))
|
| 1395 |
+
if isinstance(body, dict):
|
| 1396 |
+
req = _build_request(RequestParser(body))
|
| 1397 |
+
else:
|
| 1398 |
+
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 1399 |
+
except HTTPException:
|
| 1400 |
+
raise
|
| 1401 |
+
except Exception:
|
| 1402 |
+
raise HTTPException(
|
| 1403 |
+
status_code=400,
|
| 1404 |
+
detail="Invalid JSON body (hint: set 'Content-Type: application/json')",
|
| 1405 |
+
)
|
| 1406 |
+
# Best-effort: parse key=value bodies even if Content-Type is missing.
|
| 1407 |
+
elif raw_stripped and b"=" in raw:
|
| 1408 |
+
parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
|
| 1409 |
+
flat = {k: (v[0] if isinstance(v, list) and v else v) for k, v in parsed.items()}
|
| 1410 |
+
reference_audio_path = str(flat.get("reference_audio_path") or "").strip() or None
|
| 1411 |
+
src_audio_path = str(flat.get("src_audio_path") or "").strip() or None
|
| 1412 |
+
req = _build_request(
|
| 1413 |
+
RequestParser(flat),
|
| 1414 |
+
reference_audio_path=reference_audio_path,
|
| 1415 |
+
src_audio_path=src_audio_path,
|
| 1416 |
+
)
|
| 1417 |
+
else:
|
| 1418 |
+
raise HTTPException(
|
| 1419 |
+
status_code=415,
|
| 1420 |
+
detail=(
|
| 1421 |
+
f"Unsupported Content-Type: {content_type or '(missing)'}; "
|
| 1422 |
+
"use application/json, application/x-www-form-urlencoded, or multipart/form-data"
|
| 1423 |
+
),
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
rec = store.create()
|
| 1427 |
+
|
| 1428 |
+
q: asyncio.Queue = app.state.job_queue
|
| 1429 |
+
if q.full():
|
| 1430 |
+
for p in temp_files:
|
| 1431 |
+
try:
|
| 1432 |
+
os.remove(p)
|
| 1433 |
+
except Exception:
|
| 1434 |
+
pass
|
| 1435 |
+
raise HTTPException(status_code=429, detail="Server busy: queue is full")
|
| 1436 |
+
|
| 1437 |
+
if temp_files:
|
| 1438 |
+
async with app.state.job_temp_files_lock:
|
| 1439 |
+
app.state.job_temp_files[rec.job_id] = temp_files
|
| 1440 |
+
|
| 1441 |
+
async with app.state.pending_lock:
|
| 1442 |
+
app.state.pending_ids.append(rec.job_id)
|
| 1443 |
+
position = len(app.state.pending_ids)
|
| 1444 |
+
|
| 1445 |
+
await q.put((rec.job_id, req))
|
| 1446 |
+
return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
|
| 1447 |
+
|
| 1448 |
+
@app.post("/v1/music/random", response_model=CreateJobResponse)
|
| 1449 |
+
async def create_random_sample_job(request: Request) -> CreateJobResponse:
|
| 1450 |
+
"""Create a sample-mode job that auto-generates caption/lyrics via LM."""
|
| 1451 |
+
|
| 1452 |
+
thinking_value: Any = None
|
| 1453 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 1454 |
+
body_dict: Dict[str, Any] = {}
|
| 1455 |
+
|
| 1456 |
+
if "json" in content_type:
|
| 1457 |
+
try:
|
| 1458 |
+
payload = await request.json()
|
| 1459 |
+
if isinstance(payload, dict):
|
| 1460 |
+
body_dict = payload
|
| 1461 |
+
except Exception:
|
| 1462 |
+
body_dict = {}
|
| 1463 |
+
|
| 1464 |
+
if not body_dict and request.query_params:
|
| 1465 |
+
body_dict = dict(request.query_params)
|
| 1466 |
+
|
| 1467 |
+
thinking_value = body_dict.get("thinking")
|
| 1468 |
+
if thinking_value is None:
|
| 1469 |
+
thinking_value = body_dict.get("Thinking")
|
| 1470 |
+
|
| 1471 |
+
thinking_flag = _to_bool(thinking_value, True)
|
| 1472 |
+
|
| 1473 |
+
req = GenerateMusicRequest(
|
| 1474 |
+
caption="",
|
| 1475 |
+
lyrics="",
|
| 1476 |
+
thinking=thinking_flag,
|
| 1477 |
+
sample_mode=True,
|
| 1478 |
+
)
|
| 1479 |
+
|
| 1480 |
+
rec = store.create()
|
| 1481 |
+
q: asyncio.Queue = app.state.job_queue
|
| 1482 |
+
if q.full():
|
| 1483 |
+
raise HTTPException(status_code=429, detail="Server busy: queue is full")
|
| 1484 |
+
|
| 1485 |
+
async with app.state.pending_lock:
|
| 1486 |
+
app.state.pending_ids.append(rec.job_id)
|
| 1487 |
+
position = len(app.state.pending_ids)
|
| 1488 |
+
|
| 1489 |
+
await q.put((rec.job_id, req))
|
| 1490 |
+
return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
|
| 1491 |
+
|
| 1492 |
+
@app.post("/query_result")
|
| 1493 |
+
async def query_result(request: Request) -> List[Dict[str, Any]]:
|
| 1494 |
+
"""Batch query job results"""
|
| 1495 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 1496 |
+
|
| 1497 |
+
if "json" in content_type:
|
| 1498 |
+
body = await request.json()
|
| 1499 |
+
else:
|
| 1500 |
+
form = await request.form()
|
| 1501 |
+
body = {k: v for k, v in form.items()}
|
| 1502 |
+
|
| 1503 |
+
task_id_list_str = body.get("task_id_list", "[]")
|
| 1504 |
+
|
| 1505 |
+
# Parse task ID list
|
| 1506 |
+
if isinstance(task_id_list_str, list):
|
| 1507 |
+
task_id_list = task_id_list_str
|
| 1508 |
+
else:
|
| 1509 |
+
try:
|
| 1510 |
+
task_id_list = json.loads(task_id_list_str)
|
| 1511 |
+
except Exception:
|
| 1512 |
+
task_id_list = []
|
| 1513 |
+
|
| 1514 |
+
local_cache = getattr(app.state, 'local_cache', None)
|
| 1515 |
+
data_list = []
|
| 1516 |
+
current_time = time.time()
|
| 1517 |
+
|
| 1518 |
+
for task_id in task_id_list:
|
| 1519 |
+
result_key = f"{RESULT_KEY_PREFIX}{task_id}"
|
| 1520 |
+
|
| 1521 |
+
# Read from local cache first
|
| 1522 |
+
if local_cache:
|
| 1523 |
+
data = local_cache.get(result_key)
|
| 1524 |
+
if data:
|
| 1525 |
+
try:
|
| 1526 |
+
data_json = json.loads(data)
|
| 1527 |
+
except Exception:
|
| 1528 |
+
data_json = []
|
| 1529 |
+
|
| 1530 |
+
if len(data_json) <= 0:
|
| 1531 |
+
data_list.append({"task_id": task_id, "result": data, "status": 2})
|
| 1532 |
+
else:
|
| 1533 |
+
status = data_json[0].get("status")
|
| 1534 |
+
create_time = data_json[0].get("create_time", 0)
|
| 1535 |
+
if status == 0 and (current_time - create_time) > TASK_TIMEOUT_SECONDS:
|
| 1536 |
+
data_list.append({"task_id": task_id, "result": data, "status": 2})
|
| 1537 |
+
else:
|
| 1538 |
+
data_list.append({
|
| 1539 |
+
"task_id": task_id,
|
| 1540 |
+
"result": data,
|
| 1541 |
+
"status": int(status) if status is not None else 1,
|
| 1542 |
+
})
|
| 1543 |
+
continue
|
| 1544 |
+
|
| 1545 |
+
# Fallback to job_store query
|
| 1546 |
+
rec = store.get(task_id)
|
| 1547 |
+
if rec:
|
| 1548 |
+
env = getattr(rec, 'env', 'development')
|
| 1549 |
+
create_time = rec.created_at
|
| 1550 |
+
status_int = _map_status(rec.status)
|
| 1551 |
+
|
| 1552 |
+
if rec.result and rec.status == "succeeded":
|
| 1553 |
+
audio_paths = rec.result.get("audio_paths", [])
|
| 1554 |
+
metas = rec.result.get("metas", {}) or {}
|
| 1555 |
+
result_data = [
|
| 1556 |
+
{
|
| 1557 |
+
"file": p, "wave": "", "status": status_int,
|
| 1558 |
+
"create_time": int(create_time), "env": env,
|
| 1559 |
+
"prompt": metas.get("caption", ""),
|
| 1560 |
+
"lyrics": metas.get("lyrics", ""),
|
| 1561 |
+
"metas": {
|
| 1562 |
+
"bpm": metas.get("bpm"),
|
| 1563 |
+
"duration": metas.get("duration"),
|
| 1564 |
+
"genres": metas.get("genres", ""),
|
| 1565 |
+
"keyscale": metas.get("keyscale", ""),
|
| 1566 |
+
"timesignature": metas.get("timesignature", ""),
|
| 1567 |
+
}
|
| 1568 |
+
}
|
| 1569 |
+
for p in audio_paths
|
| 1570 |
+
] if audio_paths else [{
|
| 1571 |
+
"file": "", "wave": "", "status": status_int,
|
| 1572 |
+
"create_time": int(create_time), "env": env,
|
| 1573 |
+
"prompt": metas.get("caption", ""),
|
| 1574 |
+
"lyrics": metas.get("lyrics", ""),
|
| 1575 |
+
"metas": {
|
| 1576 |
+
"bpm": metas.get("bpm"),
|
| 1577 |
+
"duration": metas.get("duration"),
|
| 1578 |
+
"genres": metas.get("genres", ""),
|
| 1579 |
+
"keyscale": metas.get("keyscale", ""),
|
| 1580 |
+
"timesignature": metas.get("timesignature", ""),
|
| 1581 |
+
}
|
| 1582 |
+
}]
|
| 1583 |
+
else:
|
| 1584 |
+
result_data = [{
|
| 1585 |
+
"file": "", "wave": "", "status": status_int,
|
| 1586 |
+
"create_time": int(create_time), "env": env,
|
| 1587 |
+
"prompt": "", "lyrics": "",
|
| 1588 |
+
"metas": {}
|
| 1589 |
+
}]
|
| 1590 |
+
|
| 1591 |
+
data_list.append({
|
| 1592 |
+
"task_id": task_id,
|
| 1593 |
+
"result": json.dumps(result_data, ensure_ascii=False),
|
| 1594 |
+
"status": status_int,
|
| 1595 |
+
})
|
| 1596 |
+
else:
|
| 1597 |
+
data_list.append({"task_id": task_id, "result": "[]", "status": 0})
|
| 1598 |
+
|
| 1599 |
+
return data_list
|
| 1600 |
+
|
| 1601 |
+
@app.get("/health")
|
| 1602 |
+
async def health_check():
|
| 1603 |
+
"""Health check endpoint for service status."""
|
| 1604 |
+
return {
|
| 1605 |
+
"status": "ok",
|
| 1606 |
+
"service": "ACE-Step API",
|
| 1607 |
+
"version": "1.0",
|
| 1608 |
+
}
|
| 1609 |
+
|
| 1610 |
+
@app.get("/v1/models")
|
| 1611 |
+
async def list_models():
|
| 1612 |
+
"""List available DiT models."""
|
| 1613 |
+
models = []
|
| 1614 |
+
|
| 1615 |
+
# Primary model (always available if initialized)
|
| 1616 |
+
if getattr(app.state, "_initialized", False):
|
| 1617 |
+
primary_model = _get_model_name(app.state._config_path)
|
| 1618 |
+
if primary_model:
|
| 1619 |
+
models.append({
|
| 1620 |
+
"name": primary_model,
|
| 1621 |
+
"is_default": True,
|
| 1622 |
+
})
|
| 1623 |
+
|
| 1624 |
+
# Secondary model
|
| 1625 |
+
if getattr(app.state, "_initialized2", False) and app.state._config_path2:
|
| 1626 |
+
secondary_model = _get_model_name(app.state._config_path2)
|
| 1627 |
+
if secondary_model:
|
| 1628 |
+
models.append({
|
| 1629 |
+
"name": secondary_model,
|
| 1630 |
+
"is_default": False,
|
| 1631 |
+
})
|
| 1632 |
+
|
| 1633 |
+
# Third model
|
| 1634 |
+
if getattr(app.state, "_initialized3", False) and app.state._config_path3:
|
| 1635 |
+
third_model = _get_model_name(app.state._config_path3)
|
| 1636 |
+
if third_model:
|
| 1637 |
+
models.append({
|
| 1638 |
+
"name": third_model,
|
| 1639 |
+
"is_default": False,
|
| 1640 |
+
})
|
| 1641 |
+
|
| 1642 |
+
return {
|
| 1643 |
+
"models": models,
|
| 1644 |
+
"default_model": models[0]["name"] if models else None,
|
| 1645 |
+
}
|
| 1646 |
+
|
| 1647 |
+
@app.get("/v1/audio")
|
| 1648 |
+
async def get_audio(path: str):
|
| 1649 |
+
"""Serve audio file by path."""
|
| 1650 |
+
from fastapi.responses import FileResponse
|
| 1651 |
+
|
| 1652 |
+
if not os.path.exists(path):
|
| 1653 |
+
raise HTTPException(status_code=404, detail=f"Audio file not found: {path}")
|
| 1654 |
+
|
| 1655 |
+
ext = os.path.splitext(path)[1].lower()
|
| 1656 |
+
media_types = {
|
| 1657 |
+
".mp3": "audio/mpeg",
|
| 1658 |
+
".wav": "audio/wav",
|
| 1659 |
+
".flac": "audio/flac",
|
| 1660 |
+
".ogg": "audio/ogg",
|
| 1661 |
+
}
|
| 1662 |
+
media_type = media_types.get(ext, "audio/mpeg")
|
| 1663 |
+
|
| 1664 |
+
return FileResponse(path, media_type=media_type)
|
| 1665 |
+
|
| 1666 |
+
return app
|
| 1667 |
+
|
| 1668 |
+
|
| 1669 |
+
app = create_app()
|
| 1670 |
+
|
| 1671 |
+
|
| 1672 |
+
def main() -> None:
|
| 1673 |
+
import argparse
|
| 1674 |
+
import uvicorn
|
| 1675 |
+
|
| 1676 |
+
parser = argparse.ArgumentParser(description="ACE-Step API server")
|
| 1677 |
+
parser.add_argument(
|
| 1678 |
+
"--host",
|
| 1679 |
+
default=os.getenv("ACESTEP_API_HOST", "127.0.0.1"),
|
| 1680 |
+
help="Bind host (default from ACESTEP_API_HOST or 127.0.0.1)",
|
| 1681 |
+
)
|
| 1682 |
+
parser.add_argument(
|
| 1683 |
+
"--port",
|
| 1684 |
+
type=int,
|
| 1685 |
+
default=int(os.getenv("ACESTEP_API_PORT", "8001")),
|
| 1686 |
+
help="Bind port (default from ACESTEP_API_PORT or 8001)",
|
| 1687 |
+
)
|
| 1688 |
+
args = parser.parse_args()
|
| 1689 |
+
|
| 1690 |
+
# IMPORTANT: in-memory queue/store -> workers MUST be 1
|
| 1691 |
+
uvicorn.run(
|
| 1692 |
+
"acestep.api_server:app",
|
| 1693 |
+
host=str(args.host),
|
| 1694 |
+
port=int(args.port),
|
| 1695 |
+
reload=False,
|
| 1696 |
+
workers=1,
|
| 1697 |
+
)
|
| 1698 |
+
|
| 1699 |
+
if __name__ == "__main__":
|
| 1700 |
+
main()
|
acestep/audio_utils.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio saving and transcoding utility module
|
| 3 |
+
|
| 4 |
+
Independent audio file operations outside of handler, supporting:
|
| 5 |
+
- Save audio tensor/numpy to files (default FLAC format, fast)
|
| 6 |
+
- Format conversion (FLAC/WAV/MP3)
|
| 7 |
+
- Batch processing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Disable torchcodec backend to avoid CUDA dependency issues on HuggingFace Space
|
| 13 |
+
# This forces torchaudio to use ffmpeg/sox/soundfile backends instead
|
| 14 |
+
os.environ["TORCHAUDIO_USE_TORCHCODEC"] = "0"
|
| 15 |
+
|
| 16 |
+
import hashlib
|
| 17 |
+
import json
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Union, Optional, List, Tuple
|
| 20 |
+
import torch
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torchaudio
|
| 23 |
+
from loguru import logger
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AudioSaver:
|
| 27 |
+
"""Audio saving and transcoding utility class"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, default_format: str = "flac"):
|
| 30 |
+
"""
|
| 31 |
+
Initialize audio saver
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
default_format: Default save format ('flac', 'wav', 'mp3')
|
| 35 |
+
"""
|
| 36 |
+
self.default_format = default_format.lower()
|
| 37 |
+
if self.default_format not in ["flac", "wav", "mp3"]:
|
| 38 |
+
logger.warning(f"Unsupported format {default_format}, using 'flac'")
|
| 39 |
+
self.default_format = "flac"
|
| 40 |
+
|
| 41 |
+
def save_audio(
|
| 42 |
+
self,
|
| 43 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 44 |
+
output_path: Union[str, Path],
|
| 45 |
+
sample_rate: int = 48000,
|
| 46 |
+
format: Optional[str] = None,
|
| 47 |
+
channels_first: bool = True,
|
| 48 |
+
) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Save audio data to file
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
|
| 54 |
+
output_path: Output file path (extension can be omitted)
|
| 55 |
+
sample_rate: Sample rate
|
| 56 |
+
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
|
| 57 |
+
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Actual saved file path
|
| 61 |
+
"""
|
| 62 |
+
format = (format or self.default_format).lower()
|
| 63 |
+
if format not in ["flac", "wav", "mp3"]:
|
| 64 |
+
logger.warning(f"Unsupported format {format}, using {self.default_format}")
|
| 65 |
+
format = self.default_format
|
| 66 |
+
|
| 67 |
+
# Ensure output path has correct extension
|
| 68 |
+
output_path = Path(output_path)
|
| 69 |
+
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
|
| 70 |
+
output_path = output_path.with_suffix(f'.{format}')
|
| 71 |
+
|
| 72 |
+
# Convert to torch tensor
|
| 73 |
+
if isinstance(audio_data, np.ndarray):
|
| 74 |
+
if channels_first:
|
| 75 |
+
# numpy [samples, channels] -> tensor [channels, samples]
|
| 76 |
+
audio_tensor = torch.from_numpy(audio_data.T).float()
|
| 77 |
+
else:
|
| 78 |
+
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
|
| 79 |
+
audio_tensor = torch.from_numpy(audio_data).float()
|
| 80 |
+
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
|
| 81 |
+
audio_tensor = audio_tensor.T
|
| 82 |
+
else:
|
| 83 |
+
# torch tensor
|
| 84 |
+
audio_tensor = audio_data.cpu().float()
|
| 85 |
+
if not channels_first and audio_tensor.dim() == 2:
|
| 86 |
+
# [samples, channels] -> [channels, samples]
|
| 87 |
+
if audio_tensor.shape[0] > audio_tensor.shape[1]:
|
| 88 |
+
audio_tensor = audio_tensor.T
|
| 89 |
+
|
| 90 |
+
# Ensure memory is contiguous
|
| 91 |
+
audio_tensor = audio_tensor.contiguous()
|
| 92 |
+
|
| 93 |
+
# Select backend and save
|
| 94 |
+
try:
|
| 95 |
+
if format == "mp3":
|
| 96 |
+
# MP3 uses ffmpeg backend
|
| 97 |
+
torchaudio.save(
|
| 98 |
+
str(output_path),
|
| 99 |
+
audio_tensor,
|
| 100 |
+
sample_rate,
|
| 101 |
+
channels_first=True,
|
| 102 |
+
backend='ffmpeg',
|
| 103 |
+
)
|
| 104 |
+
elif format in ["flac", "wav"]:
|
| 105 |
+
# FLAC and WAV use soundfile backend (fastest)
|
| 106 |
+
torchaudio.save(
|
| 107 |
+
str(output_path),
|
| 108 |
+
audio_tensor,
|
| 109 |
+
sample_rate,
|
| 110 |
+
channels_first=True,
|
| 111 |
+
backend='soundfile',
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
# Other formats use default backend
|
| 115 |
+
torchaudio.save(
|
| 116 |
+
str(output_path),
|
| 117 |
+
audio_tensor,
|
| 118 |
+
sample_rate,
|
| 119 |
+
channels_first=True,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 123 |
+
return str(output_path)
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
try:
|
| 127 |
+
import soundfile as sf
|
| 128 |
+
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
|
| 129 |
+
sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
|
| 130 |
+
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 131 |
+
return str(output_path)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.error(f"[AudioSaver] Failed to save audio: {e}")
|
| 134 |
+
raise
|
| 135 |
+
|
| 136 |
+
def _load_audio_file(self, audio_file: Union[str, Path]) -> Tuple[torch.Tensor, int]:
|
| 137 |
+
"""
|
| 138 |
+
Load audio file with ffmpeg backend, fallback to soundfile if failed.
|
| 139 |
+
|
| 140 |
+
This handles CUDA dependency issues with torchcodec on HuggingFace Space.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
audio_file: Path to the audio file
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tuple of (audio_tensor, sample_rate)
|
| 147 |
+
|
| 148 |
+
Raises:
|
| 149 |
+
FileNotFoundError: If the audio file doesn't exist
|
| 150 |
+
Exception: If all methods fail to load the audio
|
| 151 |
+
"""
|
| 152 |
+
audio_file = str(audio_file)
|
| 153 |
+
|
| 154 |
+
# Check if file exists first
|
| 155 |
+
if not Path(audio_file).exists():
|
| 156 |
+
raise FileNotFoundError(f"Audio file not found: {audio_file}")
|
| 157 |
+
|
| 158 |
+
# Try torchaudio with explicit ffmpeg backend first
|
| 159 |
+
try:
|
| 160 |
+
audio, sr = torchaudio.load(audio_file, backend="ffmpeg")
|
| 161 |
+
return audio, sr
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.debug(f"[AudioSaver._load_audio_file] ffmpeg backend failed: {e}, trying soundfile fallback")
|
| 164 |
+
|
| 165 |
+
# Fallback: use soundfile directly (most compatible)
|
| 166 |
+
try:
|
| 167 |
+
import soundfile as sf
|
| 168 |
+
audio_np, sr = sf.read(audio_file)
|
| 169 |
+
# soundfile returns [samples, channels] or [samples], convert to [channels, samples]
|
| 170 |
+
audio = torch.from_numpy(audio_np).float()
|
| 171 |
+
if audio.dim() == 1:
|
| 172 |
+
# Mono: [samples] -> [1, samples]
|
| 173 |
+
audio = audio.unsqueeze(0)
|
| 174 |
+
else:
|
| 175 |
+
# Stereo: [samples, channels] -> [channels, samples]
|
| 176 |
+
audio = audio.T
|
| 177 |
+
return audio, sr
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(f"[AudioSaver._load_audio_file] All methods failed to load audio: {audio_file}, error: {e}")
|
| 180 |
+
raise
|
| 181 |
+
|
| 182 |
+
def convert_audio(
|
| 183 |
+
self,
|
| 184 |
+
input_path: Union[str, Path],
|
| 185 |
+
output_path: Union[str, Path],
|
| 186 |
+
output_format: str,
|
| 187 |
+
remove_input: bool = False,
|
| 188 |
+
) -> str:
|
| 189 |
+
"""
|
| 190 |
+
Convert audio format
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
input_path: Input audio file path
|
| 194 |
+
output_path: Output audio file path
|
| 195 |
+
output_format: Target format ('flac', 'wav', 'mp3')
|
| 196 |
+
remove_input: Whether to delete input file
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Output file path
|
| 200 |
+
"""
|
| 201 |
+
input_path = Path(input_path)
|
| 202 |
+
output_path = Path(output_path)
|
| 203 |
+
|
| 204 |
+
if not input_path.exists():
|
| 205 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
| 206 |
+
|
| 207 |
+
# Load audio with fallback backends
|
| 208 |
+
audio_tensor, sample_rate = self._load_audio_file(input_path)
|
| 209 |
+
|
| 210 |
+
# Save as new format
|
| 211 |
+
output_path = self.save_audio(
|
| 212 |
+
audio_tensor,
|
| 213 |
+
output_path,
|
| 214 |
+
sample_rate=sample_rate,
|
| 215 |
+
format=output_format,
|
| 216 |
+
channels_first=True
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Delete input file if needed
|
| 220 |
+
if remove_input:
|
| 221 |
+
input_path.unlink()
|
| 222 |
+
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
|
| 223 |
+
|
| 224 |
+
return output_path
|
| 225 |
+
|
| 226 |
+
def save_batch(
|
| 227 |
+
self,
|
| 228 |
+
audio_batch: Union[List[torch.Tensor], torch.Tensor],
|
| 229 |
+
output_dir: Union[str, Path],
|
| 230 |
+
file_prefix: str = "audio",
|
| 231 |
+
sample_rate: int = 48000,
|
| 232 |
+
format: Optional[str] = None,
|
| 233 |
+
channels_first: bool = True,
|
| 234 |
+
) -> List[str]:
|
| 235 |
+
"""
|
| 236 |
+
Save audio batch
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
|
| 240 |
+
output_dir: Output directory
|
| 241 |
+
file_prefix: File prefix
|
| 242 |
+
sample_rate: Sample rate
|
| 243 |
+
format: Audio format
|
| 244 |
+
channels_first: Tensor format flag
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
List of saved file paths
|
| 248 |
+
"""
|
| 249 |
+
output_dir = Path(output_dir)
|
| 250 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 251 |
+
|
| 252 |
+
# Process batch
|
| 253 |
+
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
|
| 254 |
+
# [batch, channels, samples]
|
| 255 |
+
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
|
| 256 |
+
elif isinstance(audio_batch, list):
|
| 257 |
+
audio_list = audio_batch
|
| 258 |
+
else:
|
| 259 |
+
audio_list = [audio_batch]
|
| 260 |
+
|
| 261 |
+
saved_paths = []
|
| 262 |
+
for i, audio in enumerate(audio_list):
|
| 263 |
+
output_path = output_dir / f"{file_prefix}_{i:04d}"
|
| 264 |
+
saved_path = self.save_audio(
|
| 265 |
+
audio,
|
| 266 |
+
output_path,
|
| 267 |
+
sample_rate=sample_rate,
|
| 268 |
+
format=format,
|
| 269 |
+
channels_first=channels_first
|
| 270 |
+
)
|
| 271 |
+
saved_paths.append(saved_path)
|
| 272 |
+
|
| 273 |
+
return saved_paths
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_audio_file_hash(audio_file) -> str:
|
| 277 |
+
"""
|
| 278 |
+
Get hash identifier for an audio file.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
audio_file: Path to audio file (str) or file-like object
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Hash string or empty string
|
| 285 |
+
"""
|
| 286 |
+
if audio_file is None:
|
| 287 |
+
return ""
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
if isinstance(audio_file, str):
|
| 291 |
+
if os.path.exists(audio_file):
|
| 292 |
+
with open(audio_file, 'rb') as f:
|
| 293 |
+
return hashlib.md5(f.read()).hexdigest()
|
| 294 |
+
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
|
| 295 |
+
elif hasattr(audio_file, 'name'):
|
| 296 |
+
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
|
| 297 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 298 |
+
except Exception:
|
| 299 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def generate_uuid_from_params(params_dict) -> str:
|
| 303 |
+
"""
|
| 304 |
+
Generate deterministic UUID from generation parameters.
|
| 305 |
+
Same parameters will always generate the same UUID.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
params_dict: Dictionary of parameters
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
UUID string
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
| 315 |
+
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
| 316 |
+
hash_hex = hash_obj.hexdigest()
|
| 317 |
+
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
|
| 318 |
+
return uuid_str
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def generate_uuid_from_audio_data(
|
| 322 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 323 |
+
seed: Optional[int] = None
|
| 324 |
+
) -> str:
|
| 325 |
+
"""
|
| 326 |
+
Generate UUID from audio data (for caching/deduplication)
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
audio_data: Audio data
|
| 330 |
+
seed: Optional seed value
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
UUID string
|
| 334 |
+
"""
|
| 335 |
+
if isinstance(audio_data, torch.Tensor):
|
| 336 |
+
# Convert to numpy and calculate hash
|
| 337 |
+
audio_np = audio_data.cpu().numpy()
|
| 338 |
+
else:
|
| 339 |
+
audio_np = audio_data
|
| 340 |
+
|
| 341 |
+
# Calculate data hash
|
| 342 |
+
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
|
| 343 |
+
|
| 344 |
+
if seed is not None:
|
| 345 |
+
combined = f"{data_hash}_{seed}"
|
| 346 |
+
return hashlib.md5(combined.encode()).hexdigest()
|
| 347 |
+
|
| 348 |
+
return data_hash
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# Global default instance
|
| 352 |
+
_default_saver = AudioSaver(default_format="flac")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def save_audio(
|
| 356 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 357 |
+
output_path: Union[str, Path],
|
| 358 |
+
sample_rate: int = 48000,
|
| 359 |
+
format: Optional[str] = None,
|
| 360 |
+
channels_first: bool = True,
|
| 361 |
+
) -> str:
|
| 362 |
+
"""
|
| 363 |
+
Convenience function: save audio (using default configuration)
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
audio_data: Audio data
|
| 367 |
+
output_path: Output path
|
| 368 |
+
sample_rate: Sample rate
|
| 369 |
+
format: Format (default flac)
|
| 370 |
+
channels_first: Tensor format flag
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
Saved file path
|
| 374 |
+
"""
|
| 375 |
+
return _default_saver.save_audio(
|
| 376 |
+
audio_data, output_path, sample_rate, format, channels_first
|
| 377 |
+
)
|
| 378 |
+
|
acestep/constants.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Constants for ACE-Step
|
| 3 |
+
Centralized constants used across the codebase
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ==============================================================================
|
| 7 |
+
# Language Constants
|
| 8 |
+
# ==============================================================================
|
| 9 |
+
|
| 10 |
+
VALID_LANGUAGES = [
|
| 11 |
+
'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
|
| 12 |
+
'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
|
| 13 |
+
'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
|
| 14 |
+
'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
|
| 15 |
+
'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
|
| 16 |
+
'unknown'
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ==============================================================================
|
| 21 |
+
# Keyscale Constants
|
| 22 |
+
# ==============================================================================
|
| 23 |
+
|
| 24 |
+
KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
|
| 25 |
+
KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
|
| 26 |
+
KEYSCALE_MODES = ['major', 'minor']
|
| 27 |
+
|
| 28 |
+
# Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
|
| 29 |
+
VALID_KEYSCALES = set()
|
| 30 |
+
for note in KEYSCALE_NOTES:
|
| 31 |
+
for acc in KEYSCALE_ACCIDENTALS:
|
| 32 |
+
for mode in KEYSCALE_MODES:
|
| 33 |
+
VALID_KEYSCALES.add(f"{note}{acc} {mode}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ==============================================================================
|
| 37 |
+
# Metadata Range Constants
|
| 38 |
+
# ==============================================================================
|
| 39 |
+
|
| 40 |
+
# BPM (Beats Per Minute) range
|
| 41 |
+
BPM_MIN = 30
|
| 42 |
+
BPM_MAX = 300
|
| 43 |
+
|
| 44 |
+
# Duration range (in seconds)
|
| 45 |
+
DURATION_MIN = 10
|
| 46 |
+
DURATION_MAX = 600
|
| 47 |
+
|
| 48 |
+
# Valid time signatures
|
| 49 |
+
VALID_TIME_SIGNATURES = [2, 3, 4, 6]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ==============================================================================
|
| 53 |
+
# Task Type Constants
|
| 54 |
+
# ==============================================================================
|
| 55 |
+
|
| 56 |
+
TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 57 |
+
|
| 58 |
+
# Task types available for turbo models (subset)
|
| 59 |
+
TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
|
| 60 |
+
|
| 61 |
+
# Task types available for base models (full set)
|
| 62 |
+
TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ==============================================================================
|
| 66 |
+
# Instruction Constants
|
| 67 |
+
# ==============================================================================
|
| 68 |
+
|
| 69 |
+
# Default instructions
|
| 70 |
+
DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
| 71 |
+
DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
|
| 72 |
+
DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
|
| 73 |
+
DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
|
| 74 |
+
DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
|
| 75 |
+
|
| 76 |
+
# Instruction templates for each task type
|
| 77 |
+
# Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
|
| 78 |
+
# These should be formatted using .format() or f-strings when used
|
| 79 |
+
TASK_INSTRUCTIONS = {
|
| 80 |
+
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
| 81 |
+
"repaint": "Repaint the mask area based on the given conditions:",
|
| 82 |
+
"cover": "Generate audio semantic tokens based on the given conditions:",
|
| 83 |
+
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
| 84 |
+
"extract_default": "Extract the track from the audio:",
|
| 85 |
+
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
| 86 |
+
"lego_default": "Generate the track based on the audio context:",
|
| 87 |
+
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
| 88 |
+
"complete_default": "Complete the input track:",
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ==============================================================================
|
| 93 |
+
# Track/Instrument Constants
|
| 94 |
+
# ==============================================================================
|
| 95 |
+
|
| 96 |
+
TRACK_NAMES = [
|
| 97 |
+
"woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 98 |
+
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
SFT_GEN_PROMPT = """# Instruction
|
| 102 |
+
{}
|
| 103 |
+
|
| 104 |
+
# Caption
|
| 105 |
+
{}
|
| 106 |
+
|
| 107 |
+
# Metas
|
| 108 |
+
{}<|endoftext|>
|
| 109 |
+
"""
|
acestep/constrained_logits_processor.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/dataset_handler.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Handler
|
| 3 |
+
Handles dataset import and exploration functionality
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional, Tuple, Any, Dict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DatasetHandler:
|
| 9 |
+
"""Dataset Handler for Dataset Explorer functionality"""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
"""Initialize dataset handler"""
|
| 13 |
+
self.dataset = None
|
| 14 |
+
self.dataset_imported = False
|
| 15 |
+
|
| 16 |
+
def import_dataset(self, dataset_type: str) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Import dataset (temporarily disabled)
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
dataset_type: Type of dataset to import (e.g., "train", "test")
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Status message string
|
| 25 |
+
"""
|
| 26 |
+
self.dataset_imported = False
|
| 27 |
+
return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
|
| 28 |
+
|
| 29 |
+
def get_item_data(self, *args, **kwargs) -> Tuple:
|
| 30 |
+
"""
|
| 31 |
+
Get dataset item (temporarily disabled)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Tuple of placeholder values matching the expected return format
|
| 35 |
+
"""
|
| 36 |
+
return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
|
| 37 |
+
|
acestep/dit_alignment_score.py
ADDED
|
@@ -0,0 +1,870 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DiT Alignment Score Module
|
| 3 |
+
|
| 4 |
+
This module provides lyrics-to-audio alignment using cross-attention matrices
|
| 5 |
+
from DiT model for generating LRC timestamps.
|
| 6 |
+
|
| 7 |
+
Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
|
| 8 |
+
"""
|
| 9 |
+
import numba
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from dataclasses import dataclass, asdict
|
| 14 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ================= Data Classes =================
|
| 18 |
+
@dataclass
|
| 19 |
+
class TokenTimestamp:
|
| 20 |
+
"""Stores per-token timing information."""
|
| 21 |
+
token_id: int
|
| 22 |
+
text: str
|
| 23 |
+
start: float
|
| 24 |
+
end: float
|
| 25 |
+
probability: float
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class SentenceTimestamp:
|
| 30 |
+
"""Stores per-sentence timing information with token list."""
|
| 31 |
+
text: str
|
| 32 |
+
start: float
|
| 33 |
+
end: float
|
| 34 |
+
tokens: List[TokenTimestamp]
|
| 35 |
+
confidence: float
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ================= DTW Algorithm (Numba Optimized) =================
|
| 39 |
+
@numba.jit(nopython=True)
|
| 40 |
+
def dtw_cpu(x: np.ndarray):
|
| 41 |
+
"""
|
| 42 |
+
Dynamic Time Warping algorithm optimized with Numba.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
x: Cost matrix of shape [N, M]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (text_indices, time_indices) arrays
|
| 49 |
+
"""
|
| 50 |
+
N, M = x.shape
|
| 51 |
+
# Use float32 for memory efficiency
|
| 52 |
+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
| 53 |
+
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
| 54 |
+
cost[0, 0] = 0
|
| 55 |
+
|
| 56 |
+
for j in range(1, M + 1):
|
| 57 |
+
for i in range(1, N + 1):
|
| 58 |
+
c0 = cost[i - 1, j - 1]
|
| 59 |
+
c1 = cost[i - 1, j]
|
| 60 |
+
c2 = cost[i, j - 1]
|
| 61 |
+
|
| 62 |
+
if c0 < c1 and c0 < c2:
|
| 63 |
+
c, t = c0, 0
|
| 64 |
+
elif c1 < c0 and c1 < c2:
|
| 65 |
+
c, t = c1, 1
|
| 66 |
+
else:
|
| 67 |
+
c, t = c2, 2
|
| 68 |
+
|
| 69 |
+
cost[i, j] = x[i - 1, j - 1] + c
|
| 70 |
+
trace[i, j] = t
|
| 71 |
+
|
| 72 |
+
return _backtrace(trace, N, M)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@numba.jit(nopython=True)
|
| 76 |
+
def _backtrace(trace: np.ndarray, N: int, M: int):
|
| 77 |
+
"""
|
| 78 |
+
Optimized backtrace function for DTW.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
trace: Trace matrix of shape (N+1, M+1)
|
| 82 |
+
N, M: Original matrix dimensions
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Path array of shape (2, path_len) - first row is text indices, second is time indices
|
| 86 |
+
"""
|
| 87 |
+
# Boundary handling
|
| 88 |
+
trace[0, :] = 2
|
| 89 |
+
trace[:, 0] = 1
|
| 90 |
+
|
| 91 |
+
# Pre-allocate array, max path length is N+M
|
| 92 |
+
max_path_len = N + M
|
| 93 |
+
path = np.zeros((2, max_path_len), dtype=np.int32)
|
| 94 |
+
|
| 95 |
+
i, j = N, M
|
| 96 |
+
path_idx = max_path_len - 1
|
| 97 |
+
|
| 98 |
+
while i > 0 or j > 0:
|
| 99 |
+
path[0, path_idx] = i - 1 # text index
|
| 100 |
+
path[1, path_idx] = j - 1 # time index
|
| 101 |
+
path_idx -= 1
|
| 102 |
+
|
| 103 |
+
t = trace[i, j]
|
| 104 |
+
if t == 0:
|
| 105 |
+
i -= 1
|
| 106 |
+
j -= 1
|
| 107 |
+
elif t == 1:
|
| 108 |
+
i -= 1
|
| 109 |
+
elif t == 2:
|
| 110 |
+
j -= 1
|
| 111 |
+
else:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
actual_len = max_path_len - path_idx - 1
|
| 115 |
+
return path[:, path_idx + 1:max_path_len]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ================= Utility Functions =================
|
| 119 |
+
def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
|
| 120 |
+
"""
|
| 121 |
+
Apply median filter to tensor.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: Input tensor
|
| 125 |
+
filter_width: Width of median filter
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Filtered tensor
|
| 129 |
+
"""
|
| 130 |
+
pad_width = filter_width // 2
|
| 131 |
+
if x.shape[-1] <= pad_width:
|
| 132 |
+
return x
|
| 133 |
+
if x.ndim == 2:
|
| 134 |
+
x = x[None, :]
|
| 135 |
+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
| 136 |
+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
| 137 |
+
if result.ndim > 2:
|
| 138 |
+
result = result.squeeze(0)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ================= Main Aligner Class =================
|
| 143 |
+
class MusicStampsAligner:
|
| 144 |
+
"""
|
| 145 |
+
Aligner class for generating lyrics timestamps from cross-attention matrices.
|
| 146 |
+
|
| 147 |
+
Uses bidirectional consensus denoising and DTW for alignment.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, tokenizer):
|
| 151 |
+
"""
|
| 152 |
+
Initialize the aligner.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
tokenizer: Text tokenizer for decoding tokens
|
| 156 |
+
"""
|
| 157 |
+
self.tokenizer = tokenizer
|
| 158 |
+
|
| 159 |
+
def _apply_bidirectional_consensus(
|
| 160 |
+
self,
|
| 161 |
+
weights_stack: torch.Tensor,
|
| 162 |
+
violence_level: float,
|
| 163 |
+
medfilt_width: int
|
| 164 |
+
) -> tuple:
|
| 165 |
+
"""
|
| 166 |
+
Core denoising logic using bidirectional consensus.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
weights_stack: Attention weights [Heads, Tokens, Frames]
|
| 170 |
+
violence_level: Denoising strength coefficient
|
| 171 |
+
medfilt_width: Median filter width
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Tuple of (calc_matrix, energy_matrix) as numpy arrays
|
| 175 |
+
"""
|
| 176 |
+
# A. Bidirectional Consensus
|
| 177 |
+
row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
|
| 178 |
+
col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
|
| 179 |
+
processed = row_prob * col_prob
|
| 180 |
+
|
| 181 |
+
# 1. Row suppression (kill horizontal crossing lines)
|
| 182 |
+
row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
|
| 183 |
+
processed = processed - (violence_level * row_medians)
|
| 184 |
+
processed = torch.relu(processed)
|
| 185 |
+
|
| 186 |
+
# 2. Column suppression (kill vertical crossing lines)
|
| 187 |
+
col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
|
| 188 |
+
processed = processed - (violence_level * col_medians)
|
| 189 |
+
processed = torch.relu(processed)
|
| 190 |
+
|
| 191 |
+
# C. Power sharpening
|
| 192 |
+
processed = processed ** 2
|
| 193 |
+
|
| 194 |
+
# Energy matrix for confidence
|
| 195 |
+
energy_matrix = processed.mean(dim=0).cpu().numpy()
|
| 196 |
+
|
| 197 |
+
# D. Z-Score normalization
|
| 198 |
+
std, mean = torch.std_mean(processed, unbiased=False)
|
| 199 |
+
weights_processed = (processed - mean) / (std + 1e-9)
|
| 200 |
+
|
| 201 |
+
# E. Median filtering
|
| 202 |
+
weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
|
| 203 |
+
calc_matrix = weights_processed.mean(dim=0).numpy()
|
| 204 |
+
|
| 205 |
+
return calc_matrix, energy_matrix
|
| 206 |
+
|
| 207 |
+
def _preprocess_attention(
|
| 208 |
+
self,
|
| 209 |
+
attention_matrix: torch.Tensor,
|
| 210 |
+
custom_config: Dict[int, List[int]],
|
| 211 |
+
violence_level: float,
|
| 212 |
+
medfilt_width: int = 7
|
| 213 |
+
) -> tuple:
|
| 214 |
+
"""
|
| 215 |
+
Preprocess attention matrix for alignment.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
|
| 219 |
+
custom_config: Dict mapping layer indices to head indices
|
| 220 |
+
violence_level: Denoising strength
|
| 221 |
+
medfilt_width: Median filter width
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Tuple of (calc_matrix, energy_matrix, visual_matrix)
|
| 225 |
+
"""
|
| 226 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 227 |
+
weights = torch.tensor(attention_matrix)
|
| 228 |
+
else:
|
| 229 |
+
weights = attention_matrix.clone()
|
| 230 |
+
|
| 231 |
+
weights = weights.cpu().float()
|
| 232 |
+
|
| 233 |
+
selected_tensors = []
|
| 234 |
+
for layer_idx, head_indices in custom_config.items():
|
| 235 |
+
for head_idx in head_indices:
|
| 236 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 237 |
+
head_matrix = weights[layer_idx, head_idx]
|
| 238 |
+
selected_tensors.append(head_matrix)
|
| 239 |
+
|
| 240 |
+
if not selected_tensors:
|
| 241 |
+
return None, None, None
|
| 242 |
+
|
| 243 |
+
# Stack selected heads: [Heads, Tokens, Frames]
|
| 244 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 245 |
+
visual_matrix = weights_stack.mean(dim=0).numpy()
|
| 246 |
+
|
| 247 |
+
calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
|
| 248 |
+
weights_stack, violence_level, medfilt_width
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return calc_matrix, energy_matrix, visual_matrix
|
| 252 |
+
|
| 253 |
+
def stamps_align_info(
|
| 254 |
+
self,
|
| 255 |
+
attention_matrix: torch.Tensor,
|
| 256 |
+
lyrics_tokens: List[int],
|
| 257 |
+
total_duration_seconds: float,
|
| 258 |
+
custom_config: Dict[int, List[int]],
|
| 259 |
+
return_matrices: bool = False,
|
| 260 |
+
violence_level: float = 2.0,
|
| 261 |
+
medfilt_width: int = 1
|
| 262 |
+
) -> Dict[str, Any]:
|
| 263 |
+
"""
|
| 264 |
+
Get alignment information from attention matrix.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
|
| 268 |
+
lyrics_tokens: List of lyrics token IDs
|
| 269 |
+
total_duration_seconds: Total audio duration in seconds
|
| 270 |
+
custom_config: Dict mapping layer indices to head indices
|
| 271 |
+
return_matrices: Whether to return intermediate matrices
|
| 272 |
+
violence_level: Denoising strength
|
| 273 |
+
medfilt_width: Median filter width
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
|
| 277 |
+
and optionally energy_matrix and vis_matrix
|
| 278 |
+
"""
|
| 279 |
+
calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
|
| 280 |
+
attention_matrix, custom_config, violence_level, medfilt_width
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if calc_matrix is None:
|
| 284 |
+
return {
|
| 285 |
+
"calc_matrix": None,
|
| 286 |
+
"lyrics_tokens": lyrics_tokens,
|
| 287 |
+
"total_duration_seconds": total_duration_seconds,
|
| 288 |
+
"error": "No valid attention heads found"
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
return_dict = {
|
| 292 |
+
"calc_matrix": calc_matrix,
|
| 293 |
+
"lyrics_tokens": lyrics_tokens,
|
| 294 |
+
"total_duration_seconds": total_duration_seconds
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
if return_matrices:
|
| 298 |
+
return_dict['energy_matrix'] = energy_matrix
|
| 299 |
+
return_dict['vis_matrix'] = visual_matrix
|
| 300 |
+
|
| 301 |
+
return return_dict
|
| 302 |
+
|
| 303 |
+
def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
|
| 304 |
+
"""
|
| 305 |
+
Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
|
| 306 |
+
|
| 307 |
+
For Chinese and other multi-byte characters, the tokenizer may split them
|
| 308 |
+
into multiple byte-level tokens. Decoding each token individually produces
|
| 309 |
+
invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
|
| 310 |
+
to correctly track which characters each token contributes.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
token_ids: List of token IDs
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
List of decoded text for each token position
|
| 317 |
+
"""
|
| 318 |
+
decoded_tokens = []
|
| 319 |
+
prev_bytes = b""
|
| 320 |
+
|
| 321 |
+
for i in range(len(token_ids)):
|
| 322 |
+
# Decode tokens from start to current position
|
| 323 |
+
current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
|
| 324 |
+
current_bytes = current_text.encode('utf-8', errors='surrogatepass')
|
| 325 |
+
|
| 326 |
+
# The contribution of current token is the new bytes added
|
| 327 |
+
if len(current_bytes) >= len(prev_bytes):
|
| 328 |
+
new_bytes = current_bytes[len(prev_bytes):]
|
| 329 |
+
# Try to decode the new bytes; if incomplete, use empty string
|
| 330 |
+
try:
|
| 331 |
+
token_text = new_bytes.decode('utf-8')
|
| 332 |
+
except UnicodeDecodeError:
|
| 333 |
+
# Incomplete UTF-8 sequence, this token doesn't complete a character
|
| 334 |
+
token_text = ""
|
| 335 |
+
else:
|
| 336 |
+
# Edge case: current decode is shorter (shouldn't happen normally)
|
| 337 |
+
token_text = ""
|
| 338 |
+
|
| 339 |
+
decoded_tokens.append(token_text)
|
| 340 |
+
prev_bytes = current_bytes
|
| 341 |
+
|
| 342 |
+
return decoded_tokens
|
| 343 |
+
|
| 344 |
+
def token_timestamps(
|
| 345 |
+
self,
|
| 346 |
+
calc_matrix: np.ndarray,
|
| 347 |
+
lyrics_tokens: List[int],
|
| 348 |
+
total_duration_seconds: float
|
| 349 |
+
) -> List[TokenTimestamp]:
|
| 350 |
+
"""
|
| 351 |
+
Generate per-token timestamps using DTW.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
calc_matrix: Processed attention matrix [Tokens, Frames]
|
| 355 |
+
lyrics_tokens: List of token IDs
|
| 356 |
+
total_duration_seconds: Total audio duration
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
List of TokenTimestamp objects
|
| 360 |
+
"""
|
| 361 |
+
n_frames = calc_matrix.shape[-1]
|
| 362 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
|
| 363 |
+
|
| 364 |
+
seconds_per_frame = total_duration_seconds / n_frames
|
| 365 |
+
alignment_results = []
|
| 366 |
+
|
| 367 |
+
# Use incremental decoding to properly handle multi-byte UTF-8 characters
|
| 368 |
+
decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
|
| 369 |
+
|
| 370 |
+
for i in range(len(lyrics_tokens)):
|
| 371 |
+
mask = (text_indices == i)
|
| 372 |
+
|
| 373 |
+
if not np.any(mask):
|
| 374 |
+
start = alignment_results[-1].end if alignment_results else 0.0
|
| 375 |
+
end = start
|
| 376 |
+
token_conf = 0.0
|
| 377 |
+
else:
|
| 378 |
+
times = time_indices[mask] * seconds_per_frame
|
| 379 |
+
start = times[0]
|
| 380 |
+
end = times[-1]
|
| 381 |
+
token_conf = 0.0
|
| 382 |
+
|
| 383 |
+
if end < start:
|
| 384 |
+
end = start
|
| 385 |
+
|
| 386 |
+
alignment_results.append(TokenTimestamp(
|
| 387 |
+
token_id=lyrics_tokens[i],
|
| 388 |
+
text=decoded_tokens[i],
|
| 389 |
+
start=float(start),
|
| 390 |
+
end=float(end),
|
| 391 |
+
probability=token_conf
|
| 392 |
+
))
|
| 393 |
+
|
| 394 |
+
return alignment_results
|
| 395 |
+
|
| 396 |
+
def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
|
| 397 |
+
"""
|
| 398 |
+
Decode a sentence by decoding all token IDs together.
|
| 399 |
+
This avoids UTF-8 encoding issues from joining individual token texts.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
tokens: List of TokenTimestamp objects
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Properly decoded sentence text
|
| 406 |
+
"""
|
| 407 |
+
token_ids = [t.token_id for t in tokens]
|
| 408 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
| 409 |
+
|
| 410 |
+
def sentence_timestamps(
|
| 411 |
+
self,
|
| 412 |
+
token_alignment: List[TokenTimestamp]
|
| 413 |
+
) -> List[SentenceTimestamp]:
|
| 414 |
+
"""
|
| 415 |
+
Group token timestamps into sentence timestamps.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
token_alignment: List of TokenTimestamp objects
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
List of SentenceTimestamp objects
|
| 422 |
+
"""
|
| 423 |
+
results = []
|
| 424 |
+
current_tokens = []
|
| 425 |
+
|
| 426 |
+
for token in token_alignment:
|
| 427 |
+
current_tokens.append(token)
|
| 428 |
+
|
| 429 |
+
if '\n' in token.text:
|
| 430 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 431 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 432 |
+
|
| 433 |
+
if full_text.strip():
|
| 434 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 435 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 436 |
+
|
| 437 |
+
results.append(SentenceTimestamp(
|
| 438 |
+
text=full_text.strip(),
|
| 439 |
+
start=round(current_tokens[0].start, 3),
|
| 440 |
+
end=round(current_tokens[-1].end, 3),
|
| 441 |
+
tokens=list(current_tokens),
|
| 442 |
+
confidence=sent_conf
|
| 443 |
+
))
|
| 444 |
+
|
| 445 |
+
current_tokens = []
|
| 446 |
+
|
| 447 |
+
# Handle last sentence
|
| 448 |
+
if current_tokens:
|
| 449 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 450 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 451 |
+
if full_text.strip():
|
| 452 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 453 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 454 |
+
|
| 455 |
+
results.append(SentenceTimestamp(
|
| 456 |
+
text=full_text.strip(),
|
| 457 |
+
start=round(current_tokens[0].start, 3),
|
| 458 |
+
end=round(current_tokens[-1].end, 3),
|
| 459 |
+
tokens=list(current_tokens),
|
| 460 |
+
confidence=sent_conf
|
| 461 |
+
))
|
| 462 |
+
|
| 463 |
+
# Normalize confidence scores
|
| 464 |
+
if results:
|
| 465 |
+
all_scores = [s.confidence for s in results]
|
| 466 |
+
min_score = min(all_scores)
|
| 467 |
+
max_score = max(all_scores)
|
| 468 |
+
score_range = max_score - min_score
|
| 469 |
+
|
| 470 |
+
if score_range > 1e-9:
|
| 471 |
+
for s in results:
|
| 472 |
+
normalized_score = (s.confidence - min_score) / score_range
|
| 473 |
+
s.confidence = round(normalized_score, 2)
|
| 474 |
+
else:
|
| 475 |
+
for s in results:
|
| 476 |
+
s.confidence = round(s.confidence, 2)
|
| 477 |
+
|
| 478 |
+
return results
|
| 479 |
+
|
| 480 |
+
def format_lrc(
|
| 481 |
+
self,
|
| 482 |
+
sentence_timestamps: List[SentenceTimestamp],
|
| 483 |
+
include_end_time: bool = False
|
| 484 |
+
) -> str:
|
| 485 |
+
"""
|
| 486 |
+
Format sentence timestamps as LRC lyrics format.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
sentence_timestamps: List of SentenceTimestamp objects
|
| 490 |
+
include_end_time: Whether to include end time (enhanced LRC format)
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
LRC formatted string
|
| 494 |
+
"""
|
| 495 |
+
lines = []
|
| 496 |
+
|
| 497 |
+
for sentence in sentence_timestamps:
|
| 498 |
+
# Convert seconds to mm:ss.xx format
|
| 499 |
+
start_minutes = int(sentence.start // 60)
|
| 500 |
+
start_seconds = sentence.start % 60
|
| 501 |
+
|
| 502 |
+
if include_end_time:
|
| 503 |
+
end_minutes = int(sentence.end // 60)
|
| 504 |
+
end_seconds = sentence.end % 60
|
| 505 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
|
| 506 |
+
else:
|
| 507 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
|
| 508 |
+
|
| 509 |
+
# Clean the text (remove structural tags like [verse], [chorus])
|
| 510 |
+
text = sentence.text
|
| 511 |
+
|
| 512 |
+
lines.append(f"{timestamp}{text}")
|
| 513 |
+
|
| 514 |
+
return "\n".join(lines)
|
| 515 |
+
|
| 516 |
+
def get_timestamps_and_lrc(
|
| 517 |
+
self,
|
| 518 |
+
calc_matrix: np.ndarray,
|
| 519 |
+
lyrics_tokens: List[int],
|
| 520 |
+
total_duration_seconds: float
|
| 521 |
+
) -> Dict[str, Any]:
|
| 522 |
+
"""
|
| 523 |
+
Convenience method to get both timestamps and LRC in one call.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
calc_matrix: Processed attention matrix
|
| 527 |
+
lyrics_tokens: List of token IDs
|
| 528 |
+
total_duration_seconds: Total audio duration
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Dict containing token_timestamps, sentence_timestamps, and lrc_text
|
| 532 |
+
"""
|
| 533 |
+
token_stamps = self.token_timestamps(
|
| 534 |
+
calc_matrix=calc_matrix,
|
| 535 |
+
lyrics_tokens=lyrics_tokens,
|
| 536 |
+
total_duration_seconds=total_duration_seconds
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
sentence_stamps = self.sentence_timestamps(token_stamps)
|
| 540 |
+
lrc_text = self.format_lrc(sentence_stamps)
|
| 541 |
+
|
| 542 |
+
return {
|
| 543 |
+
"token_timestamps": token_stamps,
|
| 544 |
+
"sentence_timestamps": sentence_stamps,
|
| 545 |
+
"lrc_text": lrc_text
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class MusicLyricScorer:
|
| 550 |
+
"""
|
| 551 |
+
Scorer class for evaluating lyrics-to-audio alignment quality.
|
| 552 |
+
|
| 553 |
+
Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
|
| 554 |
+
using tensor operations for potential differentiability or GPU acceleration.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def __init__(self, tokenizer: Any):
|
| 558 |
+
"""
|
| 559 |
+
Initialize the aligner.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
tokenizer: Tokenizer instance (must implement .decode()).
|
| 563 |
+
"""
|
| 564 |
+
self.tokenizer = tokenizer
|
| 565 |
+
|
| 566 |
+
def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
|
| 567 |
+
"""
|
| 568 |
+
Generate a mask distinguishing lyrics (1) from structural tags (0).
|
| 569 |
+
Uses self.tokenizer to decode tokens.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
token_ids: List of token IDs.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
Numpy array of shape [len(token_ids)] with 1 or 0.
|
| 576 |
+
"""
|
| 577 |
+
decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
|
| 578 |
+
mask = np.ones(len(token_ids), dtype=np.int32)
|
| 579 |
+
in_bracket = False
|
| 580 |
+
|
| 581 |
+
for i, token_str in enumerate(decoded_tokens):
|
| 582 |
+
if '[' in token_str:
|
| 583 |
+
in_bracket = True
|
| 584 |
+
if in_bracket:
|
| 585 |
+
mask[i] = 0
|
| 586 |
+
if ']' in token_str:
|
| 587 |
+
in_bracket = False
|
| 588 |
+
mask[i] = 0
|
| 589 |
+
return mask
|
| 590 |
+
|
| 591 |
+
def _preprocess_attention(
|
| 592 |
+
self,
|
| 593 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 594 |
+
custom_config: Dict[int, List[int]],
|
| 595 |
+
medfilt_width: int = 1
|
| 596 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
|
| 597 |
+
"""
|
| 598 |
+
Extracts and normalizes the attention matrix.
|
| 599 |
+
|
| 600 |
+
Logic V4: Uses Min-Max normalization to highlight energy differences.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
|
| 604 |
+
custom_config: Config mapping layers to heads.
|
| 605 |
+
medfilt_width: Width for median filtering.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
|
| 609 |
+
"""
|
| 610 |
+
# 1. Prepare Tensor
|
| 611 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 612 |
+
weights = torch.tensor(attention_matrix)
|
| 613 |
+
else:
|
| 614 |
+
weights = attention_matrix.clone()
|
| 615 |
+
weights = weights.cpu().float()
|
| 616 |
+
|
| 617 |
+
# 2. Select Heads based on config
|
| 618 |
+
selected_tensors = []
|
| 619 |
+
for layer_idx, head_indices in custom_config.items():
|
| 620 |
+
for head_idx in head_indices:
|
| 621 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 622 |
+
selected_tensors.append(weights[layer_idx, head_idx])
|
| 623 |
+
|
| 624 |
+
if not selected_tensors:
|
| 625 |
+
return None, None, None
|
| 626 |
+
|
| 627 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 628 |
+
|
| 629 |
+
# 3. Average Heads
|
| 630 |
+
avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
|
| 631 |
+
|
| 632 |
+
# 4. Preprocessing Logic
|
| 633 |
+
# Min-Max normalization preserving energy distribution
|
| 634 |
+
# Median filter is applied to the energy matrix
|
| 635 |
+
energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
|
| 636 |
+
energy_matrix = energy_tensor.numpy()
|
| 637 |
+
|
| 638 |
+
e_min, e_max = energy_matrix.min(), energy_matrix.max()
|
| 639 |
+
|
| 640 |
+
if e_max - e_min > 1e-9:
|
| 641 |
+
energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
|
| 642 |
+
else:
|
| 643 |
+
energy_matrix = np.zeros_like(energy_matrix)
|
| 644 |
+
|
| 645 |
+
# Contrast enhancement for DTW pathfinding
|
| 646 |
+
# calc_matrix is used for pathfinding, energy_matrix for scoring
|
| 647 |
+
calc_matrix = energy_matrix ** 2
|
| 648 |
+
|
| 649 |
+
return calc_matrix, energy_matrix, avg_weights
|
| 650 |
+
|
| 651 |
+
def _compute_alignment_metrics(
|
| 652 |
+
self,
|
| 653 |
+
energy_matrix: torch.Tensor,
|
| 654 |
+
path_coords: torch.Tensor,
|
| 655 |
+
type_mask: torch.Tensor,
|
| 656 |
+
time_weight: float = 0.01,
|
| 657 |
+
overlap_frames: float = 9.0,
|
| 658 |
+
instrumental_weight: float = 1.0
|
| 659 |
+
) -> Tuple[float, float, float]:
|
| 660 |
+
"""
|
| 661 |
+
Core metric calculation logic using high-precision Tensor operations.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
energy_matrix: Normalized energy [Rows, Cols].
|
| 665 |
+
path_coords: DTW path coordinates [Steps, 2].
|
| 666 |
+
type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
|
| 667 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 668 |
+
overlap_frames: Allowed overlap for monotonicity check.
|
| 669 |
+
instrumental_weight: Weight for non-lyric tokens in confidence calc.
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
Tuple of (coverage, monotonicity, confidence).
|
| 673 |
+
"""
|
| 674 |
+
# Ensure high precision for internal calculation
|
| 675 |
+
energy_matrix = energy_matrix.to(dtype=torch.float64)
|
| 676 |
+
path_coords = path_coords.long()
|
| 677 |
+
type_mask = type_mask.long()
|
| 678 |
+
|
| 679 |
+
device = energy_matrix.device
|
| 680 |
+
rows, cols = energy_matrix.shape
|
| 681 |
+
|
| 682 |
+
is_lyrics_row = (type_mask == 1)
|
| 683 |
+
|
| 684 |
+
# ================= A. Coverage Score =================
|
| 685 |
+
# Ratio of lyric lines that have significant energy peak
|
| 686 |
+
row_max_energies = energy_matrix.max(dim=1).values
|
| 687 |
+
total_sung_rows = is_lyrics_row.sum().double()
|
| 688 |
+
|
| 689 |
+
coverage_threshold = 0.1
|
| 690 |
+
valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
|
| 691 |
+
valid_sung_rows = valid_sung_mask.sum().double()
|
| 692 |
+
|
| 693 |
+
if total_sung_rows > 0:
|
| 694 |
+
coverage_score = valid_sung_rows / total_sung_rows
|
| 695 |
+
else:
|
| 696 |
+
coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 697 |
+
|
| 698 |
+
# ================= B. Monotonicity Score =================
|
| 699 |
+
# Check if the "center of mass" of lyric lines moves forward in time
|
| 700 |
+
col_indices = torch.arange(cols, device=device, dtype=torch.float64)
|
| 701 |
+
|
| 702 |
+
# Zero out low energy noise
|
| 703 |
+
weights = torch.where(
|
| 704 |
+
energy_matrix > time_weight,
|
| 705 |
+
energy_matrix,
|
| 706 |
+
torch.zeros_like(energy_matrix)
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
sum_w = weights.sum(dim=1)
|
| 710 |
+
sum_t = (weights * col_indices).sum(dim=1)
|
| 711 |
+
|
| 712 |
+
# Calculate centroids
|
| 713 |
+
centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
|
| 714 |
+
valid_w_mask = sum_w > 1e-9
|
| 715 |
+
centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
|
| 716 |
+
|
| 717 |
+
# Extract sequence of valid lyrics centroids
|
| 718 |
+
valid_sequence_mask = is_lyrics_row & (centroids >= 0)
|
| 719 |
+
sung_centroids = centroids[valid_sequence_mask]
|
| 720 |
+
|
| 721 |
+
cnt = sung_centroids.shape[0]
|
| 722 |
+
if cnt > 1:
|
| 723 |
+
curr_c = sung_centroids[:-1]
|
| 724 |
+
next_c = sung_centroids[1:]
|
| 725 |
+
|
| 726 |
+
# Check non-decreasing order with overlap tolerance
|
| 727 |
+
non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
|
| 728 |
+
pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
|
| 729 |
+
monotonicity_score = non_decreasing / pairs
|
| 730 |
+
else:
|
| 731 |
+
monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 732 |
+
|
| 733 |
+
# ================= C. Path Confidence =================
|
| 734 |
+
# Average energy along the optimal path
|
| 735 |
+
if path_coords.shape[0] > 0:
|
| 736 |
+
p_rows = path_coords[:, 0]
|
| 737 |
+
p_cols = path_coords[:, 1]
|
| 738 |
+
|
| 739 |
+
path_energies = energy_matrix[p_rows, p_cols]
|
| 740 |
+
step_weights = torch.ones_like(path_energies)
|
| 741 |
+
|
| 742 |
+
# Lower weight for instrumental/tag steps
|
| 743 |
+
is_inst_step = (type_mask[p_rows] == 0)
|
| 744 |
+
step_weights[is_inst_step] = instrumental_weight
|
| 745 |
+
|
| 746 |
+
total_energy = (path_energies * step_weights).sum()
|
| 747 |
+
total_steps = step_weights.sum()
|
| 748 |
+
|
| 749 |
+
if total_steps > 0:
|
| 750 |
+
path_confidence = total_energy / total_steps
|
| 751 |
+
else:
|
| 752 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 753 |
+
else:
|
| 754 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 755 |
+
|
| 756 |
+
return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
|
| 757 |
+
|
| 758 |
+
def lyrics_alignment_info(
|
| 759 |
+
self,
|
| 760 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 761 |
+
token_ids: List[int],
|
| 762 |
+
custom_config: Dict[int, List[int]],
|
| 763 |
+
return_matrices: bool = False,
|
| 764 |
+
medfilt_width: int = 1
|
| 765 |
+
) -> Dict[str, Any]:
|
| 766 |
+
"""
|
| 767 |
+
Generates alignment path and processed matrices.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
attention_matrix: Input attention tensor.
|
| 771 |
+
token_ids: Corresponding token IDs.
|
| 772 |
+
custom_config: Layer/Head configuration.
|
| 773 |
+
return_matrices: If True, returns matrices in the output.
|
| 774 |
+
medfilt_width: Median filter width.
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
Dict or AlignmentInfo object containing path and masks.
|
| 778 |
+
"""
|
| 779 |
+
calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
|
| 780 |
+
attention_matrix, custom_config, medfilt_width
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
if calc_matrix is None:
|
| 784 |
+
return {
|
| 785 |
+
"calc_matrix": None,
|
| 786 |
+
"error": "No valid attention heads found"
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
# 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
|
| 790 |
+
# Uses self.tokenizer internally
|
| 791 |
+
type_mask = self._generate_token_type_mask(token_ids)
|
| 792 |
+
|
| 793 |
+
# Safety check for shape mismatch
|
| 794 |
+
if len(type_mask) != energy_matrix.shape[0]:
|
| 795 |
+
# Fallback to all lyrics if shapes don't align
|
| 796 |
+
type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
|
| 797 |
+
|
| 798 |
+
# 2. DTW Pathfinding
|
| 799 |
+
# Using negative calc_matrix because DTW minimizes cost
|
| 800 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
|
| 801 |
+
path_coords = np.stack([text_indices, time_indices], axis=1)
|
| 802 |
+
|
| 803 |
+
return_dict = {
|
| 804 |
+
"path_coords": path_coords,
|
| 805 |
+
"type_mask": type_mask,
|
| 806 |
+
"energy_matrix": energy_matrix
|
| 807 |
+
}
|
| 808 |
+
if return_matrices:
|
| 809 |
+
return_dict['calc_matrix'] = calc_matrix
|
| 810 |
+
return_dict['vis_matrix'] = vis_matrix
|
| 811 |
+
|
| 812 |
+
return return_dict
|
| 813 |
+
|
| 814 |
+
def calculate_score(
|
| 815 |
+
self,
|
| 816 |
+
energy_matrix: Union[torch.Tensor, np.ndarray],
|
| 817 |
+
type_mask: Union[torch.Tensor, np.ndarray],
|
| 818 |
+
path_coords: Union[torch.Tensor, np.ndarray],
|
| 819 |
+
time_weight: float = 0.01,
|
| 820 |
+
overlap_frames: float = 9.0,
|
| 821 |
+
instrumental_weight: float = 1.0
|
| 822 |
+
) -> Dict[str, Any]:
|
| 823 |
+
"""
|
| 824 |
+
Calculates the final alignment score based on pre-computed components.
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
energy_matrix: Processed energy matrix.
|
| 828 |
+
type_mask: Token type mask.
|
| 829 |
+
path_coords: DTW path coordinates.
|
| 830 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 831 |
+
overlap_frames: Allowed backward movement frames.
|
| 832 |
+
instrumental_weight: Weight for non-lyric path steps.
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
AlignmentScore object containing individual metrics and final score.
|
| 836 |
+
"""
|
| 837 |
+
# Ensure Inputs are Tensors on the correct device
|
| 838 |
+
if not isinstance(energy_matrix, torch.Tensor):
|
| 839 |
+
energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32)
|
| 840 |
+
|
| 841 |
+
device = energy_matrix.device
|
| 842 |
+
|
| 843 |
+
if not isinstance(type_mask, torch.Tensor):
|
| 844 |
+
type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
|
| 845 |
+
else:
|
| 846 |
+
type_mask = type_mask.to(device=device, dtype=torch.long)
|
| 847 |
+
|
| 848 |
+
if not isinstance(path_coords, torch.Tensor):
|
| 849 |
+
path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
|
| 850 |
+
else:
|
| 851 |
+
path_coords = path_coords.to(device=device, dtype=torch.long)
|
| 852 |
+
|
| 853 |
+
# Compute Metrics
|
| 854 |
+
coverage, monotonicity, confidence = self._compute_alignment_metrics(
|
| 855 |
+
energy_matrix=energy_matrix,
|
| 856 |
+
path_coords=path_coords,
|
| 857 |
+
type_mask=type_mask,
|
| 858 |
+
time_weight=time_weight,
|
| 859 |
+
overlap_frames=overlap_frames,
|
| 860 |
+
instrumental_weight=instrumental_weight
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
# Final Score Calculation
|
| 864 |
+
# (Cov^2 * Mono^2 * Conf)
|
| 865 |
+
final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
|
| 866 |
+
final_score = float(np.clip(final_score, 0.0, 1.0))
|
| 867 |
+
|
| 868 |
+
return {
|
| 869 |
+
"lyrics_score": round(final_score, 4)
|
| 870 |
+
}
|
acestep/genres_vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/gradio_ui/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from acestep.gradio_ui.interfaces import create_gradio_interface
|
acestep/gradio_ui/events/__init__.py
ADDED
|
@@ -0,0 +1,1355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Event Handlers Module
|
| 3 |
+
Main entry point for setting up all event handlers
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
# Import handler modules
|
| 11 |
+
from . import generation_handlers as gen_h
|
| 12 |
+
from . import results_handlers as res_h
|
| 13 |
+
from . import training_handlers as train_h
|
| 14 |
+
from acestep.gradio_ui.i18n import t
|
| 15 |
+
|
| 16 |
+
# HuggingFace Space environment detection for ZeroGPU support
|
| 17 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _get_spaces_gpu_decorator(duration=120):
|
| 21 |
+
"""
|
| 22 |
+
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 23 |
+
Returns identity decorator if not in Space environment.
|
| 24 |
+
"""
|
| 25 |
+
if IS_HUGGINGFACE_SPACE:
|
| 26 |
+
try:
|
| 27 |
+
import spaces
|
| 28 |
+
return spaces.GPU(duration=duration)
|
| 29 |
+
except ImportError:
|
| 30 |
+
logger.warning("spaces package not found, GPU decorator disabled")
|
| 31 |
+
return lambda func: func
|
| 32 |
+
return lambda func: func
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
|
| 36 |
+
"""Setup event handlers connecting UI components and business logic
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
init_params: Dictionary containing initialization parameters including:
|
| 40 |
+
- dit_handler_2: Optional second DiT handler for multi-model setup
|
| 41 |
+
- available_dit_models: List of available DiT model names
|
| 42 |
+
- config_path: Primary model config path
|
| 43 |
+
- config_path_2: Secondary model config path (if available)
|
| 44 |
+
"""
|
| 45 |
+
# Get secondary DiT handler from init_params (for multi-model support)
|
| 46 |
+
dit_handler_2 = init_params.get('dit_handler_2') if init_params else None
|
| 47 |
+
config_path_1 = init_params.get('config_path', '') if init_params else ''
|
| 48 |
+
config_path_2 = init_params.get('config_path_2', '') if init_params else ''
|
| 49 |
+
|
| 50 |
+
# ========== Dataset Handlers ==========
|
| 51 |
+
dataset_section["import_dataset_btn"].click(
|
| 52 |
+
fn=dataset_handler.import_dataset,
|
| 53 |
+
inputs=[dataset_section["dataset_type"]],
|
| 54 |
+
outputs=[dataset_section["data_status"]]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# ========== Service Initialization ==========
|
| 58 |
+
generation_section["refresh_btn"].click(
|
| 59 |
+
fn=lambda: gen_h.refresh_checkpoints(dit_handler),
|
| 60 |
+
outputs=[generation_section["checkpoint_dropdown"]]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
generation_section["config_path"].change(
|
| 64 |
+
fn=gen_h.update_model_type_settings,
|
| 65 |
+
inputs=[generation_section["config_path"]],
|
| 66 |
+
outputs=[
|
| 67 |
+
generation_section["inference_steps"],
|
| 68 |
+
generation_section["guidance_scale"],
|
| 69 |
+
generation_section["use_adg"],
|
| 70 |
+
generation_section["shift"],
|
| 71 |
+
generation_section["cfg_interval_start"],
|
| 72 |
+
generation_section["cfg_interval_end"],
|
| 73 |
+
generation_section["task_type"],
|
| 74 |
+
]
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
generation_section["init_btn"].click(
|
| 78 |
+
fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args),
|
| 79 |
+
inputs=[
|
| 80 |
+
generation_section["checkpoint_dropdown"],
|
| 81 |
+
generation_section["config_path"],
|
| 82 |
+
generation_section["device"],
|
| 83 |
+
generation_section["init_llm_checkbox"],
|
| 84 |
+
generation_section["lm_model_path"],
|
| 85 |
+
generation_section["backend_dropdown"],
|
| 86 |
+
generation_section["use_flash_attention_checkbox"],
|
| 87 |
+
generation_section["offload_to_cpu_checkbox"],
|
| 88 |
+
generation_section["offload_dit_to_cpu_checkbox"],
|
| 89 |
+
],
|
| 90 |
+
outputs=[
|
| 91 |
+
generation_section["init_status"],
|
| 92 |
+
generation_section["generate_btn"],
|
| 93 |
+
generation_section["service_config_accordion"],
|
| 94 |
+
# Model type settings (updated based on actual loaded model)
|
| 95 |
+
generation_section["inference_steps"],
|
| 96 |
+
generation_section["guidance_scale"],
|
| 97 |
+
generation_section["use_adg"],
|
| 98 |
+
generation_section["shift"],
|
| 99 |
+
generation_section["cfg_interval_start"],
|
| 100 |
+
generation_section["cfg_interval_end"],
|
| 101 |
+
generation_section["task_type"],
|
| 102 |
+
]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# ========== LoRA Handlers ==========
|
| 106 |
+
generation_section["load_lora_btn"].click(
|
| 107 |
+
fn=dit_handler.load_lora,
|
| 108 |
+
inputs=[generation_section["lora_path"]],
|
| 109 |
+
outputs=[generation_section["lora_status"]]
|
| 110 |
+
).then(
|
| 111 |
+
# Update checkbox to enabled state after loading
|
| 112 |
+
fn=lambda: gr.update(value=True),
|
| 113 |
+
outputs=[generation_section["use_lora_checkbox"]]
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
generation_section["unload_lora_btn"].click(
|
| 117 |
+
fn=dit_handler.unload_lora,
|
| 118 |
+
outputs=[generation_section["lora_status"]]
|
| 119 |
+
).then(
|
| 120 |
+
# Update checkbox to disabled state after unloading
|
| 121 |
+
fn=lambda: gr.update(value=False),
|
| 122 |
+
outputs=[generation_section["use_lora_checkbox"]]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
generation_section["use_lora_checkbox"].change(
|
| 126 |
+
fn=dit_handler.set_use_lora,
|
| 127 |
+
inputs=[generation_section["use_lora_checkbox"]],
|
| 128 |
+
outputs=[generation_section["lora_status"]]
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# ========== UI Visibility Updates ==========
|
| 132 |
+
generation_section["init_llm_checkbox"].change(
|
| 133 |
+
fn=gen_h.update_negative_prompt_visibility,
|
| 134 |
+
inputs=[generation_section["init_llm_checkbox"]],
|
| 135 |
+
outputs=[generation_section["lm_negative_prompt"]]
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
generation_section["init_llm_checkbox"].change(
|
| 139 |
+
fn=gen_h.update_audio_cover_strength_visibility,
|
| 140 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
|
| 141 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
generation_section["task_type"].change(
|
| 145 |
+
fn=gen_h.update_audio_cover_strength_visibility,
|
| 146 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
|
| 147 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
generation_section["batch_size_input"].change(
|
| 151 |
+
fn=gen_h.update_audio_components_visibility,
|
| 152 |
+
inputs=[generation_section["batch_size_input"]],
|
| 153 |
+
outputs=[
|
| 154 |
+
results_section["audio_col_1"],
|
| 155 |
+
results_section["audio_col_2"],
|
| 156 |
+
results_section["audio_col_3"],
|
| 157 |
+
results_section["audio_col_4"],
|
| 158 |
+
results_section["audio_row_5_8"],
|
| 159 |
+
results_section["audio_col_5"],
|
| 160 |
+
results_section["audio_col_6"],
|
| 161 |
+
results_section["audio_col_7"],
|
| 162 |
+
results_section["audio_col_8"],
|
| 163 |
+
]
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# ========== Audio Conversion ==========
|
| 167 |
+
generation_section["convert_src_to_codes_btn"].click(
|
| 168 |
+
fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
|
| 169 |
+
inputs=[generation_section["src_audio"]],
|
| 170 |
+
outputs=[generation_section["text2music_audio_code_string"]]
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# ========== Instruction UI Updates ==========
|
| 174 |
+
for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"]]:
|
| 175 |
+
trigger.change(
|
| 176 |
+
fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args),
|
| 177 |
+
inputs=[
|
| 178 |
+
generation_section["task_type"],
|
| 179 |
+
generation_section["track_name"],
|
| 180 |
+
generation_section["complete_track_classes"],
|
| 181 |
+
generation_section["text2music_audio_code_string"],
|
| 182 |
+
generation_section["init_llm_checkbox"]
|
| 183 |
+
],
|
| 184 |
+
outputs=[
|
| 185 |
+
generation_section["instruction_display_gen"],
|
| 186 |
+
generation_section["track_name"],
|
| 187 |
+
generation_section["complete_track_classes"],
|
| 188 |
+
generation_section["audio_cover_strength"],
|
| 189 |
+
generation_section["repainting_group"],
|
| 190 |
+
generation_section["text2music_audio_codes_group"],
|
| 191 |
+
]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# ========== Sample/Transcribe Handlers ==========
|
| 195 |
+
# Load random example from ./examples/text2music directory
|
| 196 |
+
generation_section["sample_btn"].click(
|
| 197 |
+
fn=lambda task: gen_h.load_random_example(task) + (True,),
|
| 198 |
+
inputs=[
|
| 199 |
+
generation_section["task_type"],
|
| 200 |
+
],
|
| 201 |
+
outputs=[
|
| 202 |
+
generation_section["captions"],
|
| 203 |
+
generation_section["lyrics"],
|
| 204 |
+
generation_section["think_checkbox"],
|
| 205 |
+
generation_section["bpm"],
|
| 206 |
+
generation_section["audio_duration"],
|
| 207 |
+
generation_section["key_scale"],
|
| 208 |
+
generation_section["vocal_language"],
|
| 209 |
+
generation_section["time_signature"],
|
| 210 |
+
results_section["is_format_caption_state"]
|
| 211 |
+
]
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
generation_section["text2music_audio_code_string"].change(
|
| 215 |
+
fn=gen_h.update_transcribe_button_text,
|
| 216 |
+
inputs=[generation_section["text2music_audio_code_string"]],
|
| 217 |
+
outputs=[generation_section["transcribe_btn"]]
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
generation_section["transcribe_btn"].click(
|
| 221 |
+
fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug),
|
| 222 |
+
inputs=[
|
| 223 |
+
generation_section["text2music_audio_code_string"],
|
| 224 |
+
generation_section["constrained_decoding_debug"]
|
| 225 |
+
],
|
| 226 |
+
outputs=[
|
| 227 |
+
results_section["status_output"],
|
| 228 |
+
generation_section["captions"],
|
| 229 |
+
generation_section["lyrics"],
|
| 230 |
+
generation_section["bpm"],
|
| 231 |
+
generation_section["audio_duration"],
|
| 232 |
+
generation_section["key_scale"],
|
| 233 |
+
generation_section["vocal_language"],
|
| 234 |
+
generation_section["time_signature"],
|
| 235 |
+
results_section["is_format_caption_state"]
|
| 236 |
+
]
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# ========== Reset Format Caption Flag ==========
|
| 240 |
+
for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"],
|
| 241 |
+
generation_section["key_scale"], generation_section["time_signature"],
|
| 242 |
+
generation_section["vocal_language"], generation_section["audio_duration"]]:
|
| 243 |
+
trigger.change(
|
| 244 |
+
fn=gen_h.reset_format_caption_flag,
|
| 245 |
+
inputs=[],
|
| 246 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# ========== Audio Uploads Accordion ==========
|
| 250 |
+
for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
|
| 251 |
+
trigger.change(
|
| 252 |
+
fn=gen_h.update_audio_uploads_accordion,
|
| 253 |
+
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
|
| 254 |
+
outputs=[generation_section["audio_uploads_accordion"]]
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# ========== Instrumental Checkbox ==========
|
| 258 |
+
generation_section["instrumental_checkbox"].change(
|
| 259 |
+
fn=gen_h.handle_instrumental_checkbox,
|
| 260 |
+
inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
|
| 261 |
+
outputs=[generation_section["lyrics"]]
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# ========== Format Button ==========
|
| 265 |
+
# Note: cfg_scale and negative_prompt are not supported in format mode
|
| 266 |
+
@_get_spaces_gpu_decorator(duration=120)
|
| 267 |
+
def handle_format_sample_wrapper(caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug):
|
| 268 |
+
return gen_h.handle_format_sample(
|
| 269 |
+
llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
generation_section["format_btn"].click(
|
| 273 |
+
fn=handle_format_sample_wrapper,
|
| 274 |
+
inputs=[
|
| 275 |
+
generation_section["captions"],
|
| 276 |
+
generation_section["lyrics"],
|
| 277 |
+
generation_section["bpm"],
|
| 278 |
+
generation_section["audio_duration"],
|
| 279 |
+
generation_section["key_scale"],
|
| 280 |
+
generation_section["time_signature"],
|
| 281 |
+
generation_section["lm_temperature"],
|
| 282 |
+
generation_section["lm_top_k"],
|
| 283 |
+
generation_section["lm_top_p"],
|
| 284 |
+
generation_section["constrained_decoding_debug"],
|
| 285 |
+
],
|
| 286 |
+
outputs=[
|
| 287 |
+
generation_section["captions"],
|
| 288 |
+
generation_section["lyrics"],
|
| 289 |
+
generation_section["bpm"],
|
| 290 |
+
generation_section["audio_duration"],
|
| 291 |
+
generation_section["key_scale"],
|
| 292 |
+
generation_section["vocal_language"],
|
| 293 |
+
generation_section["time_signature"],
|
| 294 |
+
results_section["is_format_caption_state"],
|
| 295 |
+
results_section["status_output"],
|
| 296 |
+
]
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# ========== Generation Mode Toggle (Simple/Custom/Cover/Repaint) ==========
|
| 300 |
+
generation_section["generation_mode"].change(
|
| 301 |
+
fn=gen_h.handle_generation_mode_change,
|
| 302 |
+
inputs=[generation_section["generation_mode"]],
|
| 303 |
+
outputs=[
|
| 304 |
+
generation_section["simple_mode_group"],
|
| 305 |
+
generation_section["custom_mode_content"],
|
| 306 |
+
generation_section["cover_mode_group"],
|
| 307 |
+
generation_section["repainting_group"],
|
| 308 |
+
generation_section["task_type"],
|
| 309 |
+
generation_section["generate_btn"],
|
| 310 |
+
generation_section["simple_sample_created"],
|
| 311 |
+
generation_section["src_audio_group"],
|
| 312 |
+
generation_section["audio_cover_strength"],
|
| 313 |
+
generation_section["think_checkbox"], # Disable thinking for cover/repaint modes
|
| 314 |
+
]
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# ========== Process Source Audio Button ==========
|
| 318 |
+
# Combines Convert to Codes + Transcribe in one step
|
| 319 |
+
# Note: @spaces.GPU decorator must be on the function passed directly to fn=,
|
| 320 |
+
# not on a module-level function wrapped in a lambda. Lambdas capturing handler
|
| 321 |
+
# objects cause pickling errors on ZeroGPU because the model contains unpicklable
|
| 322 |
+
# local objects (e.g. AceStepDiTModel.__init__ lambdas).
|
| 323 |
+
@_get_spaces_gpu_decorator(duration=120)
|
| 324 |
+
def process_source_audio_wrapper(src, debug):
|
| 325 |
+
return gen_h.process_source_audio(dit_handler, llm_handler, src, debug)
|
| 326 |
+
|
| 327 |
+
generation_section["process_src_btn"].click(
|
| 328 |
+
fn=process_source_audio_wrapper,
|
| 329 |
+
inputs=[
|
| 330 |
+
generation_section["src_audio"],
|
| 331 |
+
generation_section["constrained_decoding_debug"]
|
| 332 |
+
],
|
| 333 |
+
outputs=[
|
| 334 |
+
generation_section["text2music_audio_code_string"],
|
| 335 |
+
results_section["status_output"],
|
| 336 |
+
generation_section["captions"],
|
| 337 |
+
generation_section["lyrics"],
|
| 338 |
+
generation_section["bpm"],
|
| 339 |
+
generation_section["audio_duration"],
|
| 340 |
+
generation_section["key_scale"],
|
| 341 |
+
generation_section["vocal_language"],
|
| 342 |
+
generation_section["time_signature"],
|
| 343 |
+
results_section["is_format_caption_state"],
|
| 344 |
+
]
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# ========== Simple Mode Instrumental Checkbox ==========
|
| 348 |
+
# When instrumental is checked, disable vocal language and set to ["unknown"]
|
| 349 |
+
generation_section["simple_instrumental_checkbox"].change(
|
| 350 |
+
fn=gen_h.handle_simple_instrumental_change,
|
| 351 |
+
inputs=[generation_section["simple_instrumental_checkbox"]],
|
| 352 |
+
outputs=[generation_section["simple_vocal_language"]]
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# ========== Random Description Button ==========
|
| 356 |
+
generation_section["random_desc_btn"].click(
|
| 357 |
+
fn=gen_h.load_random_simple_description,
|
| 358 |
+
inputs=[],
|
| 359 |
+
outputs=[
|
| 360 |
+
generation_section["simple_query_input"],
|
| 361 |
+
generation_section["simple_instrumental_checkbox"],
|
| 362 |
+
generation_section["simple_vocal_language"],
|
| 363 |
+
]
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
# ========== Create Sample Button (Simple Mode) ==========
|
| 367 |
+
# Note: cfg_scale and negative_prompt are not supported in create_sample mode
|
| 368 |
+
@_get_spaces_gpu_decorator(duration=120)
|
| 369 |
+
def handle_create_sample_wrapper(query, instrumental, vocal_lang, temp, top_k, top_p, debug):
|
| 370 |
+
return gen_h.handle_create_sample(
|
| 371 |
+
llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
generation_section["create_sample_btn"].click(
|
| 375 |
+
fn=handle_create_sample_wrapper,
|
| 376 |
+
inputs=[
|
| 377 |
+
generation_section["simple_query_input"],
|
| 378 |
+
generation_section["simple_instrumental_checkbox"],
|
| 379 |
+
generation_section["simple_vocal_language"],
|
| 380 |
+
generation_section["lm_temperature"],
|
| 381 |
+
generation_section["lm_top_k"],
|
| 382 |
+
generation_section["lm_top_p"],
|
| 383 |
+
generation_section["constrained_decoding_debug"],
|
| 384 |
+
],
|
| 385 |
+
outputs=[
|
| 386 |
+
generation_section["captions"],
|
| 387 |
+
generation_section["lyrics"],
|
| 388 |
+
generation_section["bpm"],
|
| 389 |
+
generation_section["audio_duration"],
|
| 390 |
+
generation_section["key_scale"],
|
| 391 |
+
generation_section["vocal_language"],
|
| 392 |
+
generation_section["simple_vocal_language"],
|
| 393 |
+
generation_section["time_signature"],
|
| 394 |
+
generation_section["instrumental_checkbox"],
|
| 395 |
+
generation_section["caption_accordion"],
|
| 396 |
+
generation_section["lyrics_accordion"],
|
| 397 |
+
generation_section["generate_btn"],
|
| 398 |
+
generation_section["simple_sample_created"],
|
| 399 |
+
generation_section["think_checkbox"],
|
| 400 |
+
results_section["is_format_caption_state"],
|
| 401 |
+
results_section["status_output"],
|
| 402 |
+
]
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
# ========== Load/Save Metadata ==========
|
| 406 |
+
generation_section["load_file"].upload(
|
| 407 |
+
fn=gen_h.load_metadata,
|
| 408 |
+
inputs=[generation_section["load_file"]],
|
| 409 |
+
outputs=[
|
| 410 |
+
generation_section["task_type"],
|
| 411 |
+
generation_section["captions"],
|
| 412 |
+
generation_section["lyrics"],
|
| 413 |
+
generation_section["vocal_language"],
|
| 414 |
+
generation_section["bpm"],
|
| 415 |
+
generation_section["key_scale"],
|
| 416 |
+
generation_section["time_signature"],
|
| 417 |
+
generation_section["audio_duration"],
|
| 418 |
+
generation_section["batch_size_input"],
|
| 419 |
+
generation_section["inference_steps"],
|
| 420 |
+
generation_section["guidance_scale"],
|
| 421 |
+
generation_section["seed"],
|
| 422 |
+
generation_section["random_seed_checkbox"],
|
| 423 |
+
generation_section["use_adg"],
|
| 424 |
+
generation_section["cfg_interval_start"],
|
| 425 |
+
generation_section["cfg_interval_end"],
|
| 426 |
+
generation_section["shift"],
|
| 427 |
+
generation_section["infer_method"],
|
| 428 |
+
generation_section["custom_timesteps"],
|
| 429 |
+
generation_section["audio_format"],
|
| 430 |
+
generation_section["lm_temperature"],
|
| 431 |
+
generation_section["lm_cfg_scale"],
|
| 432 |
+
generation_section["lm_top_k"],
|
| 433 |
+
generation_section["lm_top_p"],
|
| 434 |
+
generation_section["lm_negative_prompt"],
|
| 435 |
+
generation_section["use_cot_metas"], # Added: use_cot_metas
|
| 436 |
+
generation_section["use_cot_caption"],
|
| 437 |
+
generation_section["use_cot_language"],
|
| 438 |
+
generation_section["audio_cover_strength"],
|
| 439 |
+
generation_section["think_checkbox"],
|
| 440 |
+
generation_section["text2music_audio_code_string"],
|
| 441 |
+
generation_section["repainting_start"],
|
| 442 |
+
generation_section["repainting_end"],
|
| 443 |
+
generation_section["track_name"],
|
| 444 |
+
generation_section["complete_track_classes"],
|
| 445 |
+
generation_section["instrumental_checkbox"], # Added: instrumental_checkbox
|
| 446 |
+
results_section["is_format_caption_state"]
|
| 447 |
+
]
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Save buttons for all 8 audio outputs
|
| 451 |
+
download_existing_js = """(current_audio, batch_files) => {
|
| 452 |
+
// Debug: print what the input actually is
|
| 453 |
+
console.log("👉 [Debug] Current Audio Input:", current_audio);
|
| 454 |
+
|
| 455 |
+
// 1. Safety check
|
| 456 |
+
if (!current_audio) {
|
| 457 |
+
console.warn("⚠️ No audio selected or audio is empty.");
|
| 458 |
+
return;
|
| 459 |
+
}
|
| 460 |
+
if (!batch_files || !Array.isArray(batch_files)) {
|
| 461 |
+
console.warn("⚠️ Batch file list is empty/not ready.");
|
| 462 |
+
return;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
// 2. Smartly extract path string
|
| 466 |
+
let pathString = "";
|
| 467 |
+
|
| 468 |
+
if (typeof current_audio === "string") {
|
| 469 |
+
// Case A: direct path string received
|
| 470 |
+
pathString = current_audio;
|
| 471 |
+
} else if (typeof current_audio === "object") {
|
| 472 |
+
// Case B: an object is received, try common properties
|
| 473 |
+
// Gradio file objects usually have path, url, or name
|
| 474 |
+
pathString = current_audio.path || current_audio.name || current_audio.url || "";
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
if (!pathString) {
|
| 478 |
+
console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
|
| 479 |
+
return;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
// 3. Extract Key (UUID)
|
| 483 |
+
// Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
|
| 484 |
+
let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
|
| 485 |
+
let key = filename.split('.')[0]; // get UUID without extension
|
| 486 |
+
|
| 487 |
+
console.log(`🔑 Key extracted: ${key}`);
|
| 488 |
+
|
| 489 |
+
// 4. Find matching file(s) in the list
|
| 490 |
+
let targets = batch_files.filter(f => {
|
| 491 |
+
// Also extract names from batch_files objects
|
| 492 |
+
// f usually contains name (backend path) and orig_name (download name)
|
| 493 |
+
const fPath = f.name || f.path || "";
|
| 494 |
+
return fPath.includes(key);
|
| 495 |
+
});
|
| 496 |
+
|
| 497 |
+
if (targets.length === 0) {
|
| 498 |
+
console.warn("❌ No matching files found in batch list for key:", key);
|
| 499 |
+
alert("Batch list does not contain this file yet. Please wait for generation to finish.");
|
| 500 |
+
return;
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
// 5. Trigger download(s)
|
| 504 |
+
console.log(`🎯 Found ${targets.length} files to download.`);
|
| 505 |
+
targets.forEach((f, index) => {
|
| 506 |
+
setTimeout(() => {
|
| 507 |
+
const a = document.createElement('a');
|
| 508 |
+
// Prefer url (frontend-accessible link), otherwise try data
|
| 509 |
+
a.href = f.url || f.data;
|
| 510 |
+
a.download = f.orig_name || "download";
|
| 511 |
+
a.style.display = 'none';
|
| 512 |
+
document.body.appendChild(a);
|
| 513 |
+
a.click();
|
| 514 |
+
document.body.removeChild(a);
|
| 515 |
+
}, index * 1000); // 300ms interval to avoid browser blocking
|
| 516 |
+
});
|
| 517 |
+
}
|
| 518 |
+
"""
|
| 519 |
+
for btn_idx in range(1, 9):
|
| 520 |
+
results_section[f"save_btn_{btn_idx}"].click(
|
| 521 |
+
fn=None,
|
| 522 |
+
inputs=[
|
| 523 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 524 |
+
results_section["generated_audio_batch"],
|
| 525 |
+
],
|
| 526 |
+
js=download_existing_js # Run the above JS
|
| 527 |
+
)
|
| 528 |
+
# ========== Send to Cover Handlers ==========
|
| 529 |
+
def send_to_cover_handler(audio_file, lm_metadata):
|
| 530 |
+
"""Send audio to cover mode and switch to cover"""
|
| 531 |
+
if audio_file is None:
|
| 532 |
+
return (gr.skip(),) * 11
|
| 533 |
+
return (
|
| 534 |
+
audio_file, # src_audio
|
| 535 |
+
gr.skip(), # bpm
|
| 536 |
+
gr.skip(), # captions
|
| 537 |
+
gr.skip(), # lyrics
|
| 538 |
+
gr.skip(), # audio_duration
|
| 539 |
+
gr.skip(), # key_scale
|
| 540 |
+
gr.skip(), # vocal_language
|
| 541 |
+
gr.skip(), # time_signature
|
| 542 |
+
gr.skip(), # is_format_caption_state
|
| 543 |
+
"cover", # generation_mode - switch to cover
|
| 544 |
+
"cover", # task_type - set to cover
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
for btn_idx in range(1, 9):
|
| 548 |
+
results_section[f"send_to_cover_btn_{btn_idx}"].click(
|
| 549 |
+
fn=send_to_cover_handler,
|
| 550 |
+
inputs=[
|
| 551 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 552 |
+
results_section["lm_metadata_state"]
|
| 553 |
+
],
|
| 554 |
+
outputs=[
|
| 555 |
+
generation_section["src_audio"],
|
| 556 |
+
generation_section["bpm"],
|
| 557 |
+
generation_section["captions"],
|
| 558 |
+
generation_section["lyrics"],
|
| 559 |
+
generation_section["audio_duration"],
|
| 560 |
+
generation_section["key_scale"],
|
| 561 |
+
generation_section["vocal_language"],
|
| 562 |
+
generation_section["time_signature"],
|
| 563 |
+
results_section["is_format_caption_state"],
|
| 564 |
+
generation_section["generation_mode"],
|
| 565 |
+
generation_section["task_type"],
|
| 566 |
+
]
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# ========== Send to Repaint Handlers ==========
|
| 570 |
+
def send_to_repaint_handler(audio_file, lm_metadata):
|
| 571 |
+
"""Send audio to repaint mode and switch to repaint"""
|
| 572 |
+
if audio_file is None:
|
| 573 |
+
return (gr.skip(),) * 11
|
| 574 |
+
return (
|
| 575 |
+
audio_file, # src_audio
|
| 576 |
+
gr.skip(), # bpm
|
| 577 |
+
gr.skip(), # captions
|
| 578 |
+
gr.skip(), # lyrics
|
| 579 |
+
gr.skip(), # audio_duration
|
| 580 |
+
gr.skip(), # key_scale
|
| 581 |
+
gr.skip(), # vocal_language
|
| 582 |
+
gr.skip(), # time_signature
|
| 583 |
+
gr.skip(), # is_format_caption_state
|
| 584 |
+
"repaint", # generation_mode - switch to repaint
|
| 585 |
+
"repaint", # task_type - set to repaint
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
for btn_idx in range(1, 9):
|
| 589 |
+
results_section[f"send_to_repaint_btn_{btn_idx}"].click(
|
| 590 |
+
fn=send_to_repaint_handler,
|
| 591 |
+
inputs=[
|
| 592 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 593 |
+
results_section["lm_metadata_state"]
|
| 594 |
+
],
|
| 595 |
+
outputs=[
|
| 596 |
+
generation_section["src_audio"],
|
| 597 |
+
generation_section["bpm"],
|
| 598 |
+
generation_section["captions"],
|
| 599 |
+
generation_section["lyrics"],
|
| 600 |
+
generation_section["audio_duration"],
|
| 601 |
+
generation_section["key_scale"],
|
| 602 |
+
generation_section["vocal_language"],
|
| 603 |
+
generation_section["time_signature"],
|
| 604 |
+
results_section["is_format_caption_state"],
|
| 605 |
+
generation_section["generation_mode"],
|
| 606 |
+
generation_section["task_type"],
|
| 607 |
+
]
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
# ========== Score Calculation Handlers ==========
|
| 611 |
+
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 612 |
+
# Note: @spaces.GPU decorator applied here (not on module-level function) to avoid
|
| 613 |
+
# pickling issues on ZeroGPU when handler objects are captured in closures.
|
| 614 |
+
def make_score_handler(idx):
|
| 615 |
+
@_get_spaces_gpu_decorator(duration=120)
|
| 616 |
+
def score_handler(scale, batch_idx, queue):
|
| 617 |
+
return res_h.calculate_score_handler_with_selection(
|
| 618 |
+
dit_handler, llm_handler, idx, scale, batch_idx, queue
|
| 619 |
+
)
|
| 620 |
+
return score_handler
|
| 621 |
+
|
| 622 |
+
for btn_idx in range(1, 9):
|
| 623 |
+
results_section[f"score_btn_{btn_idx}"].click(
|
| 624 |
+
fn=make_score_handler(btn_idx),
|
| 625 |
+
inputs=[
|
| 626 |
+
generation_section["score_scale"],
|
| 627 |
+
results_section["current_batch_index"],
|
| 628 |
+
results_section["batch_queue"],
|
| 629 |
+
],
|
| 630 |
+
outputs=[
|
| 631 |
+
results_section[f"score_display_{btn_idx}"],
|
| 632 |
+
results_section[f"details_accordion_{btn_idx}"],
|
| 633 |
+
results_section["batch_queue"]
|
| 634 |
+
]
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
# ========== LRC Timestamp Handlers ==========
|
| 638 |
+
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 639 |
+
def make_lrc_handler(idx):
|
| 640 |
+
@_get_spaces_gpu_decorator(duration=120)
|
| 641 |
+
def lrc_handler(batch_idx, queue, vocal_lang, infer_steps):
|
| 642 |
+
return res_h.generate_lrc_handler(
|
| 643 |
+
dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
|
| 644 |
+
)
|
| 645 |
+
return lrc_handler
|
| 646 |
+
|
| 647 |
+
for btn_idx in range(1, 9):
|
| 648 |
+
results_section[f"lrc_btn_{btn_idx}"].click(
|
| 649 |
+
fn=make_lrc_handler(btn_idx),
|
| 650 |
+
inputs=[
|
| 651 |
+
results_section["current_batch_index"],
|
| 652 |
+
results_section["batch_queue"],
|
| 653 |
+
generation_section["vocal_language"],
|
| 654 |
+
generation_section["inference_steps"],
|
| 655 |
+
],
|
| 656 |
+
outputs=[
|
| 657 |
+
results_section[f"lrc_display_{btn_idx}"],
|
| 658 |
+
results_section[f"details_accordion_{btn_idx}"],
|
| 659 |
+
# NOTE: Removed generated_audio output!
|
| 660 |
+
# Audio subtitles are now updated via lrc_display.change() event.
|
| 661 |
+
results_section["batch_queue"]
|
| 662 |
+
]
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
@_get_spaces_gpu_decorator(duration=120)
|
| 666 |
+
def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
|
| 667 |
+
"""Wrapper that selects the appropriate DiT handler based on model selection"""
|
| 668 |
+
# Convert args to list for modification
|
| 669 |
+
args_list = list(args)
|
| 670 |
+
|
| 671 |
+
# args order (after simple mode params):
|
| 672 |
+
# captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
|
| 673 |
+
# inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
|
| 674 |
+
# reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
|
| 675 |
+
# text2music_audio_code_string (14), repainting_start (15), repainting_end (16),
|
| 676 |
+
# instruction_display_gen (17), audio_cover_strength (18), task_type (19), ...
|
| 677 |
+
# ... lm_temperature (27), think_checkbox (28), ...
|
| 678 |
+
# ... instrumental_checkbox (at position after all regular params)
|
| 679 |
+
|
| 680 |
+
src_audio = args_list[13] if len(args_list) > 13 else None
|
| 681 |
+
task_type = args_list[19] if len(args_list) > 19 else "text2music"
|
| 682 |
+
|
| 683 |
+
# Validate: Cover and Repaint modes require source audio
|
| 684 |
+
if task_type in ["cover", "repaint"] and src_audio is None:
|
| 685 |
+
raise gr.Error(f"Source Audio is required for {task_type.capitalize()} mode. Please upload an audio file.")
|
| 686 |
+
|
| 687 |
+
# Handle Simple mode: first create sample, then generate
|
| 688 |
+
if generation_mode == "simple":
|
| 689 |
+
# Get instrumental from the main checkbox (args[-6] based on input order)
|
| 690 |
+
# The instrumental_checkbox is passed after all the regular generation params
|
| 691 |
+
instrumental = args_list[-6] if len(args_list) > 6 else False # instrumental_checkbox position
|
| 692 |
+
lm_temperature = args_list[27] if len(args_list) > 27 else 0.85
|
| 693 |
+
lm_top_k = args_list[30] if len(args_list) > 30 else 0
|
| 694 |
+
lm_top_p = args_list[31] if len(args_list) > 31 else 0.9
|
| 695 |
+
constrained_decoding_debug = args_list[38] if len(args_list) > 38 else False
|
| 696 |
+
|
| 697 |
+
# Call create_sample to generate caption/lyrics/metadata
|
| 698 |
+
from acestep.inference import create_sample
|
| 699 |
+
|
| 700 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 701 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 702 |
+
|
| 703 |
+
result = create_sample(
|
| 704 |
+
llm_handler=llm_handler,
|
| 705 |
+
query=simple_query_input,
|
| 706 |
+
instrumental=instrumental,
|
| 707 |
+
vocal_language=simple_vocal_language,
|
| 708 |
+
temperature=lm_temperature,
|
| 709 |
+
top_k=top_k_value,
|
| 710 |
+
top_p=top_p_value,
|
| 711 |
+
use_constrained_decoding=True,
|
| 712 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
if not result.success:
|
| 716 |
+
raise gr.Error(f"Failed to create sample: {result.status_message}")
|
| 717 |
+
|
| 718 |
+
# Update args with generated data
|
| 719 |
+
args_list[0] = result.caption # captions
|
| 720 |
+
args_list[1] = result.lyrics # lyrics
|
| 721 |
+
args_list[2] = result.bpm # bpm
|
| 722 |
+
args_list[3] = result.keyscale # key_scale
|
| 723 |
+
args_list[4] = result.timesignature # time_signature
|
| 724 |
+
args_list[5] = result.language # vocal_language
|
| 725 |
+
if result.duration and result.duration > 0:
|
| 726 |
+
args_list[11] = result.duration # audio_duration
|
| 727 |
+
# Enable thinking for Simple mode
|
| 728 |
+
args_list[28] = True # think_checkbox
|
| 729 |
+
# Mark as formatted caption (LM-generated sample)
|
| 730 |
+
args_list[36] = True # is_format_caption_state
|
| 731 |
+
|
| 732 |
+
# Determine which handler to use based on model selection
|
| 733 |
+
active_handler = dit_handler # Default to primary handler
|
| 734 |
+
if dit_handler_2 is not None and selected_model == config_path_2:
|
| 735 |
+
active_handler = dit_handler_2
|
| 736 |
+
yield from res_h.generate_with_batch_management(active_handler, llm_handler, *args_list)
|
| 737 |
+
|
| 738 |
+
# ========== Generation Handler ==========
|
| 739 |
+
generation_section["generate_btn"].click(
|
| 740 |
+
fn=generation_wrapper,
|
| 741 |
+
inputs=[
|
| 742 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 743 |
+
generation_section["generation_mode"], # For Simple mode detection
|
| 744 |
+
generation_section["simple_query_input"], # Simple mode query
|
| 745 |
+
generation_section["simple_vocal_language"], # Simple mode vocal language
|
| 746 |
+
generation_section["captions"],
|
| 747 |
+
generation_section["lyrics"],
|
| 748 |
+
generation_section["bpm"],
|
| 749 |
+
generation_section["key_scale"],
|
| 750 |
+
generation_section["time_signature"],
|
| 751 |
+
generation_section["vocal_language"],
|
| 752 |
+
generation_section["inference_steps"],
|
| 753 |
+
generation_section["guidance_scale"],
|
| 754 |
+
generation_section["random_seed_checkbox"],
|
| 755 |
+
generation_section["seed"],
|
| 756 |
+
generation_section["reference_audio"],
|
| 757 |
+
generation_section["audio_duration"],
|
| 758 |
+
generation_section["batch_size_input"],
|
| 759 |
+
generation_section["src_audio"],
|
| 760 |
+
generation_section["text2music_audio_code_string"],
|
| 761 |
+
generation_section["repainting_start"],
|
| 762 |
+
generation_section["repainting_end"],
|
| 763 |
+
generation_section["instruction_display_gen"],
|
| 764 |
+
generation_section["audio_cover_strength"],
|
| 765 |
+
generation_section["task_type"],
|
| 766 |
+
generation_section["use_adg"],
|
| 767 |
+
generation_section["cfg_interval_start"],
|
| 768 |
+
generation_section["cfg_interval_end"],
|
| 769 |
+
generation_section["shift"],
|
| 770 |
+
generation_section["infer_method"],
|
| 771 |
+
generation_section["custom_timesteps"],
|
| 772 |
+
generation_section["audio_format"],
|
| 773 |
+
generation_section["lm_temperature"],
|
| 774 |
+
generation_section["think_checkbox"],
|
| 775 |
+
generation_section["lm_cfg_scale"],
|
| 776 |
+
generation_section["lm_top_k"],
|
| 777 |
+
generation_section["lm_top_p"],
|
| 778 |
+
generation_section["lm_negative_prompt"],
|
| 779 |
+
generation_section["use_cot_metas"],
|
| 780 |
+
generation_section["use_cot_caption"],
|
| 781 |
+
generation_section["use_cot_language"],
|
| 782 |
+
results_section["is_format_caption_state"],
|
| 783 |
+
generation_section["constrained_decoding_debug"],
|
| 784 |
+
generation_section["allow_lm_batch"],
|
| 785 |
+
generation_section["auto_score"],
|
| 786 |
+
generation_section["auto_lrc"],
|
| 787 |
+
generation_section["score_scale"],
|
| 788 |
+
generation_section["lm_batch_chunk_size"],
|
| 789 |
+
generation_section["track_name"],
|
| 790 |
+
generation_section["complete_track_classes"],
|
| 791 |
+
generation_section["autogen_checkbox"],
|
| 792 |
+
results_section["current_batch_index"],
|
| 793 |
+
results_section["total_batches"],
|
| 794 |
+
results_section["batch_queue"],
|
| 795 |
+
results_section["generation_params_state"],
|
| 796 |
+
],
|
| 797 |
+
outputs=[
|
| 798 |
+
results_section["generated_audio_1"],
|
| 799 |
+
results_section["generated_audio_2"],
|
| 800 |
+
results_section["generated_audio_3"],
|
| 801 |
+
results_section["generated_audio_4"],
|
| 802 |
+
results_section["generated_audio_5"],
|
| 803 |
+
results_section["generated_audio_6"],
|
| 804 |
+
results_section["generated_audio_7"],
|
| 805 |
+
results_section["generated_audio_8"],
|
| 806 |
+
results_section["generated_audio_batch"],
|
| 807 |
+
results_section["generation_info"],
|
| 808 |
+
results_section["status_output"],
|
| 809 |
+
generation_section["seed"],
|
| 810 |
+
results_section["score_display_1"],
|
| 811 |
+
results_section["score_display_2"],
|
| 812 |
+
results_section["score_display_3"],
|
| 813 |
+
results_section["score_display_4"],
|
| 814 |
+
results_section["score_display_5"],
|
| 815 |
+
results_section["score_display_6"],
|
| 816 |
+
results_section["score_display_7"],
|
| 817 |
+
results_section["score_display_8"],
|
| 818 |
+
results_section["codes_display_1"],
|
| 819 |
+
results_section["codes_display_2"],
|
| 820 |
+
results_section["codes_display_3"],
|
| 821 |
+
results_section["codes_display_4"],
|
| 822 |
+
results_section["codes_display_5"],
|
| 823 |
+
results_section["codes_display_6"],
|
| 824 |
+
results_section["codes_display_7"],
|
| 825 |
+
results_section["codes_display_8"],
|
| 826 |
+
results_section["details_accordion_1"],
|
| 827 |
+
results_section["details_accordion_2"],
|
| 828 |
+
results_section["details_accordion_3"],
|
| 829 |
+
results_section["details_accordion_4"],
|
| 830 |
+
results_section["details_accordion_5"],
|
| 831 |
+
results_section["details_accordion_6"],
|
| 832 |
+
results_section["details_accordion_7"],
|
| 833 |
+
results_section["details_accordion_8"],
|
| 834 |
+
results_section["lrc_display_1"],
|
| 835 |
+
results_section["lrc_display_2"],
|
| 836 |
+
results_section["lrc_display_3"],
|
| 837 |
+
results_section["lrc_display_4"],
|
| 838 |
+
results_section["lrc_display_5"],
|
| 839 |
+
results_section["lrc_display_6"],
|
| 840 |
+
results_section["lrc_display_7"],
|
| 841 |
+
results_section["lrc_display_8"],
|
| 842 |
+
results_section["lm_metadata_state"],
|
| 843 |
+
results_section["is_format_caption_state"],
|
| 844 |
+
results_section["current_batch_index"],
|
| 845 |
+
results_section["total_batches"],
|
| 846 |
+
results_section["batch_queue"],
|
| 847 |
+
results_section["generation_params_state"],
|
| 848 |
+
results_section["batch_indicator"],
|
| 849 |
+
results_section["prev_batch_btn"],
|
| 850 |
+
results_section["next_batch_btn"],
|
| 851 |
+
results_section["next_batch_status"],
|
| 852 |
+
results_section["restore_params_btn"],
|
| 853 |
+
]
|
| 854 |
+
).then(
|
| 855 |
+
fn=lambda selected_model, *args: res_h.generate_next_batch_background(
|
| 856 |
+
dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
|
| 857 |
+
llm_handler, *args
|
| 858 |
+
),
|
| 859 |
+
inputs=[
|
| 860 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 861 |
+
generation_section["autogen_checkbox"],
|
| 862 |
+
results_section["generation_params_state"],
|
| 863 |
+
results_section["current_batch_index"],
|
| 864 |
+
results_section["total_batches"],
|
| 865 |
+
results_section["batch_queue"],
|
| 866 |
+
results_section["is_format_caption_state"],
|
| 867 |
+
],
|
| 868 |
+
outputs=[
|
| 869 |
+
results_section["batch_queue"],
|
| 870 |
+
results_section["total_batches"],
|
| 871 |
+
results_section["next_batch_status"],
|
| 872 |
+
results_section["next_batch_btn"],
|
| 873 |
+
]
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
# ========== Batch Navigation Handlers ==========
|
| 877 |
+
results_section["prev_batch_btn"].click(
|
| 878 |
+
fn=res_h.navigate_to_previous_batch,
|
| 879 |
+
inputs=[
|
| 880 |
+
results_section["current_batch_index"],
|
| 881 |
+
results_section["batch_queue"],
|
| 882 |
+
],
|
| 883 |
+
outputs=[
|
| 884 |
+
results_section["generated_audio_1"],
|
| 885 |
+
results_section["generated_audio_2"],
|
| 886 |
+
results_section["generated_audio_3"],
|
| 887 |
+
results_section["generated_audio_4"],
|
| 888 |
+
results_section["generated_audio_5"],
|
| 889 |
+
results_section["generated_audio_6"],
|
| 890 |
+
results_section["generated_audio_7"],
|
| 891 |
+
results_section["generated_audio_8"],
|
| 892 |
+
results_section["generated_audio_batch"],
|
| 893 |
+
results_section["generation_info"],
|
| 894 |
+
results_section["current_batch_index"],
|
| 895 |
+
results_section["batch_indicator"],
|
| 896 |
+
results_section["prev_batch_btn"],
|
| 897 |
+
results_section["next_batch_btn"],
|
| 898 |
+
results_section["status_output"],
|
| 899 |
+
results_section["score_display_1"],
|
| 900 |
+
results_section["score_display_2"],
|
| 901 |
+
results_section["score_display_3"],
|
| 902 |
+
results_section["score_display_4"],
|
| 903 |
+
results_section["score_display_5"],
|
| 904 |
+
results_section["score_display_6"],
|
| 905 |
+
results_section["score_display_7"],
|
| 906 |
+
results_section["score_display_8"],
|
| 907 |
+
results_section["codes_display_1"],
|
| 908 |
+
results_section["codes_display_2"],
|
| 909 |
+
results_section["codes_display_3"],
|
| 910 |
+
results_section["codes_display_4"],
|
| 911 |
+
results_section["codes_display_5"],
|
| 912 |
+
results_section["codes_display_6"],
|
| 913 |
+
results_section["codes_display_7"],
|
| 914 |
+
results_section["codes_display_8"],
|
| 915 |
+
results_section["lrc_display_1"],
|
| 916 |
+
results_section["lrc_display_2"],
|
| 917 |
+
results_section["lrc_display_3"],
|
| 918 |
+
results_section["lrc_display_4"],
|
| 919 |
+
results_section["lrc_display_5"],
|
| 920 |
+
results_section["lrc_display_6"],
|
| 921 |
+
results_section["lrc_display_7"],
|
| 922 |
+
results_section["lrc_display_8"],
|
| 923 |
+
results_section["details_accordion_1"],
|
| 924 |
+
results_section["details_accordion_2"],
|
| 925 |
+
results_section["details_accordion_3"],
|
| 926 |
+
results_section["details_accordion_4"],
|
| 927 |
+
results_section["details_accordion_5"],
|
| 928 |
+
results_section["details_accordion_6"],
|
| 929 |
+
results_section["details_accordion_7"],
|
| 930 |
+
results_section["details_accordion_8"],
|
| 931 |
+
results_section["restore_params_btn"],
|
| 932 |
+
]
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
results_section["next_batch_btn"].click(
|
| 936 |
+
fn=res_h.capture_current_params,
|
| 937 |
+
inputs=[
|
| 938 |
+
generation_section["captions"],
|
| 939 |
+
generation_section["lyrics"],
|
| 940 |
+
generation_section["bpm"],
|
| 941 |
+
generation_section["key_scale"],
|
| 942 |
+
generation_section["time_signature"],
|
| 943 |
+
generation_section["vocal_language"],
|
| 944 |
+
generation_section["inference_steps"],
|
| 945 |
+
generation_section["guidance_scale"],
|
| 946 |
+
generation_section["random_seed_checkbox"],
|
| 947 |
+
generation_section["seed"],
|
| 948 |
+
generation_section["reference_audio"],
|
| 949 |
+
generation_section["audio_duration"],
|
| 950 |
+
generation_section["batch_size_input"],
|
| 951 |
+
generation_section["src_audio"],
|
| 952 |
+
generation_section["text2music_audio_code_string"],
|
| 953 |
+
generation_section["repainting_start"],
|
| 954 |
+
generation_section["repainting_end"],
|
| 955 |
+
generation_section["instruction_display_gen"],
|
| 956 |
+
generation_section["audio_cover_strength"],
|
| 957 |
+
generation_section["task_type"],
|
| 958 |
+
generation_section["use_adg"],
|
| 959 |
+
generation_section["cfg_interval_start"],
|
| 960 |
+
generation_section["cfg_interval_end"],
|
| 961 |
+
generation_section["shift"],
|
| 962 |
+
generation_section["infer_method"],
|
| 963 |
+
generation_section["custom_timesteps"],
|
| 964 |
+
generation_section["audio_format"],
|
| 965 |
+
generation_section["lm_temperature"],
|
| 966 |
+
generation_section["think_checkbox"],
|
| 967 |
+
generation_section["lm_cfg_scale"],
|
| 968 |
+
generation_section["lm_top_k"],
|
| 969 |
+
generation_section["lm_top_p"],
|
| 970 |
+
generation_section["lm_negative_prompt"],
|
| 971 |
+
generation_section["use_cot_metas"],
|
| 972 |
+
generation_section["use_cot_caption"],
|
| 973 |
+
generation_section["use_cot_language"],
|
| 974 |
+
generation_section["constrained_decoding_debug"],
|
| 975 |
+
generation_section["allow_lm_batch"],
|
| 976 |
+
generation_section["auto_score"],
|
| 977 |
+
generation_section["auto_lrc"],
|
| 978 |
+
generation_section["score_scale"],
|
| 979 |
+
generation_section["lm_batch_chunk_size"],
|
| 980 |
+
generation_section["track_name"],
|
| 981 |
+
generation_section["complete_track_classes"],
|
| 982 |
+
],
|
| 983 |
+
outputs=[results_section["generation_params_state"]]
|
| 984 |
+
).then(
|
| 985 |
+
fn=res_h.navigate_to_next_batch,
|
| 986 |
+
inputs=[
|
| 987 |
+
generation_section["autogen_checkbox"],
|
| 988 |
+
results_section["current_batch_index"],
|
| 989 |
+
results_section["total_batches"],
|
| 990 |
+
results_section["batch_queue"],
|
| 991 |
+
],
|
| 992 |
+
outputs=[
|
| 993 |
+
results_section["generated_audio_1"],
|
| 994 |
+
results_section["generated_audio_2"],
|
| 995 |
+
results_section["generated_audio_3"],
|
| 996 |
+
results_section["generated_audio_4"],
|
| 997 |
+
results_section["generated_audio_5"],
|
| 998 |
+
results_section["generated_audio_6"],
|
| 999 |
+
results_section["generated_audio_7"],
|
| 1000 |
+
results_section["generated_audio_8"],
|
| 1001 |
+
results_section["generated_audio_batch"],
|
| 1002 |
+
results_section["generation_info"],
|
| 1003 |
+
results_section["current_batch_index"],
|
| 1004 |
+
results_section["batch_indicator"],
|
| 1005 |
+
results_section["prev_batch_btn"],
|
| 1006 |
+
results_section["next_batch_btn"],
|
| 1007 |
+
results_section["status_output"],
|
| 1008 |
+
results_section["next_batch_status"],
|
| 1009 |
+
results_section["score_display_1"],
|
| 1010 |
+
results_section["score_display_2"],
|
| 1011 |
+
results_section["score_display_3"],
|
| 1012 |
+
results_section["score_display_4"],
|
| 1013 |
+
results_section["score_display_5"],
|
| 1014 |
+
results_section["score_display_6"],
|
| 1015 |
+
results_section["score_display_7"],
|
| 1016 |
+
results_section["score_display_8"],
|
| 1017 |
+
results_section["codes_display_1"],
|
| 1018 |
+
results_section["codes_display_2"],
|
| 1019 |
+
results_section["codes_display_3"],
|
| 1020 |
+
results_section["codes_display_4"],
|
| 1021 |
+
results_section["codes_display_5"],
|
| 1022 |
+
results_section["codes_display_6"],
|
| 1023 |
+
results_section["codes_display_7"],
|
| 1024 |
+
results_section["codes_display_8"],
|
| 1025 |
+
results_section["lrc_display_1"],
|
| 1026 |
+
results_section["lrc_display_2"],
|
| 1027 |
+
results_section["lrc_display_3"],
|
| 1028 |
+
results_section["lrc_display_4"],
|
| 1029 |
+
results_section["lrc_display_5"],
|
| 1030 |
+
results_section["lrc_display_6"],
|
| 1031 |
+
results_section["lrc_display_7"],
|
| 1032 |
+
results_section["lrc_display_8"],
|
| 1033 |
+
results_section["details_accordion_1"],
|
| 1034 |
+
results_section["details_accordion_2"],
|
| 1035 |
+
results_section["details_accordion_3"],
|
| 1036 |
+
results_section["details_accordion_4"],
|
| 1037 |
+
results_section["details_accordion_5"],
|
| 1038 |
+
results_section["details_accordion_6"],
|
| 1039 |
+
results_section["details_accordion_7"],
|
| 1040 |
+
results_section["details_accordion_8"],
|
| 1041 |
+
results_section["restore_params_btn"],
|
| 1042 |
+
]
|
| 1043 |
+
).then(
|
| 1044 |
+
fn=lambda selected_model, *args: res_h.generate_next_batch_background(
|
| 1045 |
+
dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
|
| 1046 |
+
llm_handler, *args
|
| 1047 |
+
),
|
| 1048 |
+
inputs=[
|
| 1049 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 1050 |
+
generation_section["autogen_checkbox"],
|
| 1051 |
+
results_section["generation_params_state"],
|
| 1052 |
+
results_section["current_batch_index"],
|
| 1053 |
+
results_section["total_batches"],
|
| 1054 |
+
results_section["batch_queue"],
|
| 1055 |
+
results_section["is_format_caption_state"],
|
| 1056 |
+
],
|
| 1057 |
+
outputs=[
|
| 1058 |
+
results_section["batch_queue"],
|
| 1059 |
+
results_section["total_batches"],
|
| 1060 |
+
results_section["next_batch_status"],
|
| 1061 |
+
results_section["next_batch_btn"],
|
| 1062 |
+
]
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
# ========== Restore Parameters Handler ==========
|
| 1066 |
+
results_section["restore_params_btn"].click(
|
| 1067 |
+
fn=res_h.restore_batch_parameters,
|
| 1068 |
+
inputs=[
|
| 1069 |
+
results_section["current_batch_index"],
|
| 1070 |
+
results_section["batch_queue"]
|
| 1071 |
+
],
|
| 1072 |
+
outputs=[
|
| 1073 |
+
generation_section["text2music_audio_code_string"],
|
| 1074 |
+
generation_section["captions"],
|
| 1075 |
+
generation_section["lyrics"],
|
| 1076 |
+
generation_section["bpm"],
|
| 1077 |
+
generation_section["key_scale"],
|
| 1078 |
+
generation_section["time_signature"],
|
| 1079 |
+
generation_section["vocal_language"],
|
| 1080 |
+
generation_section["audio_duration"],
|
| 1081 |
+
generation_section["batch_size_input"],
|
| 1082 |
+
generation_section["inference_steps"],
|
| 1083 |
+
generation_section["lm_temperature"],
|
| 1084 |
+
generation_section["lm_cfg_scale"],
|
| 1085 |
+
generation_section["lm_top_k"],
|
| 1086 |
+
generation_section["lm_top_p"],
|
| 1087 |
+
generation_section["think_checkbox"],
|
| 1088 |
+
generation_section["use_cot_caption"],
|
| 1089 |
+
generation_section["use_cot_language"],
|
| 1090 |
+
generation_section["allow_lm_batch"],
|
| 1091 |
+
generation_section["track_name"],
|
| 1092 |
+
generation_section["complete_track_classes"],
|
| 1093 |
+
]
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
# ========== LRC Display Change Handlers ==========
|
| 1097 |
+
# NEW APPROACH: Use lrc_display.change() to update audio subtitles
|
| 1098 |
+
# This decouples audio value updates from subtitle updates, avoiding flickering.
|
| 1099 |
+
#
|
| 1100 |
+
# When lrc_display text changes (from generate, LRC button, or manual edit):
|
| 1101 |
+
# 1. lrc_display.change() is triggered
|
| 1102 |
+
# 2. update_audio_subtitles_from_lrc() parses LRC and updates audio subtitles
|
| 1103 |
+
# 3. Audio value is NEVER updated here - only subtitles
|
| 1104 |
+
for lrc_idx in range(1, 9):
|
| 1105 |
+
results_section[f"lrc_display_{lrc_idx}"].change(
|
| 1106 |
+
fn=res_h.update_audio_subtitles_from_lrc,
|
| 1107 |
+
inputs=[
|
| 1108 |
+
results_section[f"lrc_display_{lrc_idx}"],
|
| 1109 |
+
# audio_duration not needed - parse_lrc_to_subtitles calculates end time from timestamps
|
| 1110 |
+
],
|
| 1111 |
+
outputs=[
|
| 1112 |
+
results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
|
| 1113 |
+
]
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
|
| 1118 |
+
"""Setup event handlers for the training tab (dataset builder and LoRA training)"""
|
| 1119 |
+
|
| 1120 |
+
# ========== Load Existing Dataset (Top Section) ==========
|
| 1121 |
+
|
| 1122 |
+
# Load existing dataset JSON at the top of Dataset Builder
|
| 1123 |
+
training_section["load_json_btn"].click(
|
| 1124 |
+
fn=train_h.load_existing_dataset_for_preprocess,
|
| 1125 |
+
inputs=[
|
| 1126 |
+
training_section["load_json_path"],
|
| 1127 |
+
training_section["dataset_builder_state"],
|
| 1128 |
+
],
|
| 1129 |
+
outputs=[
|
| 1130 |
+
training_section["load_json_status"],
|
| 1131 |
+
training_section["audio_files_table"],
|
| 1132 |
+
training_section["sample_selector"],
|
| 1133 |
+
training_section["dataset_builder_state"],
|
| 1134 |
+
# Also update preview fields with first sample
|
| 1135 |
+
training_section["preview_audio"],
|
| 1136 |
+
training_section["preview_filename"],
|
| 1137 |
+
training_section["edit_caption"],
|
| 1138 |
+
training_section["edit_lyrics"],
|
| 1139 |
+
training_section["edit_bpm"],
|
| 1140 |
+
training_section["edit_keyscale"],
|
| 1141 |
+
training_section["edit_timesig"],
|
| 1142 |
+
training_section["edit_duration"],
|
| 1143 |
+
training_section["edit_language"],
|
| 1144 |
+
training_section["edit_instrumental"],
|
| 1145 |
+
]
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
# ========== Dataset Builder Handlers ==========
|
| 1149 |
+
|
| 1150 |
+
# Scan directory for audio files
|
| 1151 |
+
training_section["scan_btn"].click(
|
| 1152 |
+
fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
|
| 1153 |
+
dir, name, tag, pos, instr, state
|
| 1154 |
+
),
|
| 1155 |
+
inputs=[
|
| 1156 |
+
training_section["audio_directory"],
|
| 1157 |
+
training_section["dataset_name"],
|
| 1158 |
+
training_section["custom_tag"],
|
| 1159 |
+
training_section["tag_position"],
|
| 1160 |
+
training_section["all_instrumental"],
|
| 1161 |
+
training_section["dataset_builder_state"],
|
| 1162 |
+
],
|
| 1163 |
+
outputs=[
|
| 1164 |
+
training_section["audio_files_table"],
|
| 1165 |
+
training_section["scan_status"],
|
| 1166 |
+
training_section["sample_selector"],
|
| 1167 |
+
training_section["dataset_builder_state"],
|
| 1168 |
+
]
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
# Auto-label all samples
|
| 1172 |
+
training_section["auto_label_btn"].click(
|
| 1173 |
+
fn=lambda state, skip: train_h.auto_label_all(dit_handler, llm_handler, state, skip),
|
| 1174 |
+
inputs=[
|
| 1175 |
+
training_section["dataset_builder_state"],
|
| 1176 |
+
training_section["skip_metas"],
|
| 1177 |
+
],
|
| 1178 |
+
outputs=[
|
| 1179 |
+
training_section["audio_files_table"],
|
| 1180 |
+
training_section["label_progress"],
|
| 1181 |
+
training_section["dataset_builder_state"],
|
| 1182 |
+
]
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
# Sample selector change - update preview
|
| 1186 |
+
training_section["sample_selector"].change(
|
| 1187 |
+
fn=train_h.get_sample_preview,
|
| 1188 |
+
inputs=[
|
| 1189 |
+
training_section["sample_selector"],
|
| 1190 |
+
training_section["dataset_builder_state"],
|
| 1191 |
+
],
|
| 1192 |
+
outputs=[
|
| 1193 |
+
training_section["preview_audio"],
|
| 1194 |
+
training_section["preview_filename"],
|
| 1195 |
+
training_section["edit_caption"],
|
| 1196 |
+
training_section["edit_lyrics"],
|
| 1197 |
+
training_section["edit_bpm"],
|
| 1198 |
+
training_section["edit_keyscale"],
|
| 1199 |
+
training_section["edit_timesig"],
|
| 1200 |
+
training_section["edit_duration"],
|
| 1201 |
+
training_section["edit_language"],
|
| 1202 |
+
training_section["edit_instrumental"],
|
| 1203 |
+
]
|
| 1204 |
+
)
|
| 1205 |
+
|
| 1206 |
+
# Save sample edit
|
| 1207 |
+
training_section["save_edit_btn"].click(
|
| 1208 |
+
fn=train_h.save_sample_edit,
|
| 1209 |
+
inputs=[
|
| 1210 |
+
training_section["sample_selector"],
|
| 1211 |
+
training_section["edit_caption"],
|
| 1212 |
+
training_section["edit_lyrics"],
|
| 1213 |
+
training_section["edit_bpm"],
|
| 1214 |
+
training_section["edit_keyscale"],
|
| 1215 |
+
training_section["edit_timesig"],
|
| 1216 |
+
training_section["edit_language"],
|
| 1217 |
+
training_section["edit_instrumental"],
|
| 1218 |
+
training_section["dataset_builder_state"],
|
| 1219 |
+
],
|
| 1220 |
+
outputs=[
|
| 1221 |
+
training_section["audio_files_table"],
|
| 1222 |
+
training_section["edit_status"],
|
| 1223 |
+
training_section["dataset_builder_state"],
|
| 1224 |
+
]
|
| 1225 |
+
)
|
| 1226 |
+
|
| 1227 |
+
# Update settings when changed
|
| 1228 |
+
for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"]]:
|
| 1229 |
+
trigger.change(
|
| 1230 |
+
fn=train_h.update_settings,
|
| 1231 |
+
inputs=[
|
| 1232 |
+
training_section["custom_tag"],
|
| 1233 |
+
training_section["tag_position"],
|
| 1234 |
+
training_section["all_instrumental"],
|
| 1235 |
+
training_section["dataset_builder_state"],
|
| 1236 |
+
],
|
| 1237 |
+
outputs=[training_section["dataset_builder_state"]]
|
| 1238 |
+
)
|
| 1239 |
+
|
| 1240 |
+
# Save dataset
|
| 1241 |
+
training_section["save_dataset_btn"].click(
|
| 1242 |
+
fn=train_h.save_dataset,
|
| 1243 |
+
inputs=[
|
| 1244 |
+
training_section["save_path"],
|
| 1245 |
+
training_section["dataset_name"],
|
| 1246 |
+
training_section["dataset_builder_state"],
|
| 1247 |
+
],
|
| 1248 |
+
outputs=[training_section["save_status"]]
|
| 1249 |
+
)
|
| 1250 |
+
|
| 1251 |
+
# ========== Preprocess Handlers ==========
|
| 1252 |
+
|
| 1253 |
+
# Load existing dataset JSON for preprocessing
|
| 1254 |
+
# This also updates the preview section so users can view/edit samples
|
| 1255 |
+
training_section["load_existing_dataset_btn"].click(
|
| 1256 |
+
fn=train_h.load_existing_dataset_for_preprocess,
|
| 1257 |
+
inputs=[
|
| 1258 |
+
training_section["load_existing_dataset_path"],
|
| 1259 |
+
training_section["dataset_builder_state"],
|
| 1260 |
+
],
|
| 1261 |
+
outputs=[
|
| 1262 |
+
training_section["load_existing_status"],
|
| 1263 |
+
training_section["audio_files_table"],
|
| 1264 |
+
training_section["sample_selector"],
|
| 1265 |
+
training_section["dataset_builder_state"],
|
| 1266 |
+
# Also update preview fields with first sample
|
| 1267 |
+
training_section["preview_audio"],
|
| 1268 |
+
training_section["preview_filename"],
|
| 1269 |
+
training_section["edit_caption"],
|
| 1270 |
+
training_section["edit_lyrics"],
|
| 1271 |
+
training_section["edit_bpm"],
|
| 1272 |
+
training_section["edit_keyscale"],
|
| 1273 |
+
training_section["edit_timesig"],
|
| 1274 |
+
training_section["edit_duration"],
|
| 1275 |
+
training_section["edit_language"],
|
| 1276 |
+
training_section["edit_instrumental"],
|
| 1277 |
+
]
|
| 1278 |
+
)
|
| 1279 |
+
|
| 1280 |
+
# Preprocess dataset to tensor files
|
| 1281 |
+
training_section["preprocess_btn"].click(
|
| 1282 |
+
fn=lambda output_dir, state: train_h.preprocess_dataset(
|
| 1283 |
+
output_dir, dit_handler, state
|
| 1284 |
+
),
|
| 1285 |
+
inputs=[
|
| 1286 |
+
training_section["preprocess_output_dir"],
|
| 1287 |
+
training_section["dataset_builder_state"],
|
| 1288 |
+
],
|
| 1289 |
+
outputs=[training_section["preprocess_progress"]]
|
| 1290 |
+
)
|
| 1291 |
+
|
| 1292 |
+
# ========== Training Tab Handlers ==========
|
| 1293 |
+
|
| 1294 |
+
# Load preprocessed tensor dataset
|
| 1295 |
+
training_section["load_dataset_btn"].click(
|
| 1296 |
+
fn=train_h.load_training_dataset,
|
| 1297 |
+
inputs=[training_section["training_tensor_dir"]],
|
| 1298 |
+
outputs=[training_section["training_dataset_info"]]
|
| 1299 |
+
)
|
| 1300 |
+
|
| 1301 |
+
# Start training from preprocessed tensors
|
| 1302 |
+
def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts):
|
| 1303 |
+
try:
|
| 1304 |
+
for progress, log, plot, state in train_h.start_training(
|
| 1305 |
+
tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts
|
| 1306 |
+
):
|
| 1307 |
+
yield progress, log, plot, state
|
| 1308 |
+
except Exception as e:
|
| 1309 |
+
logger.exception("Training wrapper error")
|
| 1310 |
+
yield f"❌ Error: {str(e)}", str(e), None, ts
|
| 1311 |
+
|
| 1312 |
+
training_section["start_training_btn"].click(
|
| 1313 |
+
fn=training_wrapper,
|
| 1314 |
+
inputs=[
|
| 1315 |
+
training_section["training_tensor_dir"],
|
| 1316 |
+
training_section["lora_rank"],
|
| 1317 |
+
training_section["lora_alpha"],
|
| 1318 |
+
training_section["lora_dropout"],
|
| 1319 |
+
training_section["learning_rate"],
|
| 1320 |
+
training_section["train_epochs"],
|
| 1321 |
+
training_section["train_batch_size"],
|
| 1322 |
+
training_section["gradient_accumulation"],
|
| 1323 |
+
training_section["save_every_n_epochs"],
|
| 1324 |
+
training_section["training_shift"],
|
| 1325 |
+
training_section["training_seed"],
|
| 1326 |
+
training_section["lora_output_dir"],
|
| 1327 |
+
training_section["training_state"],
|
| 1328 |
+
],
|
| 1329 |
+
outputs=[
|
| 1330 |
+
training_section["training_progress"],
|
| 1331 |
+
training_section["training_log"],
|
| 1332 |
+
training_section["training_loss_plot"],
|
| 1333 |
+
training_section["training_state"],
|
| 1334 |
+
]
|
| 1335 |
+
)
|
| 1336 |
+
|
| 1337 |
+
# Stop training
|
| 1338 |
+
training_section["stop_training_btn"].click(
|
| 1339 |
+
fn=train_h.stop_training,
|
| 1340 |
+
inputs=[training_section["training_state"]],
|
| 1341 |
+
outputs=[
|
| 1342 |
+
training_section["training_progress"],
|
| 1343 |
+
training_section["training_state"],
|
| 1344 |
+
]
|
| 1345 |
+
)
|
| 1346 |
+
|
| 1347 |
+
# Export LoRA
|
| 1348 |
+
training_section["export_lora_btn"].click(
|
| 1349 |
+
fn=train_h.export_lora,
|
| 1350 |
+
inputs=[
|
| 1351 |
+
training_section["export_path"],
|
| 1352 |
+
training_section["lora_output_dir"],
|
| 1353 |
+
],
|
| 1354 |
+
outputs=[training_section["export_status"]]
|
| 1355 |
+
)
|
acestep/gradio_ui/events/generation_handlers.py
ADDED
|
@@ -0,0 +1,1071 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generation Input Handlers Module
|
| 3 |
+
Contains event handlers and helper functions related to generation inputs
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
import glob
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from typing import Optional, List, Tuple
|
| 11 |
+
from loguru import logger
|
| 12 |
+
from acestep.constants import (
|
| 13 |
+
TASK_TYPES_TURBO,
|
| 14 |
+
TASK_TYPES_BASE,
|
| 15 |
+
)
|
| 16 |
+
from acestep.gradio_ui.i18n import t
|
| 17 |
+
from acestep.inference import understand_music, create_sample, format_sample
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# HuggingFace Space environment detection for ZeroGPU support
|
| 21 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_spaces_gpu_decorator(duration=120):
|
| 25 |
+
"""
|
| 26 |
+
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 27 |
+
Returns identity decorator if not in Space environment.
|
| 28 |
+
"""
|
| 29 |
+
if IS_HUGGINGFACE_SPACE:
|
| 30 |
+
try:
|
| 31 |
+
import spaces
|
| 32 |
+
return spaces.GPU(duration=duration)
|
| 33 |
+
except ImportError:
|
| 34 |
+
logger.warning("spaces package not found, GPU decorator disabled")
|
| 35 |
+
return lambda func: func
|
| 36 |
+
return lambda func: func
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def parse_and_validate_timesteps(
|
| 40 |
+
timesteps_str: str,
|
| 41 |
+
inference_steps: int
|
| 42 |
+
) -> Tuple[Optional[List[float]], bool, str]:
|
| 43 |
+
"""
|
| 44 |
+
Parse timesteps string and validate.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 48 |
+
inference_steps: Expected number of inference steps
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tuple of (parsed_timesteps, has_warning, warning_message)
|
| 52 |
+
- parsed_timesteps: List of float timesteps, or None if invalid/empty
|
| 53 |
+
- has_warning: Whether a warning was shown
|
| 54 |
+
- warning_message: Description of the warning
|
| 55 |
+
"""
|
| 56 |
+
if not timesteps_str or not timesteps_str.strip():
|
| 57 |
+
return None, False, ""
|
| 58 |
+
|
| 59 |
+
# Parse comma-separated values
|
| 60 |
+
values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
|
| 61 |
+
|
| 62 |
+
if not values:
|
| 63 |
+
return None, False, ""
|
| 64 |
+
|
| 65 |
+
# Handle optional trailing 0
|
| 66 |
+
if values[-1] != "0":
|
| 67 |
+
values.append("0")
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
timesteps = [float(v) for v in values]
|
| 71 |
+
except ValueError:
|
| 72 |
+
gr.Warning(t("messages.invalid_timesteps_format"))
|
| 73 |
+
return None, True, "Invalid format"
|
| 74 |
+
|
| 75 |
+
# Validate range [0, 1]
|
| 76 |
+
if any(ts < 0 or ts > 1 for ts in timesteps):
|
| 77 |
+
gr.Warning(t("messages.timesteps_out_of_range"))
|
| 78 |
+
return None, True, "Out of range"
|
| 79 |
+
|
| 80 |
+
# Check if count matches inference_steps
|
| 81 |
+
actual_steps = len(timesteps) - 1
|
| 82 |
+
if actual_steps != inference_steps:
|
| 83 |
+
gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
|
| 84 |
+
return timesteps, True, f"Using {actual_steps} steps from timesteps"
|
| 85 |
+
|
| 86 |
+
return timesteps, False, ""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def load_metadata(file_obj):
|
| 90 |
+
"""Load generation parameters from a JSON file"""
|
| 91 |
+
if file_obj is None:
|
| 92 |
+
gr.Warning(t("messages.no_file_selected"))
|
| 93 |
+
return [None] * 36 + [False] # Return None for all fields, False for is_format_caption
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
# Read the uploaded file
|
| 97 |
+
if hasattr(file_obj, 'name'):
|
| 98 |
+
filepath = file_obj.name
|
| 99 |
+
else:
|
| 100 |
+
filepath = file_obj
|
| 101 |
+
|
| 102 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 103 |
+
metadata = json.load(f)
|
| 104 |
+
|
| 105 |
+
# Extract all fields
|
| 106 |
+
task_type = metadata.get('task_type', 'text2music')
|
| 107 |
+
captions = metadata.get('caption', '')
|
| 108 |
+
lyrics = metadata.get('lyrics', '')
|
| 109 |
+
vocal_language = metadata.get('vocal_language', 'unknown')
|
| 110 |
+
|
| 111 |
+
# Convert bpm
|
| 112 |
+
bpm_value = metadata.get('bpm')
|
| 113 |
+
if bpm_value is not None and bpm_value != "N/A":
|
| 114 |
+
try:
|
| 115 |
+
bpm = int(bpm_value) if bpm_value else None
|
| 116 |
+
except:
|
| 117 |
+
bpm = None
|
| 118 |
+
else:
|
| 119 |
+
bpm = None
|
| 120 |
+
|
| 121 |
+
key_scale = metadata.get('keyscale', '')
|
| 122 |
+
time_signature = metadata.get('timesignature', '')
|
| 123 |
+
|
| 124 |
+
# Convert duration
|
| 125 |
+
duration_value = metadata.get('duration', -1)
|
| 126 |
+
if duration_value is not None and duration_value != "N/A":
|
| 127 |
+
try:
|
| 128 |
+
audio_duration = float(duration_value)
|
| 129 |
+
except:
|
| 130 |
+
audio_duration = -1
|
| 131 |
+
else:
|
| 132 |
+
audio_duration = -1
|
| 133 |
+
|
| 134 |
+
batch_size = metadata.get('batch_size', 2)
|
| 135 |
+
inference_steps = metadata.get('inference_steps', 8)
|
| 136 |
+
guidance_scale = metadata.get('guidance_scale', 7.0)
|
| 137 |
+
seed = metadata.get('seed', '-1')
|
| 138 |
+
random_seed = False # Always set to False when loading to enable reproducibility with saved seed
|
| 139 |
+
use_adg = metadata.get('use_adg', False)
|
| 140 |
+
cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
|
| 141 |
+
cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
|
| 142 |
+
audio_format = metadata.get('audio_format', 'mp3')
|
| 143 |
+
lm_temperature = metadata.get('lm_temperature', 0.85)
|
| 144 |
+
lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
|
| 145 |
+
lm_top_k = metadata.get('lm_top_k', 0)
|
| 146 |
+
lm_top_p = metadata.get('lm_top_p', 0.9)
|
| 147 |
+
lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
|
| 148 |
+
use_cot_metas = metadata.get('use_cot_metas', True) # Added: read use_cot_metas
|
| 149 |
+
use_cot_caption = metadata.get('use_cot_caption', True)
|
| 150 |
+
use_cot_language = metadata.get('use_cot_language', True)
|
| 151 |
+
audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
|
| 152 |
+
think = metadata.get('thinking', True) # Fixed: read 'thinking' not 'think'
|
| 153 |
+
audio_codes = metadata.get('audio_codes', '')
|
| 154 |
+
repainting_start = metadata.get('repainting_start', 0.0)
|
| 155 |
+
repainting_end = metadata.get('repainting_end', -1)
|
| 156 |
+
track_name = metadata.get('track_name')
|
| 157 |
+
complete_track_classes = metadata.get('complete_track_classes', [])
|
| 158 |
+
shift = metadata.get('shift', 3.0) # Default 3.0 for base models
|
| 159 |
+
infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
|
| 160 |
+
custom_timesteps = metadata.get('timesteps', '') # Custom timesteps (stored as 'timesteps' in JSON)
|
| 161 |
+
if custom_timesteps is None:
|
| 162 |
+
custom_timesteps = ''
|
| 163 |
+
instrumental = metadata.get('instrumental', False) # Added: read instrumental
|
| 164 |
+
|
| 165 |
+
gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
|
| 166 |
+
|
| 167 |
+
return (
|
| 168 |
+
task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
|
| 169 |
+
audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
|
| 170 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method,
|
| 171 |
+
custom_timesteps, # Added: custom_timesteps (between infer_method and audio_format)
|
| 172 |
+
audio_format, lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 173 |
+
use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
|
| 174 |
+
think, audio_codes, repainting_start, repainting_end,
|
| 175 |
+
track_name, complete_track_classes, instrumental,
|
| 176 |
+
True # Set is_format_caption to True when loading from file
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
except json.JSONDecodeError as e:
|
| 180 |
+
gr.Warning(t("messages.invalid_json", error=str(e)))
|
| 181 |
+
return [None] * 36 + [False]
|
| 182 |
+
except Exception as e:
|
| 183 |
+
gr.Warning(t("messages.load_error", error=str(e)))
|
| 184 |
+
return [None] * 36 + [False]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def load_random_example(task_type: str):
|
| 188 |
+
"""Load a random example from the task-specific examples directory
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
task_type: The task type (e.g., "text2music")
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
|
| 195 |
+
"""
|
| 196 |
+
try:
|
| 197 |
+
# Get the project root directory
|
| 198 |
+
current_file = os.path.abspath(__file__)
|
| 199 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 200 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 201 |
+
|
| 202 |
+
# Construct the examples directory path
|
| 203 |
+
examples_dir = os.path.join(project_root, "examples", task_type)
|
| 204 |
+
|
| 205 |
+
# Check if directory exists
|
| 206 |
+
if not os.path.exists(examples_dir):
|
| 207 |
+
gr.Warning(f"Examples directory not found: examples/{task_type}/")
|
| 208 |
+
return "", "", True, None, None, "", "", ""
|
| 209 |
+
|
| 210 |
+
# Find all JSON files in the directory
|
| 211 |
+
json_files = glob.glob(os.path.join(examples_dir, "*.json"))
|
| 212 |
+
|
| 213 |
+
if not json_files:
|
| 214 |
+
gr.Warning(f"No JSON files found in examples/{task_type}/")
|
| 215 |
+
return "", "", True, None, None, "", "", ""
|
| 216 |
+
|
| 217 |
+
# Randomly select one file
|
| 218 |
+
selected_file = random.choice(json_files)
|
| 219 |
+
|
| 220 |
+
# Read and parse JSON
|
| 221 |
+
try:
|
| 222 |
+
with open(selected_file, 'r', encoding='utf-8') as f:
|
| 223 |
+
data = json.load(f)
|
| 224 |
+
|
| 225 |
+
# Extract caption (prefer 'caption', fallback to 'prompt')
|
| 226 |
+
caption_value = data.get('caption', data.get('prompt', ''))
|
| 227 |
+
if not isinstance(caption_value, str):
|
| 228 |
+
caption_value = str(caption_value) if caption_value else ''
|
| 229 |
+
|
| 230 |
+
# Extract lyrics
|
| 231 |
+
lyrics_value = data.get('lyrics', '')
|
| 232 |
+
if not isinstance(lyrics_value, str):
|
| 233 |
+
lyrics_value = str(lyrics_value) if lyrics_value else ''
|
| 234 |
+
|
| 235 |
+
# Extract think (default to True if not present)
|
| 236 |
+
think_value = data.get('think', True)
|
| 237 |
+
if not isinstance(think_value, bool):
|
| 238 |
+
think_value = True
|
| 239 |
+
|
| 240 |
+
# Extract optional metadata fields
|
| 241 |
+
bpm_value = None
|
| 242 |
+
if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]:
|
| 243 |
+
try:
|
| 244 |
+
bpm_value = int(data['bpm'])
|
| 245 |
+
except (ValueError, TypeError):
|
| 246 |
+
pass
|
| 247 |
+
|
| 248 |
+
duration_value = None
|
| 249 |
+
if 'duration' in data and data['duration'] not in [None, "N/A", ""]:
|
| 250 |
+
try:
|
| 251 |
+
duration_value = float(data['duration'])
|
| 252 |
+
except (ValueError, TypeError):
|
| 253 |
+
pass
|
| 254 |
+
|
| 255 |
+
keyscale_value = data.get('keyscale', '')
|
| 256 |
+
if keyscale_value in [None, "N/A"]:
|
| 257 |
+
keyscale_value = ''
|
| 258 |
+
|
| 259 |
+
language_value = data.get('language', '')
|
| 260 |
+
if language_value in [None, "N/A"]:
|
| 261 |
+
language_value = ''
|
| 262 |
+
|
| 263 |
+
timesignature_value = data.get('timesignature', '')
|
| 264 |
+
if timesignature_value in [None, "N/A"]:
|
| 265 |
+
timesignature_value = ''
|
| 266 |
+
|
| 267 |
+
gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file)))
|
| 268 |
+
return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
|
| 269 |
+
|
| 270 |
+
except json.JSONDecodeError as e:
|
| 271 |
+
gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
|
| 272 |
+
return "", "", True, None, None, "", "", ""
|
| 273 |
+
except Exception as e:
|
| 274 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 275 |
+
return "", "", True, None, None, "", "", ""
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 279 |
+
return "", "", True, None, None, "", "", ""
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
|
| 283 |
+
"""Smart sample function that uses LM if initialized, otherwise falls back to examples
|
| 284 |
+
|
| 285 |
+
This is a Gradio wrapper that uses the understand_music API from acestep.inference
|
| 286 |
+
to generate examples when LM is available.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
llm_handler: LLM handler instance
|
| 290 |
+
task_type: The task type (e.g., "text2music")
|
| 291 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
|
| 295 |
+
"""
|
| 296 |
+
# Check if LM is initialized
|
| 297 |
+
if llm_handler.llm_initialized:
|
| 298 |
+
# Use LM to generate example via understand_music API
|
| 299 |
+
try:
|
| 300 |
+
result = understand_music(
|
| 301 |
+
llm_handler=llm_handler,
|
| 302 |
+
audio_codes="NO USER INPUT", # Empty input triggers example generation
|
| 303 |
+
temperature=0.85,
|
| 304 |
+
use_constrained_decoding=True,
|
| 305 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if result.success:
|
| 309 |
+
gr.Info(t("messages.lm_generated"))
|
| 310 |
+
return (
|
| 311 |
+
result.caption,
|
| 312 |
+
result.lyrics,
|
| 313 |
+
True, # Always enable think when using LM-generated examples
|
| 314 |
+
result.bpm,
|
| 315 |
+
result.duration,
|
| 316 |
+
result.keyscale,
|
| 317 |
+
result.language,
|
| 318 |
+
result.timesignature,
|
| 319 |
+
)
|
| 320 |
+
else:
|
| 321 |
+
gr.Warning(t("messages.lm_fallback"))
|
| 322 |
+
return load_random_example(task_type)
|
| 323 |
+
|
| 324 |
+
except Exception as e:
|
| 325 |
+
gr.Warning(t("messages.lm_fallback"))
|
| 326 |
+
return load_random_example(task_type)
|
| 327 |
+
else:
|
| 328 |
+
# LM not initialized, use examples directory
|
| 329 |
+
return load_random_example(task_type)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def load_random_simple_description():
|
| 333 |
+
"""Load a random description from the simple_mode examples directory.
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
Tuple of (description, instrumental, vocal_language) for updating UI components
|
| 337 |
+
"""
|
| 338 |
+
try:
|
| 339 |
+
# Get the project root directory
|
| 340 |
+
current_file = os.path.abspath(__file__)
|
| 341 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 342 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 343 |
+
|
| 344 |
+
# Construct the examples directory path
|
| 345 |
+
examples_dir = os.path.join(project_root, "examples", "simple_mode")
|
| 346 |
+
|
| 347 |
+
# Check if directory exists
|
| 348 |
+
if not os.path.exists(examples_dir):
|
| 349 |
+
gr.Warning(t("messages.simple_examples_not_found"))
|
| 350 |
+
return gr.update(), gr.update(), gr.update()
|
| 351 |
+
|
| 352 |
+
# Find all JSON files in the directory
|
| 353 |
+
json_files = glob.glob(os.path.join(examples_dir, "*.json"))
|
| 354 |
+
|
| 355 |
+
if not json_files:
|
| 356 |
+
gr.Warning(t("messages.simple_examples_empty"))
|
| 357 |
+
return gr.update(), gr.update(), gr.update()
|
| 358 |
+
|
| 359 |
+
# Randomly select one file
|
| 360 |
+
selected_file = random.choice(json_files)
|
| 361 |
+
|
| 362 |
+
# Read and parse JSON
|
| 363 |
+
try:
|
| 364 |
+
with open(selected_file, 'r', encoding='utf-8') as f:
|
| 365 |
+
data = json.load(f)
|
| 366 |
+
|
| 367 |
+
# Extract fields
|
| 368 |
+
description = data.get('description', '')
|
| 369 |
+
instrumental = data.get('instrumental', False)
|
| 370 |
+
vocal_language = data.get('vocal_language', 'unknown')
|
| 371 |
+
|
| 372 |
+
# Ensure vocal_language is a string
|
| 373 |
+
if isinstance(vocal_language, list):
|
| 374 |
+
vocal_language = vocal_language[0] if vocal_language else 'unknown'
|
| 375 |
+
|
| 376 |
+
gr.Info(t("messages.simple_example_loaded", filename=os.path.basename(selected_file)))
|
| 377 |
+
return description, instrumental, vocal_language
|
| 378 |
+
|
| 379 |
+
except json.JSONDecodeError as e:
|
| 380 |
+
gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
|
| 381 |
+
return gr.update(), gr.update(), gr.update()
|
| 382 |
+
except Exception as e:
|
| 383 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 384 |
+
return gr.update(), gr.update(), gr.update()
|
| 385 |
+
|
| 386 |
+
except Exception as e:
|
| 387 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 388 |
+
return gr.update(), gr.update(), gr.update()
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def refresh_checkpoints(dit_handler):
|
| 392 |
+
"""Refresh available checkpoints"""
|
| 393 |
+
choices = dit_handler.get_available_checkpoints()
|
| 394 |
+
return gr.update(choices=choices)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def update_model_type_settings(config_path):
|
| 398 |
+
"""Update UI settings based on model type (fallback when handler not initialized yet)
|
| 399 |
+
|
| 400 |
+
Note: This is used as a fallback when the user changes config_path dropdown
|
| 401 |
+
before initializing the model. The actual settings are determined by the
|
| 402 |
+
handler's is_turbo_model() method after initialization.
|
| 403 |
+
"""
|
| 404 |
+
if config_path is None:
|
| 405 |
+
config_path = ""
|
| 406 |
+
config_path_lower = config_path.lower()
|
| 407 |
+
|
| 408 |
+
# Determine is_turbo based on config_path string
|
| 409 |
+
# This is a heuristic fallback - actual model type is determined after loading
|
| 410 |
+
if "turbo" in config_path_lower:
|
| 411 |
+
is_turbo = True
|
| 412 |
+
elif "base" in config_path_lower:
|
| 413 |
+
is_turbo = False
|
| 414 |
+
else:
|
| 415 |
+
# Default to turbo settings for unknown model types
|
| 416 |
+
is_turbo = True
|
| 417 |
+
|
| 418 |
+
return get_model_type_ui_settings(is_turbo)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
|
| 422 |
+
"""Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
|
| 423 |
+
# Initialize DiT handler
|
| 424 |
+
status, enable = dit_handler.initialize_service(
|
| 425 |
+
checkpoint, config_path, device,
|
| 426 |
+
use_flash_attention=use_flash_attention, compile_model=False,
|
| 427 |
+
offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Initialize LM handler if requested
|
| 431 |
+
if init_llm:
|
| 432 |
+
# Get checkpoint directory
|
| 433 |
+
current_file = os.path.abspath(__file__)
|
| 434 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 435 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 436 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 437 |
+
|
| 438 |
+
lm_status, lm_success = llm_handler.initialize(
|
| 439 |
+
checkpoint_dir=checkpoint_dir,
|
| 440 |
+
lm_model_path=lm_model_path,
|
| 441 |
+
backend=backend,
|
| 442 |
+
device=device,
|
| 443 |
+
offload_to_cpu=offload_to_cpu,
|
| 444 |
+
dtype=dit_handler.dtype
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
if lm_success:
|
| 448 |
+
status += f"\n{lm_status}"
|
| 449 |
+
else:
|
| 450 |
+
status += f"\n{lm_status}"
|
| 451 |
+
# Don't fail the entire initialization if LM fails, but log it
|
| 452 |
+
# Keep enable as is (DiT initialization result) even if LM fails
|
| 453 |
+
|
| 454 |
+
# Check if model is initialized - if so, collapse the accordion
|
| 455 |
+
is_model_initialized = dit_handler.model is not None
|
| 456 |
+
accordion_state = gr.Accordion(open=not is_model_initialized)
|
| 457 |
+
|
| 458 |
+
# Get model type settings based on actual loaded model
|
| 459 |
+
is_turbo = dit_handler.is_turbo_model()
|
| 460 |
+
model_type_settings = get_model_type_ui_settings(is_turbo)
|
| 461 |
+
|
| 462 |
+
return (
|
| 463 |
+
status,
|
| 464 |
+
gr.update(interactive=enable),
|
| 465 |
+
accordion_state,
|
| 466 |
+
*model_type_settings
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def get_model_type_ui_settings(is_turbo: bool):
|
| 471 |
+
"""Get UI settings based on whether the model is turbo or base"""
|
| 472 |
+
if is_turbo:
|
| 473 |
+
# Turbo model: max 20 steps, default 8, show shift with default 3.0, only show text2music/repaint/cover
|
| 474 |
+
return (
|
| 475 |
+
gr.update(value=8, maximum=20, minimum=1), # inference_steps
|
| 476 |
+
gr.update(visible=False), # guidance_scale
|
| 477 |
+
gr.update(visible=False), # use_adg
|
| 478 |
+
gr.update(value=3.0, visible=True), # shift (show with default 3.0)
|
| 479 |
+
gr.update(visible=False), # cfg_interval_start
|
| 480 |
+
gr.update(visible=False), # cfg_interval_end
|
| 481 |
+
gr.update(choices=TASK_TYPES_TURBO), # task_type
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
# Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
|
| 485 |
+
return (
|
| 486 |
+
gr.update(value=32, maximum=200, minimum=1), # inference_steps
|
| 487 |
+
gr.update(visible=True), # guidance_scale
|
| 488 |
+
gr.update(visible=True), # use_adg
|
| 489 |
+
gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
|
| 490 |
+
gr.update(visible=True), # cfg_interval_start
|
| 491 |
+
gr.update(visible=True), # cfg_interval_end
|
| 492 |
+
gr.update(choices=TASK_TYPES_BASE), # task_type
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def update_negative_prompt_visibility(init_llm_checked):
|
| 497 |
+
"""Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
|
| 498 |
+
return gr.update(visible=init_llm_checked)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
|
| 502 |
+
"""Update audio_cover_strength visibility and label"""
|
| 503 |
+
# Show if task is cover OR if LM is initialized (but NOT for repaint mode)
|
| 504 |
+
# Repaint mode never shows this control
|
| 505 |
+
is_repaint = task_type_value == "repaint"
|
| 506 |
+
is_cover = task_type_value == "cover"
|
| 507 |
+
is_visible = is_cover or (init_llm_checked and not is_repaint)
|
| 508 |
+
|
| 509 |
+
# Change label based on context
|
| 510 |
+
if init_llm_checked and not is_cover:
|
| 511 |
+
label = "LM codes strength"
|
| 512 |
+
info = "Control how many denoising steps use LM-generated codes"
|
| 513 |
+
else:
|
| 514 |
+
label = "Audio Cover Strength"
|
| 515 |
+
info = "Control how many denoising steps use cover mode"
|
| 516 |
+
|
| 517 |
+
return gr.update(visible=is_visible, label=label, info=info)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def convert_src_audio_to_codes_wrapper(dit_handler, src_audio):
|
| 521 |
+
"""Wrapper for converting src audio to codes"""
|
| 522 |
+
codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
|
| 523 |
+
return codes_string
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def update_instruction_ui(
|
| 527 |
+
dit_handler,
|
| 528 |
+
task_type_value: str,
|
| 529 |
+
track_name_value: Optional[str],
|
| 530 |
+
complete_track_classes_value: list,
|
| 531 |
+
audio_codes_content: str = "",
|
| 532 |
+
init_llm_checked: bool = False
|
| 533 |
+
) -> tuple:
|
| 534 |
+
"""Update instruction and UI visibility based on task type."""
|
| 535 |
+
instruction = dit_handler.generate_instruction(
|
| 536 |
+
task_type=task_type_value,
|
| 537 |
+
track_name=track_name_value,
|
| 538 |
+
complete_track_classes=complete_track_classes_value
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
# Show track_name for lego and extract
|
| 542 |
+
track_name_visible = task_type_value in ["lego", "extract"]
|
| 543 |
+
# Show complete_track_classes for complete
|
| 544 |
+
complete_visible = task_type_value == "complete"
|
| 545 |
+
# Show audio_cover_strength for cover OR when LM is initialized (but NOT for repaint)
|
| 546 |
+
is_repaint = task_type_value == "repaint"
|
| 547 |
+
is_cover = task_type_value == "cover"
|
| 548 |
+
audio_cover_strength_visible = is_cover or (init_llm_checked and not is_repaint)
|
| 549 |
+
# Determine label and info based on context
|
| 550 |
+
if init_llm_checked and not is_cover:
|
| 551 |
+
audio_cover_strength_label = "LM codes strength"
|
| 552 |
+
audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
|
| 553 |
+
else:
|
| 554 |
+
audio_cover_strength_label = "Audio Cover Strength"
|
| 555 |
+
audio_cover_strength_info = "Control how many denoising steps use cover mode"
|
| 556 |
+
# Show repainting controls for repaint and lego
|
| 557 |
+
repainting_visible = task_type_value in ["repaint", "lego"]
|
| 558 |
+
# Show text2music_audio_codes if task is text2music OR if it has content
|
| 559 |
+
# This allows it to stay visible even if user switches task type but has codes
|
| 560 |
+
has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
|
| 561 |
+
text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
|
| 562 |
+
|
| 563 |
+
return (
|
| 564 |
+
instruction, # instruction_display_gen
|
| 565 |
+
gr.update(visible=track_name_visible), # track_name
|
| 566 |
+
gr.update(visible=complete_visible), # complete_track_classes
|
| 567 |
+
gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
|
| 568 |
+
gr.update(visible=repainting_visible), # repainting_group
|
| 569 |
+
gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug):
|
| 574 |
+
"""
|
| 575 |
+
Transcribe audio codes to metadata using LLM understanding.
|
| 576 |
+
If audio_code_string is empty, generate a sample example instead.
|
| 577 |
+
|
| 578 |
+
This is a Gradio wrapper around the understand_music API in acestep.inference.
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
llm_handler: LLM handler instance
|
| 582 |
+
audio_code_string: String containing audio codes (or empty for example generation)
|
| 583 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 584 |
+
|
| 585 |
+
Returns:
|
| 586 |
+
Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
|
| 587 |
+
"""
|
| 588 |
+
# Call the inference API
|
| 589 |
+
result = understand_music(
|
| 590 |
+
llm_handler=llm_handler,
|
| 591 |
+
audio_codes=audio_code_string,
|
| 592 |
+
use_constrained_decoding=True,
|
| 593 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
# Handle error case with localized message
|
| 597 |
+
if not result.success:
|
| 598 |
+
# Use localized error message for LLM not initialized
|
| 599 |
+
if result.error == "LLM not initialized":
|
| 600 |
+
return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
|
| 601 |
+
return result.status_message, "", "", None, None, "", "", "", False
|
| 602 |
+
|
| 603 |
+
return (
|
| 604 |
+
result.status_message,
|
| 605 |
+
result.caption,
|
| 606 |
+
result.lyrics,
|
| 607 |
+
result.bpm,
|
| 608 |
+
result.duration,
|
| 609 |
+
result.keyscale,
|
| 610 |
+
result.language,
|
| 611 |
+
result.timesignature,
|
| 612 |
+
True # Set is_format_caption to True (from Transcribe/LM understanding)
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def update_transcribe_button_text(audio_code_string):
|
| 617 |
+
"""
|
| 618 |
+
Update the transcribe button text based on input content.
|
| 619 |
+
If empty: "Generate Example"
|
| 620 |
+
If has content: "Transcribe"
|
| 621 |
+
"""
|
| 622 |
+
if not audio_code_string or not audio_code_string.strip():
|
| 623 |
+
return gr.update(value="Generate Example")
|
| 624 |
+
else:
|
| 625 |
+
return gr.update(value="Transcribe")
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def reset_format_caption_flag():
|
| 629 |
+
"""Reset is_format_caption to False when user manually edits caption/metadata"""
|
| 630 |
+
return False
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def update_audio_uploads_accordion(reference_audio, src_audio):
|
| 634 |
+
"""Update Audio Uploads visibility based on whether audio files are present"""
|
| 635 |
+
has_audio = (reference_audio is not None) or (src_audio is not None)
|
| 636 |
+
return gr.update(visible=has_audio)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
|
| 640 |
+
"""
|
| 641 |
+
Handle instrumental checkbox changes.
|
| 642 |
+
When checked: if no lyrics, fill with [Instrumental]
|
| 643 |
+
When unchecked: if lyrics is [Instrumental], clear it
|
| 644 |
+
"""
|
| 645 |
+
if instrumental_checked:
|
| 646 |
+
# If checked and no lyrics, fill with [Instrumental]
|
| 647 |
+
if not current_lyrics or not current_lyrics.strip():
|
| 648 |
+
return "[Instrumental]"
|
| 649 |
+
else:
|
| 650 |
+
# Has lyrics, don't change
|
| 651 |
+
return current_lyrics
|
| 652 |
+
else:
|
| 653 |
+
# If unchecked and lyrics is exactly [Instrumental], clear it
|
| 654 |
+
if current_lyrics and current_lyrics.strip() == "[Instrumental]":
|
| 655 |
+
return ""
|
| 656 |
+
else:
|
| 657 |
+
# Has other lyrics, don't change
|
| 658 |
+
return current_lyrics
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def handle_simple_instrumental_change(is_instrumental: bool):
|
| 662 |
+
"""
|
| 663 |
+
Handle simple mode instrumental checkbox changes.
|
| 664 |
+
When checked: set vocal_language to "unknown" and disable editing.
|
| 665 |
+
When unchecked: enable vocal_language editing.
|
| 666 |
+
|
| 667 |
+
Args:
|
| 668 |
+
is_instrumental: Whether instrumental checkbox is checked
|
| 669 |
+
|
| 670 |
+
Returns:
|
| 671 |
+
gr.update for simple_vocal_language dropdown
|
| 672 |
+
"""
|
| 673 |
+
if is_instrumental:
|
| 674 |
+
return gr.update(value="unknown", interactive=False)
|
| 675 |
+
else:
|
| 676 |
+
return gr.update(interactive=True)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def update_audio_components_visibility(batch_size):
|
| 680 |
+
"""Show/hide individual audio components based on batch size (1-8)
|
| 681 |
+
|
| 682 |
+
Row 1: Components 1-4 (batch_size 1-4)
|
| 683 |
+
Row 2: Components 5-8 (batch_size 5-8)
|
| 684 |
+
"""
|
| 685 |
+
# Clamp batch size to 1-8 range for UI
|
| 686 |
+
batch_size = min(max(int(batch_size), 1), 8)
|
| 687 |
+
|
| 688 |
+
# Row 1 columns (1-4)
|
| 689 |
+
updates_row1 = (
|
| 690 |
+
gr.update(visible=True), # audio_col_1: always visible
|
| 691 |
+
gr.update(visible=batch_size >= 2), # audio_col_2
|
| 692 |
+
gr.update(visible=batch_size >= 3), # audio_col_3
|
| 693 |
+
gr.update(visible=batch_size >= 4), # audio_col_4
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# Row 2 container and columns (5-8)
|
| 697 |
+
show_row_5_8 = batch_size >= 5
|
| 698 |
+
updates_row2 = (
|
| 699 |
+
gr.update(visible=show_row_5_8), # audio_row_5_8 (container)
|
| 700 |
+
gr.update(visible=batch_size >= 5), # audio_col_5
|
| 701 |
+
gr.update(visible=batch_size >= 6), # audio_col_6
|
| 702 |
+
gr.update(visible=batch_size >= 7), # audio_col_7
|
| 703 |
+
gr.update(visible=batch_size >= 8), # audio_col_8
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
return updates_row1 + updates_row2
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
def handle_generation_mode_change(mode: str):
|
| 710 |
+
"""
|
| 711 |
+
Handle generation mode change between Simple, Custom, Cover, and Repaint modes.
|
| 712 |
+
|
| 713 |
+
Modes:
|
| 714 |
+
- Simple: Show simple mode group, hide others
|
| 715 |
+
- Custom: Show custom content (prompt), hide others
|
| 716 |
+
- Cover: Show src_audio_group + custom content + LM codes strength
|
| 717 |
+
- Repaint: Show src_audio_group + custom content + repaint time controls (hide LM codes strength)
|
| 718 |
+
|
| 719 |
+
Args:
|
| 720 |
+
mode: "simple", "custom", "cover", or "repaint"
|
| 721 |
+
|
| 722 |
+
Returns:
|
| 723 |
+
Tuple of updates for:
|
| 724 |
+
- simple_mode_group (visibility)
|
| 725 |
+
- custom_mode_content (visibility)
|
| 726 |
+
- cover_mode_group (visibility) - legacy, always hidden
|
| 727 |
+
- repainting_group (visibility)
|
| 728 |
+
- task_type (value)
|
| 729 |
+
- generate_btn (interactive state)
|
| 730 |
+
- simple_sample_created (reset state)
|
| 731 |
+
- src_audio_group (visibility) - shown for cover and repaint
|
| 732 |
+
- audio_cover_strength (visibility) - shown only for cover mode
|
| 733 |
+
- think_checkbox (value and interactive) - disabled for cover/repaint modes
|
| 734 |
+
"""
|
| 735 |
+
is_simple = mode == "simple"
|
| 736 |
+
is_custom = mode == "custom"
|
| 737 |
+
is_cover = mode == "cover"
|
| 738 |
+
is_repaint = mode == "repaint"
|
| 739 |
+
|
| 740 |
+
# Map mode to task_type
|
| 741 |
+
task_type_map = {
|
| 742 |
+
"simple": "text2music",
|
| 743 |
+
"custom": "text2music",
|
| 744 |
+
"cover": "cover",
|
| 745 |
+
"repaint": "repaint",
|
| 746 |
+
}
|
| 747 |
+
task_type_value = task_type_map.get(mode, "text2music")
|
| 748 |
+
|
| 749 |
+
# think_checkbox: disabled and set to False for cover/repaint modes
|
| 750 |
+
# (these modes don't use LM thinking, they use source audio codes)
|
| 751 |
+
if is_cover or is_repaint:
|
| 752 |
+
think_checkbox_update = gr.update(value=False, interactive=False)
|
| 753 |
+
else:
|
| 754 |
+
think_checkbox_update = gr.update(value=True, interactive=True)
|
| 755 |
+
|
| 756 |
+
return (
|
| 757 |
+
gr.update(visible=is_simple), # simple_mode_group
|
| 758 |
+
gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
|
| 759 |
+
gr.update(visible=False), # cover_mode_group - legacy, always hidden
|
| 760 |
+
gr.update(visible=is_repaint), # repainting_group - time range controls
|
| 761 |
+
gr.update(value=task_type_value), # task_type
|
| 762 |
+
gr.update(interactive=True), # generate_btn - always enabled (Simple mode does create+generate in one step)
|
| 763 |
+
False, # simple_sample_created - reset to False on mode change
|
| 764 |
+
gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
|
| 765 |
+
gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
|
| 766 |
+
think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
|
| 770 |
+
"""
|
| 771 |
+
Process source audio: convert to codes and then transcribe.
|
| 772 |
+
This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
|
| 773 |
+
|
| 774 |
+
Args:
|
| 775 |
+
dit_handler: DiT handler instance
|
| 776 |
+
llm_handler: LLM handler instance
|
| 777 |
+
src_audio: Path to source audio file
|
| 778 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 779 |
+
|
| 780 |
+
Returns:
|
| 781 |
+
Tuple of (audio_codes, status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
|
| 782 |
+
"""
|
| 783 |
+
if src_audio is None:
|
| 784 |
+
return ("", "No audio file provided", "", "", None, None, "", "", "", False)
|
| 785 |
+
|
| 786 |
+
# Step 1: Convert audio to codes
|
| 787 |
+
try:
|
| 788 |
+
codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
|
| 789 |
+
if not codes_string:
|
| 790 |
+
return ("", "Failed to convert audio to codes", "", "", None, None, "", "", "", False)
|
| 791 |
+
except Exception as e:
|
| 792 |
+
return ("", f"Error converting audio: {str(e)}", "", "", None, None, "", "", "", False)
|
| 793 |
+
|
| 794 |
+
# Step 2: Transcribe the codes
|
| 795 |
+
result = understand_music(
|
| 796 |
+
llm_handler=llm_handler,
|
| 797 |
+
audio_codes=codes_string,
|
| 798 |
+
use_constrained_decoding=True,
|
| 799 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
# Handle error case
|
| 803 |
+
if not result.success:
|
| 804 |
+
if result.error == "LLM not initialized":
|
| 805 |
+
return (codes_string, t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False)
|
| 806 |
+
return (codes_string, result.status_message, "", "", None, None, "", "", "", False)
|
| 807 |
+
|
| 808 |
+
return (
|
| 809 |
+
codes_string,
|
| 810 |
+
result.status_message,
|
| 811 |
+
result.caption,
|
| 812 |
+
result.lyrics,
|
| 813 |
+
result.bpm,
|
| 814 |
+
result.duration,
|
| 815 |
+
result.keyscale,
|
| 816 |
+
result.language,
|
| 817 |
+
result.timesignature,
|
| 818 |
+
True # Set is_format_caption to True
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
def handle_create_sample(
|
| 822 |
+
llm_handler,
|
| 823 |
+
query: str,
|
| 824 |
+
instrumental: bool,
|
| 825 |
+
vocal_language: str,
|
| 826 |
+
lm_temperature: float,
|
| 827 |
+
lm_top_k: int,
|
| 828 |
+
lm_top_p: float,
|
| 829 |
+
constrained_decoding_debug: bool = False,
|
| 830 |
+
):
|
| 831 |
+
"""
|
| 832 |
+
Handle the Create Sample button click in Simple mode.
|
| 833 |
+
|
| 834 |
+
Creates a sample from the user's query using the LLM, then populates
|
| 835 |
+
the caption, lyrics, and metadata fields.
|
| 836 |
+
|
| 837 |
+
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 838 |
+
|
| 839 |
+
Args:
|
| 840 |
+
llm_handler: LLM handler instance (unused, fetched from registry)
|
| 841 |
+
query: User's natural language music description
|
| 842 |
+
instrumental: Whether to generate instrumental music
|
| 843 |
+
vocal_language: Preferred vocal language for constrained decoding
|
| 844 |
+
lm_temperature: LLM temperature for generation
|
| 845 |
+
lm_top_k: LLM top-k sampling
|
| 846 |
+
lm_top_p: LLM top-p sampling
|
| 847 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 848 |
+
|
| 849 |
+
Returns:
|
| 850 |
+
Tuple of updates for:
|
| 851 |
+
- captions
|
| 852 |
+
- lyrics
|
| 853 |
+
- bpm
|
| 854 |
+
- audio_duration
|
| 855 |
+
- key_scale
|
| 856 |
+
- vocal_language
|
| 857 |
+
- time_signature
|
| 858 |
+
- instrumental_checkbox
|
| 859 |
+
- caption_accordion (open)
|
| 860 |
+
- lyrics_accordion (open)
|
| 861 |
+
- generate_btn (interactive)
|
| 862 |
+
- simple_sample_created (True)
|
| 863 |
+
- think_checkbox (True)
|
| 864 |
+
- is_format_caption_state (True)
|
| 865 |
+
- status_output
|
| 866 |
+
"""
|
| 867 |
+
# Check if LLM is initialized
|
| 868 |
+
if not llm_handler.llm_initialized:
|
| 869 |
+
gr.Warning(t("messages.lm_not_initialized"))
|
| 870 |
+
return (
|
| 871 |
+
gr.update(), # captions - no change
|
| 872 |
+
gr.update(), # lyrics - no change
|
| 873 |
+
gr.update(), # bpm - no change
|
| 874 |
+
gr.update(), # audio_duration - no change
|
| 875 |
+
gr.update(), # key_scale - no change
|
| 876 |
+
gr.update(), # vocal_language - no change
|
| 877 |
+
gr.update(), # time_signature - no change
|
| 878 |
+
gr.update(), # instrumental_checkbox - no change
|
| 879 |
+
gr.update(), # caption_accordion - no change
|
| 880 |
+
gr.update(), # lyrics_accordion - no change
|
| 881 |
+
gr.update(interactive=False), # generate_btn - keep disabled
|
| 882 |
+
False, # simple_sample_created - still False
|
| 883 |
+
gr.update(), # think_checkbox - no change
|
| 884 |
+
gr.update(), # is_format_caption_state - no change
|
| 885 |
+
t("messages.lm_not_initialized"), # status_output
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# Convert LM parameters
|
| 889 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 890 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 891 |
+
|
| 892 |
+
# Call create_sample API
|
| 893 |
+
# Note: cfg_scale and negative_prompt are not supported in create_sample mode
|
| 894 |
+
result = create_sample(
|
| 895 |
+
llm_handler=llm_handler,
|
| 896 |
+
query=query,
|
| 897 |
+
instrumental=instrumental,
|
| 898 |
+
vocal_language=vocal_language,
|
| 899 |
+
temperature=lm_temperature,
|
| 900 |
+
top_k=top_k_value,
|
| 901 |
+
top_p=top_p_value,
|
| 902 |
+
use_constrained_decoding=True,
|
| 903 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
# Handle error
|
| 907 |
+
if not result.success:
|
| 908 |
+
gr.Warning(result.status_message or t("messages.sample_creation_failed"))
|
| 909 |
+
return (
|
| 910 |
+
gr.update(), # captions - no change
|
| 911 |
+
gr.update(), # lyrics - no change
|
| 912 |
+
gr.update(), # bpm - no change
|
| 913 |
+
gr.update(), # audio_duration - no change
|
| 914 |
+
gr.update(), # key_scale - no change
|
| 915 |
+
gr.update(), # vocal_language - no change
|
| 916 |
+
gr.update(), # simple vocal_language - no change
|
| 917 |
+
gr.update(), # time_signature - no change
|
| 918 |
+
gr.update(), # instrumental_checkbox - no change
|
| 919 |
+
gr.update(), # caption_accordion - no change
|
| 920 |
+
gr.update(), # lyrics_accordion - no change
|
| 921 |
+
gr.update(interactive=False), # generate_btn - keep disabled
|
| 922 |
+
False, # simple_sample_created - still False
|
| 923 |
+
gr.update(), # think_checkbox - no change
|
| 924 |
+
gr.update(), # is_format_caption_state - no change
|
| 925 |
+
result.status_message or t("messages.sample_creation_failed"), # status_output
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
# Success - populate fields
|
| 929 |
+
gr.Info(t("messages.sample_created"))
|
| 930 |
+
|
| 931 |
+
return (
|
| 932 |
+
result.caption, # captions
|
| 933 |
+
result.lyrics, # lyrics
|
| 934 |
+
result.bpm, # bpm
|
| 935 |
+
result.duration if result.duration and result.duration > 0 else -1, # audio_duration
|
| 936 |
+
result.keyscale, # key_scale
|
| 937 |
+
result.language, # vocal_language
|
| 938 |
+
result.language, # simple vocal_language
|
| 939 |
+
result.timesignature, # time_signature
|
| 940 |
+
result.instrumental, # instrumental_checkbox
|
| 941 |
+
gr.Accordion(open=True), # caption_accordion - expand
|
| 942 |
+
gr.Accordion(open=True), # lyrics_accordion - expand
|
| 943 |
+
gr.update(interactive=True), # generate_btn - enable
|
| 944 |
+
True, # simple_sample_created - True
|
| 945 |
+
True, # think_checkbox - enable thinking
|
| 946 |
+
True, # is_format_caption_state - True (LM-generated)
|
| 947 |
+
result.status_message, # status_output
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
def handle_format_sample(
|
| 951 |
+
llm_handler,
|
| 952 |
+
caption: str,
|
| 953 |
+
lyrics: str,
|
| 954 |
+
bpm,
|
| 955 |
+
audio_duration,
|
| 956 |
+
key_scale: str,
|
| 957 |
+
time_signature: str,
|
| 958 |
+
lm_temperature: float,
|
| 959 |
+
lm_top_k: int,
|
| 960 |
+
lm_top_p: float,
|
| 961 |
+
constrained_decoding_debug: bool = False,
|
| 962 |
+
):
|
| 963 |
+
"""
|
| 964 |
+
Handle the Format button click to format caption and lyrics.
|
| 965 |
+
|
| 966 |
+
Takes user-provided caption and lyrics, and uses the LLM to generate
|
| 967 |
+
structured music metadata and an enhanced description.
|
| 968 |
+
|
| 969 |
+
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 970 |
+
|
| 971 |
+
Args:
|
| 972 |
+
llm_handler: LLM handler instance (unused, fetched from registry)
|
| 973 |
+
caption: User's caption/description
|
| 974 |
+
lyrics: User's lyrics
|
| 975 |
+
bpm: User-provided BPM (optional, for constrained decoding)
|
| 976 |
+
audio_duration: User-provided duration (optional, for constrained decoding)
|
| 977 |
+
key_scale: User-provided key scale (optional, for constrained decoding)
|
| 978 |
+
time_signature: User-provided time signature (optional, for constrained decoding)
|
| 979 |
+
lm_temperature: LLM temperature for generation
|
| 980 |
+
lm_top_k: LLM top-k sampling
|
| 981 |
+
lm_top_p: LLM top-p sampling
|
| 982 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 983 |
+
|
| 984 |
+
Returns:
|
| 985 |
+
Tuple of updates for:
|
| 986 |
+
- captions
|
| 987 |
+
- lyrics
|
| 988 |
+
- bpm
|
| 989 |
+
- audio_duration
|
| 990 |
+
- key_scale
|
| 991 |
+
- vocal_language
|
| 992 |
+
- time_signature
|
| 993 |
+
- is_format_caption_state
|
| 994 |
+
- status_output
|
| 995 |
+
"""
|
| 996 |
+
# Check if LLM is initialized
|
| 997 |
+
if not llm_handler.llm_initialized:
|
| 998 |
+
gr.Warning(t("messages.lm_not_initialized"))
|
| 999 |
+
return (
|
| 1000 |
+
gr.update(), # captions - no change
|
| 1001 |
+
gr.update(), # lyrics - no change
|
| 1002 |
+
gr.update(), # bpm - no change
|
| 1003 |
+
gr.update(), # audio_duration - no change
|
| 1004 |
+
gr.update(), # key_scale - no change
|
| 1005 |
+
gr.update(), # vocal_language - no change
|
| 1006 |
+
gr.update(), # time_signature - no change
|
| 1007 |
+
gr.update(), # is_format_caption_state - no change
|
| 1008 |
+
t("messages.lm_not_initialized"), # status_output
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
# Build user_metadata from provided values for constrained decoding
|
| 1012 |
+
user_metadata = {}
|
| 1013 |
+
if bpm is not None and bpm > 0:
|
| 1014 |
+
user_metadata['bpm'] = int(bpm)
|
| 1015 |
+
if audio_duration is not None and audio_duration > 0:
|
| 1016 |
+
user_metadata['duration'] = int(audio_duration)
|
| 1017 |
+
if key_scale and key_scale.strip():
|
| 1018 |
+
user_metadata['keyscale'] = key_scale.strip()
|
| 1019 |
+
if time_signature and time_signature.strip():
|
| 1020 |
+
user_metadata['timesignature'] = time_signature.strip()
|
| 1021 |
+
|
| 1022 |
+
# Only pass user_metadata if we have at least one field
|
| 1023 |
+
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 1024 |
+
|
| 1025 |
+
# Convert LM parameters
|
| 1026 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 1027 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 1028 |
+
|
| 1029 |
+
# Call format_sample API
|
| 1030 |
+
result = format_sample(
|
| 1031 |
+
llm_handler=llm_handler,
|
| 1032 |
+
caption=caption,
|
| 1033 |
+
lyrics=lyrics,
|
| 1034 |
+
user_metadata=user_metadata_to_pass,
|
| 1035 |
+
temperature=lm_temperature,
|
| 1036 |
+
top_k=top_k_value,
|
| 1037 |
+
top_p=top_p_value,
|
| 1038 |
+
use_constrained_decoding=True,
|
| 1039 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1040 |
+
)
|
| 1041 |
+
|
| 1042 |
+
# Handle error
|
| 1043 |
+
if not result.success:
|
| 1044 |
+
gr.Warning(result.status_message or t("messages.format_failed"))
|
| 1045 |
+
return (
|
| 1046 |
+
gr.update(), # captions - no change
|
| 1047 |
+
gr.update(), # lyrics - no change
|
| 1048 |
+
gr.update(), # bpm - no change
|
| 1049 |
+
gr.update(), # audio_duration - no change
|
| 1050 |
+
gr.update(), # key_scale - no change
|
| 1051 |
+
gr.update(), # vocal_language - no change
|
| 1052 |
+
gr.update(), # time_signature - no change
|
| 1053 |
+
gr.update(), # is_format_caption_state - no change
|
| 1054 |
+
result.status_message or t("messages.format_failed"), # status_output
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
# Success - populate fields
|
| 1058 |
+
gr.Info(t("messages.format_success"))
|
| 1059 |
+
|
| 1060 |
+
return (
|
| 1061 |
+
result.caption, # captions
|
| 1062 |
+
result.lyrics, # lyrics
|
| 1063 |
+
result.bpm, # bpm
|
| 1064 |
+
result.duration if result.duration and result.duration > 0 else -1, # audio_duration
|
| 1065 |
+
result.keyscale, # key_scale
|
| 1066 |
+
result.language, # vocal_language
|
| 1067 |
+
result.timesignature, # time_signature
|
| 1068 |
+
True, # is_format_caption_state - True (LM-formatted)
|
| 1069 |
+
result.status_message, # status_output
|
| 1070 |
+
)
|
| 1071 |
+
|
acestep/gradio_ui/events/results_handlers.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/gradio_ui/events/training_handlers.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Event Handlers for Training Tab
|
| 3 |
+
|
| 4 |
+
Contains all event handler functions for the dataset builder and training UI.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Optional
|
| 10 |
+
from loguru import logger
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_dataset_builder() -> DatasetBuilder:
|
| 17 |
+
"""Create a new DatasetBuilder instance."""
|
| 18 |
+
return DatasetBuilder()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def scan_directory(
|
| 22 |
+
audio_dir: str,
|
| 23 |
+
dataset_name: str,
|
| 24 |
+
custom_tag: str,
|
| 25 |
+
tag_position: str,
|
| 26 |
+
all_instrumental: bool,
|
| 27 |
+
builder_state: Optional[DatasetBuilder],
|
| 28 |
+
) -> Tuple[Any, str, Any, DatasetBuilder]:
|
| 29 |
+
"""Scan a directory for audio files.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Tuple of (table_data, status, slider_update, builder_state)
|
| 33 |
+
"""
|
| 34 |
+
if not audio_dir or not audio_dir.strip():
|
| 35 |
+
return [], "❌ Please enter a directory path", gr.Slider(maximum=0, value=0), builder_state
|
| 36 |
+
|
| 37 |
+
# Create or use existing builder
|
| 38 |
+
builder = builder_state if builder_state else DatasetBuilder()
|
| 39 |
+
|
| 40 |
+
# Set metadata before scanning
|
| 41 |
+
builder.metadata.name = dataset_name
|
| 42 |
+
builder.metadata.custom_tag = custom_tag
|
| 43 |
+
builder.metadata.tag_position = tag_position
|
| 44 |
+
builder.metadata.all_instrumental = all_instrumental
|
| 45 |
+
|
| 46 |
+
# Scan directory
|
| 47 |
+
samples, status = builder.scan_directory(audio_dir.strip())
|
| 48 |
+
|
| 49 |
+
if not samples:
|
| 50 |
+
return [], status, gr.Slider(maximum=0, value=0), builder
|
| 51 |
+
|
| 52 |
+
# Set instrumental and tag for all samples
|
| 53 |
+
builder.set_all_instrumental(all_instrumental)
|
| 54 |
+
if custom_tag:
|
| 55 |
+
builder.set_custom_tag(custom_tag, tag_position)
|
| 56 |
+
|
| 57 |
+
# Get table data
|
| 58 |
+
table_data = builder.get_samples_dataframe_data()
|
| 59 |
+
|
| 60 |
+
# Calculate slider max and return as Slider update
|
| 61 |
+
slider_max = max(0, len(samples) - 1)
|
| 62 |
+
|
| 63 |
+
return table_data, status, gr.Slider(maximum=slider_max, value=0), builder
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def auto_label_all(
|
| 67 |
+
dit_handler,
|
| 68 |
+
llm_handler,
|
| 69 |
+
builder_state: Optional[DatasetBuilder],
|
| 70 |
+
skip_metas: bool = False,
|
| 71 |
+
progress=None,
|
| 72 |
+
) -> Tuple[List[List[Any]], str, DatasetBuilder]:
|
| 73 |
+
"""Auto-label all samples in the dataset.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
dit_handler: DiT handler for audio processing
|
| 77 |
+
llm_handler: LLM handler for caption generation
|
| 78 |
+
builder_state: Dataset builder state
|
| 79 |
+
skip_metas: If True, skip LLM labeling. BPM/Key/TimeSig = N/A, Language = unknown for instrumental
|
| 80 |
+
progress: Progress callback
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Tuple of (table_data, status, builder_state)
|
| 84 |
+
"""
|
| 85 |
+
if builder_state is None:
|
| 86 |
+
return [], "❌ Please scan a directory first", builder_state
|
| 87 |
+
|
| 88 |
+
if not builder_state.samples:
|
| 89 |
+
return [], "❌ No samples to label. Please scan a directory first.", builder_state
|
| 90 |
+
|
| 91 |
+
# If skip_metas is True, just set default values without LLM
|
| 92 |
+
if skip_metas:
|
| 93 |
+
for sample in builder_state.samples:
|
| 94 |
+
sample.bpm = None # Will display as N/A
|
| 95 |
+
sample.keyscale = "N/A"
|
| 96 |
+
sample.timesignature = "N/A"
|
| 97 |
+
# For instrumental, language should be "unknown"
|
| 98 |
+
if sample.is_instrumental:
|
| 99 |
+
sample.language = "unknown"
|
| 100 |
+
else:
|
| 101 |
+
sample.language = "unknown"
|
| 102 |
+
# Use custom tag as caption if set, otherwise use filename
|
| 103 |
+
if builder_state.metadata.custom_tag:
|
| 104 |
+
sample.caption = builder_state.metadata.custom_tag
|
| 105 |
+
else:
|
| 106 |
+
sample.caption = sample.filename
|
| 107 |
+
|
| 108 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 109 |
+
return table_data, f"✅ Skipped AI labeling. {len(builder_state.samples)} samples set with default values.", builder_state
|
| 110 |
+
|
| 111 |
+
# Check if handlers are initialized
|
| 112 |
+
if dit_handler is None or dit_handler.model is None:
|
| 113 |
+
return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
|
| 114 |
+
|
| 115 |
+
if llm_handler is None or not llm_handler.llm_initialized:
|
| 116 |
+
return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
|
| 117 |
+
|
| 118 |
+
def progress_callback(msg):
|
| 119 |
+
if progress:
|
| 120 |
+
try:
|
| 121 |
+
progress(msg)
|
| 122 |
+
except:
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
# Label all samples
|
| 126 |
+
samples, status = builder_state.label_all_samples(
|
| 127 |
+
dit_handler=dit_handler,
|
| 128 |
+
llm_handler=llm_handler,
|
| 129 |
+
progress_callback=progress_callback,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Get updated table data
|
| 133 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 134 |
+
|
| 135 |
+
return table_data, status, builder_state
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_sample_preview(
|
| 139 |
+
sample_idx: int,
|
| 140 |
+
builder_state: Optional[DatasetBuilder],
|
| 141 |
+
) -> Tuple[str, str, str, str, Optional[int], str, str, float, str, bool]:
|
| 142 |
+
"""Get preview data for a specific sample.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tuple of (audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
|
| 146 |
+
"""
|
| 147 |
+
if builder_state is None or not builder_state.samples:
|
| 148 |
+
return None, "", "", "", None, "", "", 0.0, "instrumental", True
|
| 149 |
+
|
| 150 |
+
idx = int(sample_idx)
|
| 151 |
+
if idx < 0 or idx >= len(builder_state.samples):
|
| 152 |
+
return None, "", "", "", None, "", "", 0.0, "instrumental", True
|
| 153 |
+
|
| 154 |
+
sample = builder_state.samples[idx]
|
| 155 |
+
|
| 156 |
+
return (
|
| 157 |
+
sample.audio_path,
|
| 158 |
+
sample.filename,
|
| 159 |
+
sample.caption,
|
| 160 |
+
sample.lyrics,
|
| 161 |
+
sample.bpm,
|
| 162 |
+
sample.keyscale,
|
| 163 |
+
sample.timesignature,
|
| 164 |
+
sample.duration,
|
| 165 |
+
sample.language,
|
| 166 |
+
sample.is_instrumental,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def save_sample_edit(
|
| 171 |
+
sample_idx: int,
|
| 172 |
+
caption: str,
|
| 173 |
+
lyrics: str,
|
| 174 |
+
bpm: Optional[int],
|
| 175 |
+
keyscale: str,
|
| 176 |
+
timesig: str,
|
| 177 |
+
language: str,
|
| 178 |
+
is_instrumental: bool,
|
| 179 |
+
builder_state: Optional[DatasetBuilder],
|
| 180 |
+
) -> Tuple[List[List[Any]], str, DatasetBuilder]:
|
| 181 |
+
"""Save edits to a sample.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tuple of (table_data, status, builder_state)
|
| 185 |
+
"""
|
| 186 |
+
if builder_state is None:
|
| 187 |
+
return [], "❌ No dataset loaded", builder_state
|
| 188 |
+
|
| 189 |
+
idx = int(sample_idx)
|
| 190 |
+
|
| 191 |
+
# Update sample
|
| 192 |
+
sample, status = builder_state.update_sample(
|
| 193 |
+
idx,
|
| 194 |
+
caption=caption,
|
| 195 |
+
lyrics=lyrics if not is_instrumental else "[Instrumental]",
|
| 196 |
+
bpm=int(bpm) if bpm else None,
|
| 197 |
+
keyscale=keyscale,
|
| 198 |
+
timesignature=timesig,
|
| 199 |
+
language="instrumental" if is_instrumental else language,
|
| 200 |
+
is_instrumental=is_instrumental,
|
| 201 |
+
labeled=True,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Get updated table data
|
| 205 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 206 |
+
|
| 207 |
+
return table_data, status, builder_state
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def update_settings(
|
| 211 |
+
custom_tag: str,
|
| 212 |
+
tag_position: str,
|
| 213 |
+
all_instrumental: bool,
|
| 214 |
+
builder_state: Optional[DatasetBuilder],
|
| 215 |
+
) -> DatasetBuilder:
|
| 216 |
+
"""Update dataset settings.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Updated builder_state
|
| 220 |
+
"""
|
| 221 |
+
if builder_state is None:
|
| 222 |
+
return builder_state
|
| 223 |
+
|
| 224 |
+
if custom_tag:
|
| 225 |
+
builder_state.set_custom_tag(custom_tag, tag_position)
|
| 226 |
+
|
| 227 |
+
builder_state.set_all_instrumental(all_instrumental)
|
| 228 |
+
|
| 229 |
+
return builder_state
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def save_dataset(
|
| 233 |
+
save_path: str,
|
| 234 |
+
dataset_name: str,
|
| 235 |
+
builder_state: Optional[DatasetBuilder],
|
| 236 |
+
) -> str:
|
| 237 |
+
"""Save the dataset to a JSON file.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
Status message
|
| 241 |
+
"""
|
| 242 |
+
if builder_state is None:
|
| 243 |
+
return "❌ No dataset to save. Please scan a directory first."
|
| 244 |
+
|
| 245 |
+
if not builder_state.samples:
|
| 246 |
+
return "❌ No samples in dataset."
|
| 247 |
+
|
| 248 |
+
if not save_path or not save_path.strip():
|
| 249 |
+
return "❌ Please enter a save path."
|
| 250 |
+
|
| 251 |
+
# Check if any samples are labeled
|
| 252 |
+
labeled_count = builder_state.get_labeled_count()
|
| 253 |
+
if labeled_count == 0:
|
| 254 |
+
return "⚠️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway..."
|
| 255 |
+
|
| 256 |
+
return builder_state.save_dataset(save_path.strip(), dataset_name)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def load_existing_dataset_for_preprocess(
|
| 260 |
+
dataset_path: str,
|
| 261 |
+
builder_state: Optional[DatasetBuilder],
|
| 262 |
+
) -> Tuple[str, Any, Any, DatasetBuilder, str, str, str, str, Optional[int], str, str, float, str, bool]:
|
| 263 |
+
"""Load an existing dataset JSON file for preprocessing.
|
| 264 |
+
|
| 265 |
+
This allows users to load a previously saved dataset and proceed to preprocessing
|
| 266 |
+
without having to re-scan and re-label.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Tuple of (status, table_data, slider_update, builder_state,
|
| 270 |
+
audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
|
| 271 |
+
"""
|
| 272 |
+
empty_preview = (None, "", "", "", None, "", "", 0.0, "instrumental", True)
|
| 273 |
+
|
| 274 |
+
if not dataset_path or not dataset_path.strip():
|
| 275 |
+
return ("❌ Please enter a dataset path", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
|
| 276 |
+
|
| 277 |
+
dataset_path = dataset_path.strip()
|
| 278 |
+
|
| 279 |
+
if not os.path.exists(dataset_path):
|
| 280 |
+
return (f"❌ Dataset not found: {dataset_path}", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
|
| 281 |
+
|
| 282 |
+
# Create new builder (don't reuse old state when loading a file)
|
| 283 |
+
builder = DatasetBuilder()
|
| 284 |
+
|
| 285 |
+
# Load the dataset
|
| 286 |
+
samples, status = builder.load_dataset(dataset_path)
|
| 287 |
+
|
| 288 |
+
if not samples:
|
| 289 |
+
return (status, [], gr.Slider(maximum=0, value=0), builder) + empty_preview
|
| 290 |
+
|
| 291 |
+
# Get table data
|
| 292 |
+
table_data = builder.get_samples_dataframe_data()
|
| 293 |
+
|
| 294 |
+
# Calculate slider max
|
| 295 |
+
slider_max = max(0, len(samples) - 1)
|
| 296 |
+
|
| 297 |
+
# Create info text
|
| 298 |
+
labeled_count = builder.get_labeled_count()
|
| 299 |
+
info = f"✅ Loaded dataset: {builder.metadata.name}\n"
|
| 300 |
+
info += f"📊 Samples: {len(samples)} ({labeled_count} labeled)\n"
|
| 301 |
+
info += f"🏷️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
|
| 302 |
+
info += "📝 Ready for preprocessing! You can also edit samples below."
|
| 303 |
+
|
| 304 |
+
# Get first sample preview
|
| 305 |
+
first_sample = builder.samples[0]
|
| 306 |
+
preview = (
|
| 307 |
+
first_sample.audio_path,
|
| 308 |
+
first_sample.filename,
|
| 309 |
+
first_sample.caption,
|
| 310 |
+
first_sample.lyrics,
|
| 311 |
+
first_sample.bpm,
|
| 312 |
+
first_sample.keyscale,
|
| 313 |
+
first_sample.timesignature,
|
| 314 |
+
first_sample.duration,
|
| 315 |
+
first_sample.language,
|
| 316 |
+
first_sample.is_instrumental,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return (info, table_data, gr.Slider(maximum=slider_max, value=0), builder) + preview
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def preprocess_dataset(
|
| 323 |
+
output_dir: str,
|
| 324 |
+
dit_handler,
|
| 325 |
+
builder_state: Optional[DatasetBuilder],
|
| 326 |
+
progress=None,
|
| 327 |
+
) -> str:
|
| 328 |
+
"""Preprocess dataset to tensor files for fast training.
|
| 329 |
+
|
| 330 |
+
This converts audio files to VAE latents and text to embeddings.
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
Status message
|
| 334 |
+
"""
|
| 335 |
+
if builder_state is None:
|
| 336 |
+
return "❌ No dataset loaded. Please scan a directory first."
|
| 337 |
+
|
| 338 |
+
if not builder_state.samples:
|
| 339 |
+
return "❌ No samples in dataset."
|
| 340 |
+
|
| 341 |
+
labeled_count = builder_state.get_labeled_count()
|
| 342 |
+
if labeled_count == 0:
|
| 343 |
+
return "❌ No labeled samples. Please auto-label or manually label samples first."
|
| 344 |
+
|
| 345 |
+
if not output_dir or not output_dir.strip():
|
| 346 |
+
return "❌ Please enter an output directory."
|
| 347 |
+
|
| 348 |
+
if dit_handler is None or dit_handler.model is None:
|
| 349 |
+
return "❌ Model not initialized. Please initialize the service first."
|
| 350 |
+
|
| 351 |
+
def progress_callback(msg):
|
| 352 |
+
if progress:
|
| 353 |
+
try:
|
| 354 |
+
progress(msg)
|
| 355 |
+
except:
|
| 356 |
+
pass
|
| 357 |
+
|
| 358 |
+
# Run preprocessing
|
| 359 |
+
output_paths, status = builder_state.preprocess_to_tensors(
|
| 360 |
+
dit_handler=dit_handler,
|
| 361 |
+
output_dir=output_dir.strip(),
|
| 362 |
+
progress_callback=progress_callback,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
return status
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def load_training_dataset(
|
| 369 |
+
tensor_dir: str,
|
| 370 |
+
) -> str:
|
| 371 |
+
"""Load a preprocessed tensor dataset for training.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Info text about the dataset
|
| 375 |
+
"""
|
| 376 |
+
if not tensor_dir or not tensor_dir.strip():
|
| 377 |
+
return "❌ Please enter a tensor directory path"
|
| 378 |
+
|
| 379 |
+
tensor_dir = tensor_dir.strip()
|
| 380 |
+
|
| 381 |
+
if not os.path.exists(tensor_dir):
|
| 382 |
+
return f"❌ Directory not found: {tensor_dir}"
|
| 383 |
+
|
| 384 |
+
if not os.path.isdir(tensor_dir):
|
| 385 |
+
return f"❌ Not a directory: {tensor_dir}"
|
| 386 |
+
|
| 387 |
+
# Check for manifest
|
| 388 |
+
manifest_path = os.path.join(tensor_dir, "manifest.json")
|
| 389 |
+
if os.path.exists(manifest_path):
|
| 390 |
+
try:
|
| 391 |
+
with open(manifest_path, 'r') as f:
|
| 392 |
+
manifest = json.load(f)
|
| 393 |
+
|
| 394 |
+
num_samples = manifest.get("num_samples", 0)
|
| 395 |
+
metadata = manifest.get("metadata", {})
|
| 396 |
+
name = metadata.get("name", "Unknown")
|
| 397 |
+
custom_tag = metadata.get("custom_tag", "")
|
| 398 |
+
|
| 399 |
+
info = f"✅ Loaded preprocessed dataset: {name}\n"
|
| 400 |
+
info += f"📊 Samples: {num_samples} preprocessed tensors\n"
|
| 401 |
+
info += f"🏷️ Custom Tag: {custom_tag or '(none)'}"
|
| 402 |
+
|
| 403 |
+
return info
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.warning(f"Failed to read manifest: {e}")
|
| 406 |
+
|
| 407 |
+
# Fallback: count .pt files
|
| 408 |
+
pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
|
| 409 |
+
|
| 410 |
+
if not pt_files:
|
| 411 |
+
return f"❌ No .pt tensor files found in {tensor_dir}"
|
| 412 |
+
|
| 413 |
+
info = f"✅ Found {len(pt_files)} tensor files in {tensor_dir}\n"
|
| 414 |
+
info += "⚠️ No manifest.json found - using all .pt files"
|
| 415 |
+
|
| 416 |
+
return info
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# Training handlers
|
| 420 |
+
|
| 421 |
+
import time
|
| 422 |
+
import re
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _format_duration(seconds):
|
| 426 |
+
"""Format seconds to human readable string."""
|
| 427 |
+
seconds = int(seconds)
|
| 428 |
+
if seconds < 60:
|
| 429 |
+
return f"{seconds}s"
|
| 430 |
+
elif seconds < 3600:
|
| 431 |
+
return f"{seconds // 60}m {seconds % 60}s"
|
| 432 |
+
else:
|
| 433 |
+
return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def start_training(
|
| 437 |
+
tensor_dir: str,
|
| 438 |
+
dit_handler,
|
| 439 |
+
lora_rank: int,
|
| 440 |
+
lora_alpha: int,
|
| 441 |
+
lora_dropout: float,
|
| 442 |
+
learning_rate: float,
|
| 443 |
+
train_epochs: int,
|
| 444 |
+
train_batch_size: int,
|
| 445 |
+
gradient_accumulation: int,
|
| 446 |
+
save_every_n_epochs: int,
|
| 447 |
+
training_shift: float,
|
| 448 |
+
training_seed: int,
|
| 449 |
+
lora_output_dir: str,
|
| 450 |
+
training_state: Dict,
|
| 451 |
+
progress=None,
|
| 452 |
+
):
|
| 453 |
+
"""Start LoRA training from preprocessed tensors.
|
| 454 |
+
|
| 455 |
+
This is a generator function that yields progress updates.
|
| 456 |
+
"""
|
| 457 |
+
if not tensor_dir or not tensor_dir.strip():
|
| 458 |
+
yield "❌ Please enter a tensor directory path", "", None, training_state
|
| 459 |
+
return
|
| 460 |
+
|
| 461 |
+
tensor_dir = tensor_dir.strip()
|
| 462 |
+
|
| 463 |
+
if not os.path.exists(tensor_dir):
|
| 464 |
+
yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
|
| 465 |
+
return
|
| 466 |
+
|
| 467 |
+
if dit_handler is None or dit_handler.model is None:
|
| 468 |
+
yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
|
| 469 |
+
return
|
| 470 |
+
|
| 471 |
+
# Check for required training dependencies
|
| 472 |
+
try:
|
| 473 |
+
from lightning.fabric import Fabric
|
| 474 |
+
from peft import get_peft_model, LoraConfig
|
| 475 |
+
except ImportError as e:
|
| 476 |
+
yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
|
| 477 |
+
return
|
| 478 |
+
|
| 479 |
+
training_state["is_training"] = True
|
| 480 |
+
training_state["should_stop"] = False
|
| 481 |
+
|
| 482 |
+
try:
|
| 483 |
+
from acestep.training.trainer import LoRATrainer
|
| 484 |
+
from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
|
| 485 |
+
|
| 486 |
+
# Create configs
|
| 487 |
+
lora_config = LoRAConfigClass(
|
| 488 |
+
r=lora_rank,
|
| 489 |
+
alpha=lora_alpha,
|
| 490 |
+
dropout=lora_dropout,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
training_config = TrainingConfig(
|
| 494 |
+
shift=training_shift,
|
| 495 |
+
learning_rate=learning_rate,
|
| 496 |
+
batch_size=train_batch_size,
|
| 497 |
+
gradient_accumulation_steps=gradient_accumulation,
|
| 498 |
+
max_epochs=train_epochs,
|
| 499 |
+
save_every_n_epochs=save_every_n_epochs,
|
| 500 |
+
seed=training_seed,
|
| 501 |
+
output_dir=lora_output_dir,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
import pandas as pd
|
| 505 |
+
|
| 506 |
+
# Initialize training log and loss history
|
| 507 |
+
log_lines = []
|
| 508 |
+
loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
|
| 509 |
+
|
| 510 |
+
# Start timer
|
| 511 |
+
start_time = time.time()
|
| 512 |
+
|
| 513 |
+
yield f"🚀 Starting training from {tensor_dir}...", "", loss_data, training_state
|
| 514 |
+
|
| 515 |
+
# Create trainer
|
| 516 |
+
trainer = LoRATrainer(
|
| 517 |
+
dit_handler=dit_handler,
|
| 518 |
+
lora_config=lora_config,
|
| 519 |
+
training_config=training_config,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Collect loss history
|
| 523 |
+
step_list = []
|
| 524 |
+
loss_list = []
|
| 525 |
+
|
| 526 |
+
# Train with progress updates using preprocessed tensors
|
| 527 |
+
for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
|
| 528 |
+
# Calculate elapsed time and ETA
|
| 529 |
+
elapsed_seconds = time.time() - start_time
|
| 530 |
+
time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
|
| 531 |
+
|
| 532 |
+
# Parse "Epoch x/y" from status to calculate ETA
|
| 533 |
+
match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
|
| 534 |
+
if match:
|
| 535 |
+
current_ep = int(match.group(1))
|
| 536 |
+
total_ep = int(match.group(2))
|
| 537 |
+
if current_ep > 0:
|
| 538 |
+
eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
|
| 539 |
+
time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
|
| 540 |
+
|
| 541 |
+
# Display status with time info
|
| 542 |
+
display_status = f"{status}\n{time_info}"
|
| 543 |
+
|
| 544 |
+
# Terminal log
|
| 545 |
+
log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
|
| 546 |
+
logger.info(log_msg)
|
| 547 |
+
|
| 548 |
+
# Add to UI log
|
| 549 |
+
log_lines.append(status)
|
| 550 |
+
if len(log_lines) > 15:
|
| 551 |
+
log_lines = log_lines[-15:]
|
| 552 |
+
log_text = "\n".join(log_lines)
|
| 553 |
+
|
| 554 |
+
# Track loss for plot (only valid values)
|
| 555 |
+
if step > 0 and loss is not None and loss == loss: # Check for NaN
|
| 556 |
+
step_list.append(step)
|
| 557 |
+
loss_list.append(float(loss))
|
| 558 |
+
loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
|
| 559 |
+
|
| 560 |
+
yield display_status, log_text, loss_data, training_state
|
| 561 |
+
|
| 562 |
+
if training_state.get("should_stop", False):
|
| 563 |
+
logger.info("⏹️ Training stopped by user")
|
| 564 |
+
log_lines.append("⏹️ Training stopped by user")
|
| 565 |
+
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
|
| 566 |
+
break
|
| 567 |
+
|
| 568 |
+
total_time = time.time() - start_time
|
| 569 |
+
training_state["is_training"] = False
|
| 570 |
+
completion_msg = f"✅ Training completed! Total time: {_format_duration(total_time)}"
|
| 571 |
+
|
| 572 |
+
logger.info(completion_msg)
|
| 573 |
+
log_lines.append(completion_msg)
|
| 574 |
+
|
| 575 |
+
yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
|
| 576 |
+
|
| 577 |
+
except Exception as e:
|
| 578 |
+
logger.exception("Training error")
|
| 579 |
+
training_state["is_training"] = False
|
| 580 |
+
import pandas as pd
|
| 581 |
+
empty_df = pd.DataFrame({"step": [], "loss": []})
|
| 582 |
+
yield f"❌ Error: {str(e)}", str(e), empty_df, training_state
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def stop_training(training_state: Dict) -> Tuple[str, Dict]:
|
| 586 |
+
"""Stop the current training process.
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
Tuple of (status, training_state)
|
| 590 |
+
"""
|
| 591 |
+
if not training_state.get("is_training", False):
|
| 592 |
+
return "⚠️ No training in progress", training_state
|
| 593 |
+
|
| 594 |
+
training_state["should_stop"] = True
|
| 595 |
+
return "⏹️ Stopping training...", training_state
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def export_lora(
|
| 599 |
+
export_path: str,
|
| 600 |
+
lora_output_dir: str,
|
| 601 |
+
) -> str:
|
| 602 |
+
"""Export the trained LoRA weights.
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
Status message
|
| 606 |
+
"""
|
| 607 |
+
if not export_path or not export_path.strip():
|
| 608 |
+
return "❌ Please enter an export path"
|
| 609 |
+
|
| 610 |
+
# Check if there's a trained model to export
|
| 611 |
+
final_dir = os.path.join(lora_output_dir, "final")
|
| 612 |
+
checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
|
| 613 |
+
|
| 614 |
+
# Prefer final, fallback to checkpoints
|
| 615 |
+
if os.path.exists(final_dir):
|
| 616 |
+
source_path = final_dir
|
| 617 |
+
elif os.path.exists(checkpoint_dir):
|
| 618 |
+
# Find the latest checkpoint
|
| 619 |
+
checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
|
| 620 |
+
if not checkpoints:
|
| 621 |
+
return "❌ No checkpoints found"
|
| 622 |
+
|
| 623 |
+
checkpoints.sort(key=lambda x: int(x.split("_")[1]))
|
| 624 |
+
latest = checkpoints[-1]
|
| 625 |
+
source_path = os.path.join(checkpoint_dir, latest)
|
| 626 |
+
else:
|
| 627 |
+
return f"❌ No trained model found in {lora_output_dir}"
|
| 628 |
+
|
| 629 |
+
try:
|
| 630 |
+
import shutil
|
| 631 |
+
|
| 632 |
+
export_path = export_path.strip()
|
| 633 |
+
os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
|
| 634 |
+
|
| 635 |
+
if os.path.exists(export_path):
|
| 636 |
+
shutil.rmtree(export_path)
|
| 637 |
+
|
| 638 |
+
shutil.copytree(source_path, export_path)
|
| 639 |
+
|
| 640 |
+
return f"✅ LoRA exported to {export_path}"
|
| 641 |
+
|
| 642 |
+
except Exception as e:
|
| 643 |
+
logger.exception("Export error")
|
| 644 |
+
return f"❌ Export failed: {str(e)}"
|
acestep/gradio_ui/i18n.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Internationalization (i18n) module for Gradio UI
|
| 3 |
+
Supports multiple languages with easy translation management
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class I18n:
|
| 11 |
+
"""Internationalization handler"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, default_language: str = "en"):
|
| 14 |
+
"""
|
| 15 |
+
Initialize i18n handler
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
default_language: Default language code (en, zh, ja, etc.)
|
| 19 |
+
"""
|
| 20 |
+
self.current_language = default_language
|
| 21 |
+
self.translations: Dict[str, Dict[str, str]] = {}
|
| 22 |
+
self._load_all_translations()
|
| 23 |
+
|
| 24 |
+
def _load_all_translations(self):
|
| 25 |
+
"""Load all translation files from i18n directory"""
|
| 26 |
+
current_file = os.path.abspath(__file__)
|
| 27 |
+
module_dir = os.path.dirname(current_file)
|
| 28 |
+
i18n_dir = os.path.join(module_dir, "i18n")
|
| 29 |
+
|
| 30 |
+
if not os.path.exists(i18n_dir):
|
| 31 |
+
# Create i18n directory if it doesn't exist
|
| 32 |
+
os.makedirs(i18n_dir)
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
# Load all JSON files in i18n directory
|
| 36 |
+
for filename in os.listdir(i18n_dir):
|
| 37 |
+
if filename.endswith(".json"):
|
| 38 |
+
lang_code = filename[:-5] # Remove .json extension
|
| 39 |
+
filepath = os.path.join(i18n_dir, filename)
|
| 40 |
+
try:
|
| 41 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 42 |
+
self.translations[lang_code] = json.load(f)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Error loading translation file {filename}: {e}")
|
| 45 |
+
|
| 46 |
+
def set_language(self, language: str):
|
| 47 |
+
"""Set current language"""
|
| 48 |
+
if language in self.translations:
|
| 49 |
+
self.current_language = language
|
| 50 |
+
else:
|
| 51 |
+
print(f"Warning: Language '{language}' not found, using default")
|
| 52 |
+
|
| 53 |
+
def t(self, key: str, **kwargs) -> str:
|
| 54 |
+
"""
|
| 55 |
+
Translate a key to current language
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
key: Translation key (dot-separated for nested keys)
|
| 59 |
+
**kwargs: Optional format parameters
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Translated string
|
| 63 |
+
"""
|
| 64 |
+
# Get translation from current language
|
| 65 |
+
translation = self._get_nested_value(
|
| 66 |
+
self.translations.get(self.current_language, {}),
|
| 67 |
+
key
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Fallback to English if not found
|
| 71 |
+
if translation is None:
|
| 72 |
+
translation = self._get_nested_value(
|
| 73 |
+
self.translations.get('en', {}),
|
| 74 |
+
key
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Final fallback to key itself
|
| 78 |
+
if translation is None:
|
| 79 |
+
translation = key
|
| 80 |
+
|
| 81 |
+
# Apply formatting if kwargs provided
|
| 82 |
+
if kwargs:
|
| 83 |
+
try:
|
| 84 |
+
translation = translation.format(**kwargs)
|
| 85 |
+
except KeyError:
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
return translation
|
| 89 |
+
|
| 90 |
+
def _get_nested_value(self, data: dict, key: str) -> Optional[str]:
|
| 91 |
+
"""
|
| 92 |
+
Get nested dictionary value using dot notation
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
data: Dictionary to search
|
| 96 |
+
key: Dot-separated key (e.g., "section.subsection.key")
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Value if found, None otherwise
|
| 100 |
+
"""
|
| 101 |
+
keys = key.split('.')
|
| 102 |
+
current = data
|
| 103 |
+
|
| 104 |
+
for k in keys:
|
| 105 |
+
if isinstance(current, dict) and k in current:
|
| 106 |
+
current = current[k]
|
| 107 |
+
else:
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
return current if isinstance(current, str) else None
|
| 111 |
+
|
| 112 |
+
def get_available_languages(self) -> list:
|
| 113 |
+
"""Get list of available language codes"""
|
| 114 |
+
return list(self.translations.keys())
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Global i18n instance
|
| 118 |
+
_i18n_instance: Optional[I18n] = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_i18n(language: Optional[str] = None) -> I18n:
|
| 122 |
+
"""
|
| 123 |
+
Get global i18n instance
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
language: Optional language to set
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
I18n instance
|
| 130 |
+
"""
|
| 131 |
+
global _i18n_instance
|
| 132 |
+
|
| 133 |
+
if _i18n_instance is None:
|
| 134 |
+
_i18n_instance = I18n(default_language=language or "en")
|
| 135 |
+
elif language is not None:
|
| 136 |
+
_i18n_instance.set_language(language)
|
| 137 |
+
|
| 138 |
+
return _i18n_instance
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def t(key: str, **kwargs) -> str:
|
| 142 |
+
"""
|
| 143 |
+
Convenience function for translation
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
key: Translation key
|
| 147 |
+
**kwargs: Optional format parameters
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Translated string
|
| 151 |
+
"""
|
| 152 |
+
return get_i18n().t(key, **kwargs)
|
acestep/gradio_ui/i18n/en.json
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 Playground💡",
|
| 4 |
+
"subtitle": "Pushing the Boundaries of Open-Source Music Generation"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 Dataset Explorer",
|
| 8 |
+
"dataset_label": "Dataset",
|
| 9 |
+
"dataset_info": "Choose dataset to explore",
|
| 10 |
+
"import_btn": "📥 Import Dataset",
|
| 11 |
+
"search_type_label": "Search Type",
|
| 12 |
+
"search_type_info": "How to find items",
|
| 13 |
+
"search_value_label": "Search Value",
|
| 14 |
+
"search_value_placeholder": "Enter keys or index (leave empty for random)",
|
| 15 |
+
"search_value_info": "Keys: exact match, Index: 0 to dataset size-1",
|
| 16 |
+
"instruction_label": "📝 Instruction",
|
| 17 |
+
"instruction_placeholder": "No instruction available",
|
| 18 |
+
"metadata_title": "📋 Item Metadata (JSON)",
|
| 19 |
+
"metadata_label": "Complete Item Information",
|
| 20 |
+
"source_audio": "Source Audio",
|
| 21 |
+
"target_audio": "Target Audio",
|
| 22 |
+
"reference_audio": "Reference Audio",
|
| 23 |
+
"get_item_btn": "🔍 Get Item",
|
| 24 |
+
"use_src_checkbox": "Use Source Audio from Dataset",
|
| 25 |
+
"use_src_info": "Check to use the source audio from dataset",
|
| 26 |
+
"data_status_label": "📊 Data Status",
|
| 27 |
+
"data_status_default": "❌ No dataset imported",
|
| 28 |
+
"autofill_btn": "📋 Auto-fill Generation Form"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 Service Configuration",
|
| 32 |
+
"checkpoint_label": "Checkpoint File",
|
| 33 |
+
"checkpoint_info": "Select a trained model checkpoint file (full path or filename)",
|
| 34 |
+
"refresh_btn": "🔄 Refresh",
|
| 35 |
+
"model_path_label": "Main Model Path",
|
| 36 |
+
"model_path_info": "Select the model configuration directory (auto-scanned from checkpoints)",
|
| 37 |
+
"device_label": "Device",
|
| 38 |
+
"device_info": "Processing device (auto-detect recommended)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM Model Path",
|
| 40 |
+
"lm_model_path_info": "Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)",
|
| 41 |
+
"backend_label": "5Hz LM Backend",
|
| 42 |
+
"backend_info": "Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)",
|
| 43 |
+
"init_llm_label": "Initialize 5Hz LM",
|
| 44 |
+
"init_llm_info": "Check to initialize 5Hz LM during service initialization",
|
| 45 |
+
"flash_attention_label": "Use Flash Attention",
|
| 46 |
+
"flash_attention_info_enabled": "Enable flash attention for faster inference (requires flash_attn package)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attention not available (flash_attn package not installed)",
|
| 48 |
+
"offload_cpu_label": "Offload to CPU",
|
| 49 |
+
"offload_cpu_info": "Offload models to CPU when not in use to save GPU memory",
|
| 50 |
+
"offload_dit_cpu_label": "Offload DiT to CPU",
|
| 51 |
+
"offload_dit_cpu_info": "Offload DiT to CPU (needs Offload to CPU)",
|
| 52 |
+
"init_btn": "Initialize Service",
|
| 53 |
+
"status_label": "Status",
|
| 54 |
+
"language_label": "UI Language",
|
| 55 |
+
"language_info": "Select interface language"
|
| 56 |
+
},
|
| 57 |
+
"generation": {
|
| 58 |
+
"required_inputs": "📝 Required Inputs",
|
| 59 |
+
"task_type_label": "Task Type",
|
| 60 |
+
"task_type_info": "Select the task type for generation",
|
| 61 |
+
"instruction_label": "Instruction",
|
| 62 |
+
"instruction_info": "Instruction is automatically generated based on task type",
|
| 63 |
+
"load_btn": "Load",
|
| 64 |
+
"track_name_label": "Track Name",
|
| 65 |
+
"track_name_info": "Select track name for lego/extract tasks",
|
| 66 |
+
"track_classes_label": "Track Names",
|
| 67 |
+
"track_classes_info": "Select multiple track classes for complete task",
|
| 68 |
+
"audio_uploads": "🎵 Audio Uploads",
|
| 69 |
+
"reference_audio": "Reference Audio (optional)",
|
| 70 |
+
"source_audio": "Source Audio (optional)",
|
| 71 |
+
"convert_codes_btn": "Convert to Codes",
|
| 72 |
+
"lm_codes_hints": "🎼 LM Codes Hints",
|
| 73 |
+
"lm_codes_label": "LM Codes Hints",
|
| 74 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 75 |
+
"lm_codes_info": "Paste LM codes hints for text2music generation",
|
| 76 |
+
"lm_codes_sample": "LM Codes Hints (Sample {n})",
|
| 77 |
+
"lm_codes_sample_info": "Codes for sample {n}",
|
| 78 |
+
"transcribe_btn": "Transcribe",
|
| 79 |
+
"repainting_controls": "🎨 Repainting Controls (seconds)",
|
| 80 |
+
"repainting_start": "Repainting Start",
|
| 81 |
+
"repainting_end": "Repainting End",
|
| 82 |
+
"mode_label": "Generation Mode",
|
| 83 |
+
"mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
|
| 84 |
+
"mode_simple": "Simple",
|
| 85 |
+
"mode_custom": "Custom",
|
| 86 |
+
"simple_query_label": "Song Description",
|
| 87 |
+
"simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
|
| 88 |
+
"simple_query_info": "Enter a natural language description of the music you want to generate",
|
| 89 |
+
"simple_vocal_language_label": "Vocal Language (optional)",
|
| 90 |
+
"simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
|
| 91 |
+
"create_sample_btn": "Create Sample",
|
| 92 |
+
"caption_title": "📝 Music Caption",
|
| 93 |
+
"caption_label": "Music Caption (optional)",
|
| 94 |
+
"caption_placeholder": "A peaceful acoustic guitar melody with soft vocals...",
|
| 95 |
+
"caption_info": "Describe the style, genre, instruments, and mood",
|
| 96 |
+
"lyrics_title": "📝 Lyrics",
|
| 97 |
+
"lyrics_label": "Lyrics (optional)",
|
| 98 |
+
"lyrics_placeholder": "[Verse 1]\\nUnder the starry night\\nI feel so alive...",
|
| 99 |
+
"lyrics_info": "Song lyrics with structure",
|
| 100 |
+
"instrumental_label": "Instrumental",
|
| 101 |
+
"format_btn": "Format",
|
| 102 |
+
"optional_params": "⚙️ Optional Parameters",
|
| 103 |
+
"vocal_language_label": "Vocal Language (optional)",
|
| 104 |
+
"vocal_language_info": "use `unknown` for inst",
|
| 105 |
+
"bpm_label": "BPM (optional)",
|
| 106 |
+
"bpm_info": "leave empty for N/A",
|
| 107 |
+
"keyscale_label": "KeyScale (optional)",
|
| 108 |
+
"keyscale_placeholder": "Leave empty for N/A",
|
| 109 |
+
"keyscale_info": "A-G, #/♭, major/minor",
|
| 110 |
+
"timesig_label": "Time Signature (optional)",
|
| 111 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 112 |
+
"duration_label": "Audio Duration (seconds)",
|
| 113 |
+
"duration_info": "Use -1 for random",
|
| 114 |
+
"batch_size_label": "Batch Size",
|
| 115 |
+
"batch_size_info": "Number of audio to generate (max 8)",
|
| 116 |
+
"advanced_settings": "🔧 Advanced Settings",
|
| 117 |
+
"inference_steps_label": "DiT Inference Steps",
|
| 118 |
+
"inference_steps_info": "Turbo: max 8, Base: max 200",
|
| 119 |
+
"guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
|
| 120 |
+
"guidance_scale_info": "Higher values follow text more closely",
|
| 121 |
+
"seed_label": "Seed",
|
| 122 |
+
"seed_info": "Use comma-separated values for batches",
|
| 123 |
+
"random_seed_label": "Random Seed",
|
| 124 |
+
"random_seed_info": "Enable to auto-generate seeds",
|
| 125 |
+
"audio_format_label": "Audio Format",
|
| 126 |
+
"audio_format_info": "Audio format for saved files",
|
| 127 |
+
"use_adg_label": "Use ADG",
|
| 128 |
+
"use_adg_info": "Enable Angle Domain Guidance",
|
| 129 |
+
"shift_label": "Shift",
|
| 130 |
+
"shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
|
| 131 |
+
"infer_method_label": "Inference Method",
|
| 132 |
+
"infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
| 133 |
+
"custom_timesteps_label": "Custom Timesteps",
|
| 134 |
+
"custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
|
| 135 |
+
"cfg_interval_start": "CFG Interval Start",
|
| 136 |
+
"cfg_interval_end": "CFG Interval End",
|
| 137 |
+
"lm_params_title": "🤖 LM Generation Parameters",
|
| 138 |
+
"lm_temperature_label": "LM Temperature",
|
| 139 |
+
"lm_temperature_info": "5Hz LM temperature (higher = more random)",
|
| 140 |
+
"lm_cfg_scale_label": "LM CFG Scale",
|
| 141 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = no CFG)",
|
| 142 |
+
"lm_top_k_label": "LM Top-K",
|
| 143 |
+
"lm_top_k_info": "Top-K (0 = disabled)",
|
| 144 |
+
"lm_top_p_label": "LM Top-P",
|
| 145 |
+
"lm_top_p_info": "Top-P (1.0 = disabled)",
|
| 146 |
+
"lm_negative_prompt_label": "LM Negative Prompt",
|
| 147 |
+
"lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
|
| 148 |
+
"lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
|
| 149 |
+
"cot_metas_label": "CoT Metas",
|
| 150 |
+
"cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
|
| 151 |
+
"cot_language_label": "CoT Language",
|
| 152 |
+
"cot_language_info": "Generate language in CoT (chain-of-thought)",
|
| 153 |
+
"constrained_debug_label": "Constrained Decoding Debug",
|
| 154 |
+
"constrained_debug_info": "Enable debug logging for constrained decoding (check to see detailed logs)",
|
| 155 |
+
"auto_score_label": "Auto Score",
|
| 156 |
+
"auto_score_info": "Automatically calculate quality scores for all generated audios",
|
| 157 |
+
"auto_lrc_label": "Auto LRC",
|
| 158 |
+
"auto_lrc_info": "Automatically generate LRC lyrics timestamps for all generated audios",
|
| 159 |
+
"lm_batch_chunk_label": "LM Batch Chunk Size",
|
| 160 |
+
"lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
|
| 161 |
+
"codes_strength_label": "LM Codes Strength",
|
| 162 |
+
"codes_strength_info": "Control how many denoising steps use LM-generated codes",
|
| 163 |
+
"cover_strength_label": "Audio Cover Strength",
|
| 164 |
+
"cover_strength_info": "Control how many denoising steps use cover mode",
|
| 165 |
+
"score_sensitivity_label": "Quality Score Sensitivity",
|
| 166 |
+
"score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
|
| 167 |
+
"think_label": "Think",
|
| 168 |
+
"parallel_thinking_label": "ParallelThinking",
|
| 169 |
+
"generate_btn": "🎵 Generate Music",
|
| 170 |
+
"autogen_label": "AutoGen",
|
| 171 |
+
"caption_rewrite_label": "CaptionRewrite"
|
| 172 |
+
},
|
| 173 |
+
"results": {
|
| 174 |
+
"title": "🎵 Results",
|
| 175 |
+
"generated_music": "🎵 Generated Music (Sample {n})",
|
| 176 |
+
"send_to_src_btn": "🔗 Send To Src Audio",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
+
"save_btn": "💾 Save",
|
| 180 |
+
"score_btn": "📊 Score",
|
| 181 |
+
"lrc_btn": "🎵 LRC",
|
| 182 |
+
"quality_score_label": "Quality Score (Sample {n})",
|
| 183 |
+
"quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
|
| 184 |
+
"codes_label": "LM Codes (Sample {n})",
|
| 185 |
+
"lrc_label": "Lyrics Timestamps (Sample {n})",
|
| 186 |
+
"lrc_placeholder": "Click 'LRC' to generate timestamps",
|
| 187 |
+
"details_accordion": "📊 Score & LRC & LM Codes",
|
| 188 |
+
"generation_status": "Generation Status",
|
| 189 |
+
"current_batch": "Current Batch",
|
| 190 |
+
"batch_indicator": "Batch {current} / {total}",
|
| 191 |
+
"next_batch_status": "Next Batch Status",
|
| 192 |
+
"prev_btn": "◀ Previous",
|
| 193 |
+
"next_btn": "Next ▶",
|
| 194 |
+
"restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
|
| 195 |
+
"batch_results_title": "📁 Batch Results & Generation Details",
|
| 196 |
+
"all_files_label": "📁 All Generated Files (Download)",
|
| 197 |
+
"generation_details": "Generation Details"
|
| 198 |
+
},
|
| 199 |
+
"messages": {
|
| 200 |
+
"no_audio_to_save": "❌ No audio to save",
|
| 201 |
+
"save_success": "✅ Saved audio and metadata to {filename}",
|
| 202 |
+
"save_failed": "❌ Failed to save: {error}",
|
| 203 |
+
"no_file_selected": "⚠️ No file selected",
|
| 204 |
+
"params_loaded": "✅ Parameters loaded from {filename}",
|
| 205 |
+
"invalid_json": "❌ Invalid JSON file: {error}",
|
| 206 |
+
"load_error": "❌ Error loading file: {error}",
|
| 207 |
+
"example_loaded": "📁 Loaded example from {filename}",
|
| 208 |
+
"example_failed": "Failed to parse JSON file {filename}: {error}",
|
| 209 |
+
"example_error": "Error loading example: {error}",
|
| 210 |
+
"lm_generated": "🤖 Generated example using LM",
|
| 211 |
+
"lm_fallback": "Failed to generate example using LM, falling back to examples directory",
|
| 212 |
+
"lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
|
| 213 |
+
"autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
|
| 214 |
+
"batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
|
| 215 |
+
"batch_generating": "🔄 Starting background generation for Batch {n}...",
|
| 216 |
+
"batch_failed": "❌ Background generation failed: {error}",
|
| 217 |
+
"viewing_batch": "✅ Viewing Batch {n}",
|
| 218 |
+
"at_first_batch": "Already at first batch",
|
| 219 |
+
"at_last_batch": "No next batch available",
|
| 220 |
+
"batch_not_found": "Batch {n} not found in queue",
|
| 221 |
+
"no_batch_data": "No batch data found to restore.",
|
| 222 |
+
"params_restored": "✅ UI Parameters restored from Batch {n}",
|
| 223 |
+
"scoring_failed": "❌ Error: Batch data not found",
|
| 224 |
+
"no_codes": "❌ No audio codes available. Please generate music first.",
|
| 225 |
+
"score_failed": "❌ Scoring failed: {error}",
|
| 226 |
+
"score_error": "❌ Error calculating score: {error}",
|
| 227 |
+
"lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
|
| 228 |
+
"lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
|
| 229 |
+
"lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
|
| 230 |
+
"lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
|
| 231 |
+
"lrc_empty_result": "⚠️ LRC generation produced empty result.",
|
| 232 |
+
"empty_query": "⚠️ Please enter a music description.",
|
| 233 |
+
"sample_creation_failed": "❌ Failed to create sample. Please try again.",
|
| 234 |
+
"sample_created": "✅ Sample created! Review the caption and lyrics, then click Generate Music.",
|
| 235 |
+
"simple_examples_not_found": "⚠️ Simple mode examples directory not found.",
|
| 236 |
+
"simple_examples_empty": "⚠️ No example files found in simple mode examples.",
|
| 237 |
+
"simple_example_loaded": "🎲 Loaded random example from {filename}",
|
| 238 |
+
"format_success": "✅ Caption and lyrics formatted successfully",
|
| 239 |
+
"format_failed": "❌ Format failed: {error}",
|
| 240 |
+
"skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
|
| 241 |
+
"invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
|
| 242 |
+
"timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
|
| 243 |
+
"timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
|
| 244 |
+
}
|
| 245 |
+
}
|
acestep/gradio_ui/i18n/ja.json
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 プレイグラウンド💡",
|
| 4 |
+
"subtitle": "オープンソース音楽生成の限界を押し広げる"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 データセットエクスプローラー",
|
| 8 |
+
"dataset_label": "データセット",
|
| 9 |
+
"dataset_info": "探索するデータセットを選択",
|
| 10 |
+
"import_btn": "📥 データセットをインポート",
|
| 11 |
+
"search_type_label": "検索タイプ",
|
| 12 |
+
"search_type_info": "アイテムの検索方法",
|
| 13 |
+
"search_value_label": "検索値",
|
| 14 |
+
"search_value_placeholder": "キーまたはインデックスを入力(空白の場合はランダム)",
|
| 15 |
+
"search_value_info": "キー: 完全一致、インデックス: 0からデータセットサイズ-1",
|
| 16 |
+
"instruction_label": "📝 指示",
|
| 17 |
+
"instruction_placeholder": "利用可能な指示がありません",
|
| 18 |
+
"metadata_title": "📋 アイテムメタデータ (JSON)",
|
| 19 |
+
"metadata_label": "完全なアイテム情報",
|
| 20 |
+
"source_audio": "ソースオーディオ",
|
| 21 |
+
"target_audio": "ターゲットオーディオ",
|
| 22 |
+
"reference_audio": "リファレンスオーディオ",
|
| 23 |
+
"get_item_btn": "🔍 アイテムを取得",
|
| 24 |
+
"use_src_checkbox": "データセットのソースオーディオを使用",
|
| 25 |
+
"use_src_info": "データセットのソースオーディオを使用する場合はチェック",
|
| 26 |
+
"data_status_label": "📊 データステータス",
|
| 27 |
+
"data_status_default": "❌ データセットがインポートされていません",
|
| 28 |
+
"autofill_btn": "📋 生成フォームを自動入力"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 サービス設定",
|
| 32 |
+
"checkpoint_label": "チェックポイントファイル",
|
| 33 |
+
"checkpoint_info": "訓練済みモデルのチェックポイントファイルを選択(フルパスまたはファイル名)",
|
| 34 |
+
"refresh_btn": "🔄 更新",
|
| 35 |
+
"model_path_label": "メインモデルパス",
|
| 36 |
+
"model_path_info": "モデル設定ディレクトリを選択(チェックポイントから自動スキャン)",
|
| 37 |
+
"device_label": "デバイス",
|
| 38 |
+
"device_info": "処理デバイス(自動検出を推奨)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM モデルパス",
|
| 40 |
+
"lm_model_path_info": "5Hz LMモデルチェックポイントを選択(チェックポイントから自動スキャン)",
|
| 41 |
+
"backend_label": "5Hz LM バックエンド",
|
| 42 |
+
"backend_info": "5Hz LMのバックエンドを選択: vllm(高速)またはpt(PyTorch、より互換性あり)",
|
| 43 |
+
"init_llm_label": "5Hz LM を初期化",
|
| 44 |
+
"init_llm_info": "サービス初期化中に5Hz LMを初期化する場合はチェック",
|
| 45 |
+
"flash_attention_label": "Flash Attention を使用",
|
| 46 |
+
"flash_attention_info_enabled": "推論を高速化するためにflash attentionを有効にする(flash_attnパッケージが必要)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attentionは利用できません(flash_attnパッケージがインストールされていません)",
|
| 48 |
+
"offload_cpu_label": "CPUにオフロード",
|
| 49 |
+
"offload_cpu_info": "使用していない時にモデルをCPUにオフロードしてGPUメモリを節約",
|
| 50 |
+
"offload_dit_cpu_label": "DiTをCPUにオフロード",
|
| 51 |
+
"offload_dit_cpu_info": "DiTをCPUにオフロード(CPUへのオフロードが必要)",
|
| 52 |
+
"init_btn": "サービスを初期化",
|
| 53 |
+
"status_label": "ステータス",
|
| 54 |
+
"language_label": "UI言語",
|
| 55 |
+
"language_info": "インターフェース言語を選択"
|
| 56 |
+
},
|
| 57 |
+
"generation": {
|
| 58 |
+
"required_inputs": "📝 必須入力",
|
| 59 |
+
"task_type_label": "タスクタイプ",
|
| 60 |
+
"task_type_info": "生成のタスクタイプを選択",
|
| 61 |
+
"instruction_label": "指示",
|
| 62 |
+
"instruction_info": "指示はタスクタイプに基づいて自動生成されます",
|
| 63 |
+
"load_btn": "読み込む",
|
| 64 |
+
"track_name_label": "トラック名",
|
| 65 |
+
"track_name_info": "lego/extractタスクのトラック名を選択",
|
| 66 |
+
"track_classes_label": "トラック名",
|
| 67 |
+
"track_classes_info": "completeタスクの複数のトラッククラスを選択",
|
| 68 |
+
"audio_uploads": "🎵 オーディオアップロード",
|
| 69 |
+
"reference_audio": "リファレンスオーディオ(オプション)",
|
| 70 |
+
"source_audio": "ソースオーディオ(オプション)",
|
| 71 |
+
"convert_codes_btn": "コードに変換",
|
| 72 |
+
"lm_codes_hints": "🎼 LM コードヒント",
|
| 73 |
+
"lm_codes_label": "LM コードヒント",
|
| 74 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 75 |
+
"lm_codes_info": "text2music生成用のLMコードヒントを貼り付け",
|
| 76 |
+
"lm_codes_sample": "LM コードヒント(サンプル {n})",
|
| 77 |
+
"lm_codes_sample_info": "サンプル{n}のコード",
|
| 78 |
+
"transcribe_btn": "転写",
|
| 79 |
+
"repainting_controls": "🎨 再描画コントロール(秒)",
|
| 80 |
+
"repainting_start": "再描画開始",
|
| 81 |
+
"repainting_end": "再描画終了",
|
| 82 |
+
"mode_label": "生成モード",
|
| 83 |
+
"mode_info": "シンプル:自然言語で音楽を説明��カスタム:キャプションと歌詞を完全にコントロール。",
|
| 84 |
+
"mode_simple": "シンプル",
|
| 85 |
+
"mode_custom": "カスタム",
|
| 86 |
+
"simple_query_label": "曲の説明",
|
| 87 |
+
"simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
|
| 88 |
+
"simple_query_info": "生成したい音楽の自然言語の説明を入力",
|
| 89 |
+
"simple_vocal_language_label": "ボーカル言語(オプション)",
|
| 90 |
+
"simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
|
| 91 |
+
"create_sample_btn": "サンプル作成",
|
| 92 |
+
"caption_title": "📝 音楽キャプション",
|
| 93 |
+
"caption_label": "音楽キャプション(オプション)",
|
| 94 |
+
"caption_placeholder": "柔らかいボーカルを伴う穏やかなアコースティックギターのメロディー...",
|
| 95 |
+
"caption_info": "スタイル、ジャンル、楽器、ムードを説明",
|
| 96 |
+
"lyrics_title": "📝 歌詞",
|
| 97 |
+
"lyrics_label": "歌詞(オプション)",
|
| 98 |
+
"lyrics_placeholder": "[バース1]\\n星空の下で\\nとても生きていると感じる...",
|
| 99 |
+
"lyrics_info": "構造を持つ曲の歌詞",
|
| 100 |
+
"instrumental_label": "インストゥルメンタル",
|
| 101 |
+
"format_btn": "フォーマット",
|
| 102 |
+
"optional_params": "⚙️ オプションパラメータ",
|
| 103 |
+
"vocal_language_label": "ボーカル言語(オプション)",
|
| 104 |
+
"vocal_language_info": "インストには`unknown`を使用",
|
| 105 |
+
"bpm_label": "BPM(オプション)",
|
| 106 |
+
"bpm_info": "空白の場合はN/A",
|
| 107 |
+
"keyscale_label": "キースケール(オプション)",
|
| 108 |
+
"keyscale_placeholder": "空白の場合はN/A",
|
| 109 |
+
"keyscale_info": "A-G, #/♭, メジャー/マイナー",
|
| 110 |
+
"timesig_label": "拍子記号(オプション)",
|
| 111 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 112 |
+
"duration_label": "オーディオ長(秒)",
|
| 113 |
+
"duration_info": "ランダムの場合は-1を使用",
|
| 114 |
+
"batch_size_label": "バッチサイズ",
|
| 115 |
+
"batch_size_info": "生成するオーディオの数(最大8)",
|
| 116 |
+
"advanced_settings": "🔧 詳細設定",
|
| 117 |
+
"inference_steps_label": "DiT 推論ステップ",
|
| 118 |
+
"inference_steps_info": "Turbo: 最大8、Base: 最大200",
|
| 119 |
+
"guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
|
| 120 |
+
"guidance_scale_info": "値が高いほどテキストに忠実に従う",
|
| 121 |
+
"seed_label": "シード",
|
| 122 |
+
"seed_info": "バッチにはカンマ区切りの値を使用",
|
| 123 |
+
"random_seed_label": "ランダムシード",
|
| 124 |
+
"random_seed_info": "有効にすると自動的にシードを生成",
|
| 125 |
+
"audio_format_label": "オーディオフォーマット",
|
| 126 |
+
"audio_format_info": "保存ファイルのオーディオフォーマット",
|
| 127 |
+
"use_adg_label": "ADG を使用",
|
| 128 |
+
"use_adg_info": "角度ドメインガイダンスを有効化",
|
| 129 |
+
"shift_label": "シフト",
|
| 130 |
+
"shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
|
| 131 |
+
"infer_method_label": "推論方法",
|
| 132 |
+
"infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
|
| 133 |
+
"custom_timesteps_label": "カスタムタイムステップ",
|
| 134 |
+
"custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
|
| 135 |
+
"cfg_interval_start": "CFG 間隔開始",
|
| 136 |
+
"cfg_interval_end": "CFG 間隔終了",
|
| 137 |
+
"lm_params_title": "🤖 LM 生成パラメータ",
|
| 138 |
+
"lm_temperature_label": "LM 温度",
|
| 139 |
+
"lm_temperature_info": "5Hz LM温度(高いほどランダム)",
|
| 140 |
+
"lm_cfg_scale_label": "LM CFG スケール",
|
| 141 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = CFGなし)",
|
| 142 |
+
"lm_top_k_label": "LM Top-K",
|
| 143 |
+
"lm_top_k_info": "Top-K (0 = 無効)",
|
| 144 |
+
"lm_top_p_label": "LM Top-P",
|
| 145 |
+
"lm_top_p_info": "Top-P (1.0 = 無効)",
|
| 146 |
+
"lm_negative_prompt_label": "LM ネガティブプロンプト",
|
| 147 |
+
"lm_negative_prompt_placeholder": "CFGのネガティブプロンプトを入力(デフォルト: NO USER INPUT)",
|
| 148 |
+
"lm_negative_prompt_info": "ネガティブプロンプト(LM CFGスケール > 1.0の場合に使用)",
|
| 149 |
+
"cot_metas_label": "CoT メタデータ",
|
| 150 |
+
"cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
|
| 151 |
+
"cot_language_label": "CoT 言語",
|
| 152 |
+
"cot_language_info": "CoTで言語を生成(思考の連鎖)",
|
| 153 |
+
"constrained_debug_label": "制約付きデコーディングデバッグ",
|
| 154 |
+
"constrained_debug_info": "制約付きデコーディングのデバッグログを有効化(チェックすると詳細ログを表示)",
|
| 155 |
+
"auto_score_label": "自動スコアリング",
|
| 156 |
+
"auto_score_info": "生成���れたすべてのオーディオの品質スコアを自動計算",
|
| 157 |
+
"auto_lrc_label": "自動 LRC",
|
| 158 |
+
"auto_lrc_info": "生成されたすべてのオーディオのLRC歌詞タイムスタンプを自動生成",
|
| 159 |
+
"lm_batch_chunk_label": "LM バッチチャンクサイズ",
|
| 160 |
+
"lm_batch_chunk_info": "LMバッチチャンクあたりの最大アイテム数(デフォルト: 8、GPUメモリによる制限)",
|
| 161 |
+
"codes_strength_label": "LM コード強度",
|
| 162 |
+
"codes_strength_info": "LM生成コードを使用するデノイジングステップ数を制御",
|
| 163 |
+
"cover_strength_label": "オーディオカバー強度",
|
| 164 |
+
"cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
|
| 165 |
+
"score_sensitivity_label": "品質スコア感度",
|
| 166 |
+
"score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
|
| 167 |
+
"think_label": "思考",
|
| 168 |
+
"parallel_thinking_label": "並列思考",
|
| 169 |
+
"generate_btn": "🎵 音楽を生成",
|
| 170 |
+
"autogen_label": "自動生成",
|
| 171 |
+
"caption_rewrite_label": "キャプション書き換え"
|
| 172 |
+
},
|
| 173 |
+
"results": {
|
| 174 |
+
"title": "🎵 結果",
|
| 175 |
+
"generated_music": "🎵 生成された音楽(サンプル {n})",
|
| 176 |
+
"send_to_src_btn": "🔗 ソースオーディオに送信",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
+
"save_btn": "💾 保存",
|
| 180 |
+
"score_btn": "📊 スコア",
|
| 181 |
+
"lrc_btn": "🎵 LRC",
|
| 182 |
+
"quality_score_label": "品質スコア(サンプル {n})",
|
| 183 |
+
"quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
|
| 184 |
+
"codes_label": "LM コード(サンプル {n})",
|
| 185 |
+
"lrc_label": "歌詞タイムスタンプ(サンプル {n})",
|
| 186 |
+
"lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
|
| 187 |
+
"details_accordion": "📊 スコア & LRC & LM コード",
|
| 188 |
+
"generation_status": "生成ステータス",
|
| 189 |
+
"current_batch": "現在のバッチ",
|
| 190 |
+
"batch_indicator": "バッチ {current} / {total}",
|
| 191 |
+
"next_batch_status": "次のバッチステータス",
|
| 192 |
+
"prev_btn": "◀ 前へ",
|
| 193 |
+
"next_btn": "次へ ▶",
|
| 194 |
+
"restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
|
| 195 |
+
"batch_results_title": "📁 バッチ結果と生成詳細",
|
| 196 |
+
"all_files_label": "📁 すべての生成ファイル(ダウンロード)",
|
| 197 |
+
"generation_details": "生成詳細"
|
| 198 |
+
},
|
| 199 |
+
"messages": {
|
| 200 |
+
"no_audio_to_save": "❌ 保存するオーディオがありません",
|
| 201 |
+
"save_success": "✅ オーディオとメタデータを {filename} に保存しました",
|
| 202 |
+
"save_failed": "❌ 保存に失敗しました: {error}",
|
| 203 |
+
"no_file_selected": "⚠️ ファイルが選択されていません",
|
| 204 |
+
"params_loaded": "✅ {filename} からパラメータを読み込みました",
|
| 205 |
+
"invalid_json": "❌ 無効なJSONファイル: {error}",
|
| 206 |
+
"load_error": "❌ ファイルの読み込みエラー: {error}",
|
| 207 |
+
"example_loaded": "📁 {filename} からサンプルを読み込みました",
|
| 208 |
+
"example_failed": "JSONファイル {filename} の解析に失敗しました: {error}",
|
| 209 |
+
"example_error": "サンプル読み込みエラー: {error}",
|
| 210 |
+
"lm_generated": "🤖 LMを使用してサンプルを生成しました",
|
| 211 |
+
"lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
|
| 212 |
+
"lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
|
| 213 |
+
"autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
|
| 214 |
+
"batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
|
| 215 |
+
"batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
|
| 216 |
+
"batch_failed": "❌ バックグラウンド生成に失敗しました: {error}",
|
| 217 |
+
"viewing_batch": "✅ バッチ {n} を表示中",
|
| 218 |
+
"at_first_batch": "すでに最初のバッチです",
|
| 219 |
+
"at_last_batch": "次のバッチはありません",
|
| 220 |
+
"batch_not_found": "キューにバッチ {n} が見つかりません",
|
| 221 |
+
"no_batch_data": "復元するバッチデータがありません。",
|
| 222 |
+
"params_restored": "✅ バッチ {n} からUIパラメータを復元しました",
|
| 223 |
+
"scoring_failed": "❌ エラー: バッチデータが見つかりません",
|
| 224 |
+
"no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
|
| 225 |
+
"score_failed": "❌ スコアリングに失敗しました: {error}",
|
| 226 |
+
"score_error": "❌ スコア計算エラー: {error}",
|
| 227 |
+
"lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
|
| 228 |
+
"lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
|
| 229 |
+
"lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
|
| 230 |
+
"lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
|
| 231 |
+
"lrc_empty_result": "⚠️ LRC生成の結果が空です。",
|
| 232 |
+
"empty_query": "⚠️ 音楽の説明を入力してください。",
|
| 233 |
+
"sample_creation_failed": "❌ サンプルの作成に失敗しました。もう一度お試しください。",
|
| 234 |
+
"sample_created": "✅ サンプルが作成されました!キャプションと歌詞を確認して、音楽を生成をクリックしてください。",
|
| 235 |
+
"simple_examples_not_found": "⚠️ シンプルモードサンプルディレクトリが見つかりません。",
|
| 236 |
+
"simple_examples_empty": "⚠️ シンプルモードサンプルにファイルがありません。",
|
| 237 |
+
"simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
|
| 238 |
+
"format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
|
| 239 |
+
"format_failed": "❌ フォーマットに失敗しました: {error}",
|
| 240 |
+
"skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
|
| 241 |
+
"invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
|
| 242 |
+
"timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
|
| 243 |
+
"timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
|
| 244 |
+
}
|
| 245 |
+
}
|
acestep/gradio_ui/i18n/zh.json
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 演练场💡",
|
| 4 |
+
"subtitle": "推动开源音乐生成的边界"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 数据集浏览器",
|
| 8 |
+
"dataset_label": "数据集",
|
| 9 |
+
"dataset_info": "选择要浏览的数据集",
|
| 10 |
+
"import_btn": "📥 导入数据集",
|
| 11 |
+
"search_type_label": "搜索类型",
|
| 12 |
+
"search_type_info": "如何查找项目",
|
| 13 |
+
"search_value_label": "搜索值",
|
| 14 |
+
"search_value_placeholder": "输入键或索引(留空表示随机)",
|
| 15 |
+
"search_value_info": "键: 精确匹配, 索引: 0到数据集大小-1",
|
| 16 |
+
"instruction_label": "📝 指令",
|
| 17 |
+
"instruction_placeholder": "无可用指令",
|
| 18 |
+
"metadata_title": "📋 项目元数据 (JSON)",
|
| 19 |
+
"metadata_label": "完整项目信息",
|
| 20 |
+
"source_audio": "源音频",
|
| 21 |
+
"target_audio": "目标音频",
|
| 22 |
+
"reference_audio": "参考音频",
|
| 23 |
+
"get_item_btn": "🔍 获取项目",
|
| 24 |
+
"use_src_checkbox": "使用数据集中的源音频",
|
| 25 |
+
"use_src_info": "勾选以使用数据集中的源音频",
|
| 26 |
+
"data_status_label": "📊 数据状态",
|
| 27 |
+
"data_status_default": "❌ 未导入数据集",
|
| 28 |
+
"autofill_btn": "📋 自动填充生成表单"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 服务配置",
|
| 32 |
+
"checkpoint_label": "检查点文件",
|
| 33 |
+
"checkpoint_info": "选择训练好的模型检查点文件(完整路径或文件名)",
|
| 34 |
+
"refresh_btn": "🔄 刷新",
|
| 35 |
+
"model_path_label": "主模型路径",
|
| 36 |
+
"model_path_info": "选择模型配置目录(从检查点自动扫描)",
|
| 37 |
+
"device_label": "设备",
|
| 38 |
+
"device_info": "处理设备(建议自动检测)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM 模型路径",
|
| 40 |
+
"lm_model_path_info": "选择5Hz LM模型检查点(从检查点自动扫描)",
|
| 41 |
+
"backend_label": "5Hz LM 后端",
|
| 42 |
+
"backend_info": "选择5Hz LM的后端: vllm(更快)或pt(PyTorch, 更兼容)",
|
| 43 |
+
"init_llm_label": "初始化 5Hz LM",
|
| 44 |
+
"init_llm_info": "勾选以在服务初始化期间初始化5Hz LM",
|
| 45 |
+
"flash_attention_label": "使用Flash Attention",
|
| 46 |
+
"flash_attention_info_enabled": "启用flash attention以加快推理速度(需要flash_attn包)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attention不可用(未安装flash_attn包)",
|
| 48 |
+
"offload_cpu_label": "卸载到CPU",
|
| 49 |
+
"offload_cpu_info": "不使用时将模型卸载到CPU以节省GPU内存",
|
| 50 |
+
"offload_dit_cpu_label": "将DiT卸载到CPU",
|
| 51 |
+
"offload_dit_cpu_info": "将DiT卸载到CPU(需要启用卸载到CPU)",
|
| 52 |
+
"init_btn": "初始化服务",
|
| 53 |
+
"status_label": "状态",
|
| 54 |
+
"language_label": "界面语言",
|
| 55 |
+
"language_info": "选择界面语言"
|
| 56 |
+
},
|
| 57 |
+
"generation": {
|
| 58 |
+
"required_inputs": "📝 必需输入",
|
| 59 |
+
"task_type_label": "任务类型",
|
| 60 |
+
"task_type_info": "选择生成的任务类型",
|
| 61 |
+
"instruction_label": "指令",
|
| 62 |
+
"instruction_info": "指令根据任务类型自动生成",
|
| 63 |
+
"load_btn": "加载",
|
| 64 |
+
"track_name_label": "音轨名称",
|
| 65 |
+
"track_name_info": "为lego/extract任务选择音轨名称",
|
| 66 |
+
"track_classes_label": "音轨名称",
|
| 67 |
+
"track_classes_info": "为complete任务选择多个音轨类别",
|
| 68 |
+
"audio_uploads": "🎵 音频上传",
|
| 69 |
+
"reference_audio": "参考音频(可选)",
|
| 70 |
+
"source_audio": "源音频(可选)",
|
| 71 |
+
"convert_codes_btn": "转换为代码",
|
| 72 |
+
"lm_codes_hints": "🎼 LM 代码提示",
|
| 73 |
+
"lm_codes_label": "LM 代码提示",
|
| 74 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 75 |
+
"lm_codes_info": "粘贴用于text2music生成的LM代码提示",
|
| 76 |
+
"lm_codes_sample": "LM 代码提示(样本 {n})",
|
| 77 |
+
"lm_codes_sample_info": "样本{n}的代码",
|
| 78 |
+
"transcribe_btn": "转录",
|
| 79 |
+
"repainting_controls": "🎨 重绘控制(秒)",
|
| 80 |
+
"repainting_start": "重绘开始",
|
| 81 |
+
"repainting_end": "重绘结束",
|
| 82 |
+
"mode_label": "生成模式",
|
| 83 |
+
"mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
|
| 84 |
+
"mode_simple": "简单",
|
| 85 |
+
"mode_custom": "自定义",
|
| 86 |
+
"simple_query_label": "歌曲描述",
|
| 87 |
+
"simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
|
| 88 |
+
"simple_query_info": "输入你想生成的音乐的自然语言描述",
|
| 89 |
+
"simple_vocal_language_label": "人声语言(可选)",
|
| 90 |
+
"simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
|
| 91 |
+
"create_sample_btn": "创建样本",
|
| 92 |
+
"caption_title": "📝 音乐描述",
|
| 93 |
+
"caption_label": "音乐描述(可选)",
|
| 94 |
+
"caption_placeholder": "一段平和的原声吉他旋律,配有柔和的人声...",
|
| 95 |
+
"caption_info": "描述风格、流派、乐器和情绪",
|
| 96 |
+
"lyrics_title": "📝 歌词",
|
| 97 |
+
"lyrics_label": "歌词(可选)",
|
| 98 |
+
"lyrics_placeholder": "[第一段]\\n在星空下\\n我感到如此活跃...",
|
| 99 |
+
"lyrics_info": "带有结构的歌曲歌词",
|
| 100 |
+
"instrumental_label": "纯音乐",
|
| 101 |
+
"format_btn": "格式化",
|
| 102 |
+
"optional_params": "⚙️ 可选参数",
|
| 103 |
+
"vocal_language_label": "人声语言(可选)",
|
| 104 |
+
"vocal_language_info": "纯音乐使用 `unknown`",
|
| 105 |
+
"bpm_label": "BPM(可选)",
|
| 106 |
+
"bpm_info": "留空表示N/A",
|
| 107 |
+
"keyscale_label": "调性(可选)",
|
| 108 |
+
"keyscale_placeholder": "留空表示N/A",
|
| 109 |
+
"keyscale_info": "A-G, #/♭, 大调/小调",
|
| 110 |
+
"timesig_label": "拍号(可选)",
|
| 111 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 112 |
+
"duration_label": "音频时长(秒)",
|
| 113 |
+
"duration_info": "使用-1表示随机",
|
| 114 |
+
"batch_size_label": "批量大小",
|
| 115 |
+
"batch_size_info": "要生成的音频数量(最多8个)",
|
| 116 |
+
"advanced_settings": "🔧 高级设置",
|
| 117 |
+
"inference_steps_label": "DiT 推理步数",
|
| 118 |
+
"inference_steps_info": "Turbo: 最多8, Base: 最多200",
|
| 119 |
+
"guidance_scale_label": "DiT 引导比例(仅支持base模型)",
|
| 120 |
+
"guidance_scale_info": "更高的值更紧密地遵循文本",
|
| 121 |
+
"seed_label": "种子",
|
| 122 |
+
"seed_info": "批量使用逗号分隔的值",
|
| 123 |
+
"random_seed_label": "随机种子",
|
| 124 |
+
"random_seed_info": "启用以自动生成种子",
|
| 125 |
+
"audio_format_label": "音频格式",
|
| 126 |
+
"audio_format_info": "保存文件的音频格式",
|
| 127 |
+
"use_adg_label": "使用 ADG",
|
| 128 |
+
"use_adg_info": "启用角域引导",
|
| 129 |
+
"shift_label": "Shift",
|
| 130 |
+
"shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
|
| 131 |
+
"infer_method_label": "推理方法",
|
| 132 |
+
"infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
|
| 133 |
+
"custom_timesteps_label": "自定义时间步",
|
| 134 |
+
"custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
|
| 135 |
+
"cfg_interval_start": "CFG 间隔开始",
|
| 136 |
+
"cfg_interval_end": "CFG 间隔结束",
|
| 137 |
+
"lm_params_title": "🤖 LM 生成参数",
|
| 138 |
+
"lm_temperature_label": "LM 温度",
|
| 139 |
+
"lm_temperature_info": "5Hz LM温度(越高越随机)",
|
| 140 |
+
"lm_cfg_scale_label": "LM CFG 比例",
|
| 141 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = 无CFG)",
|
| 142 |
+
"lm_top_k_label": "LM Top-K",
|
| 143 |
+
"lm_top_k_info": "Top-K (0 = 禁用)",
|
| 144 |
+
"lm_top_p_label": "LM Top-P",
|
| 145 |
+
"lm_top_p_info": "Top-P (1.0 = 禁用)",
|
| 146 |
+
"lm_negative_prompt_label": "LM 负面提示",
|
| 147 |
+
"lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
|
| 148 |
+
"lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
|
| 149 |
+
"cot_metas_label": "CoT 元数据",
|
| 150 |
+
"cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
|
| 151 |
+
"cot_language_label": "CoT 语言",
|
| 152 |
+
"cot_language_info": "在CoT中生成语言(思维链)",
|
| 153 |
+
"constrained_debug_label": "约束解码调试",
|
| 154 |
+
"constrained_debug_info": "启用约束解码的调试日志(勾选以查看详细日志)",
|
| 155 |
+
"auto_score_label": "自动评分",
|
| 156 |
+
"auto_score_info": "自动计算所有生成音频的质量分数",
|
| 157 |
+
"auto_lrc_label": "自动 LRC",
|
| 158 |
+
"auto_lrc_info": "自动为所有生成的音频生成LRC歌词时间戳",
|
| 159 |
+
"lm_batch_chunk_label": "LM 批量块大小",
|
| 160 |
+
"lm_batch_chunk_info": "每个LM批量块的最大项目数(默认: 8, 受GPU内存限制)",
|
| 161 |
+
"codes_strength_label": "LM 代码强度",
|
| 162 |
+
"codes_strength_info": "控制使用LM生成代码的去噪步骤数量",
|
| 163 |
+
"cover_strength_label": "音频覆盖强度",
|
| 164 |
+
"cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
|
| 165 |
+
"score_sensitivity_label": "质量评分敏感度",
|
| 166 |
+
"score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
|
| 167 |
+
"think_label": "思考",
|
| 168 |
+
"parallel_thinking_label": "并行思考",
|
| 169 |
+
"generate_btn": "🎵 生成音乐",
|
| 170 |
+
"autogen_label": "自动生成",
|
| 171 |
+
"caption_rewrite_label": "描述重写"
|
| 172 |
+
},
|
| 173 |
+
"results": {
|
| 174 |
+
"title": "🎵 结果",
|
| 175 |
+
"generated_music": "🎵 生成的音乐(样本 {n})",
|
| 176 |
+
"send_to_src_btn": "🔗 发送到源音频",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
+
"save_btn": "💾 保存",
|
| 180 |
+
"score_btn": "📊 评分",
|
| 181 |
+
"lrc_btn": "🎵 LRC",
|
| 182 |
+
"quality_score_label": "质量分数(样本 {n})",
|
| 183 |
+
"quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
|
| 184 |
+
"codes_label": "LM 代码(样本 {n})",
|
| 185 |
+
"lrc_label": "歌词时间戳(样本 {n})",
|
| 186 |
+
"lrc_placeholder": "点击'LRC'生成时间戳",
|
| 187 |
+
"details_accordion": "📊 评分与LRC与LM代码",
|
| 188 |
+
"generation_status": "生成状态",
|
| 189 |
+
"current_batch": "当前批次",
|
| 190 |
+
"batch_indicator": "批次 {current} / {total}",
|
| 191 |
+
"next_batch_status": "下一批次状态",
|
| 192 |
+
"prev_btn": "◀ 上一个",
|
| 193 |
+
"next_btn": "下一个 ▶",
|
| 194 |
+
"restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
|
| 195 |
+
"batch_results_title": "📁 批量结果和生成详情",
|
| 196 |
+
"all_files_label": "📁 所有生成的文件(��载)",
|
| 197 |
+
"generation_details": "生成详情"
|
| 198 |
+
},
|
| 199 |
+
"messages": {
|
| 200 |
+
"no_audio_to_save": "❌ 没有要保存的音频",
|
| 201 |
+
"save_success": "✅ 已将音频和元数据保存到 {filename}",
|
| 202 |
+
"save_failed": "❌ 保存失败: {error}",
|
| 203 |
+
"no_file_selected": "⚠️ 未选择文件",
|
| 204 |
+
"params_loaded": "✅ 已从 {filename} 加载参数",
|
| 205 |
+
"invalid_json": "❌ 无效的JSON文件: {error}",
|
| 206 |
+
"load_error": "❌ 加载文件时出错: {error}",
|
| 207 |
+
"example_loaded": "📁 已从 {filename} 加载示例",
|
| 208 |
+
"example_failed": "解析JSON文件 {filename} 失败: {error}",
|
| 209 |
+
"example_error": "加载示例时出错: {error}",
|
| 210 |
+
"lm_generated": "🤖 使用LM生成的示例",
|
| 211 |
+
"lm_fallback": "使用LM生成示例失败,回退到示例目录",
|
| 212 |
+
"lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
|
| 213 |
+
"autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
|
| 214 |
+
"batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
|
| 215 |
+
"batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
|
| 216 |
+
"batch_failed": "❌ 后台生成失败: {error}",
|
| 217 |
+
"viewing_batch": "✅ 查看批次 {n}",
|
| 218 |
+
"at_first_batch": "已在第一批次",
|
| 219 |
+
"at_last_batch": "没有下一批次可用",
|
| 220 |
+
"batch_not_found": "在队列中未找到批次 {n}",
|
| 221 |
+
"no_batch_data": "没有要恢复的批次数据。",
|
| 222 |
+
"params_restored": "✅ 已从批次 {n} 恢复UI参数",
|
| 223 |
+
"scoring_failed": "❌ 错误: 未找到批次数据",
|
| 224 |
+
"no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
|
| 225 |
+
"score_failed": "❌ 评分失败: {error}",
|
| 226 |
+
"score_error": "❌ 计算分数时出错: {error}",
|
| 227 |
+
"lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
|
| 228 |
+
"lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
|
| 229 |
+
"lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
|
| 230 |
+
"lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
|
| 231 |
+
"lrc_empty_result": "⚠️ LRC生成结果为空。",
|
| 232 |
+
"empty_query": "⚠️ 请输入音乐描述。",
|
| 233 |
+
"sample_creation_failed": "❌ 创建样本失败。请重试。",
|
| 234 |
+
"sample_created": "✅ 样本已创建!检查描述和歌词,然后点击生成音乐。",
|
| 235 |
+
"simple_examples_not_found": "⚠️ 未找到简单模式示例目录。",
|
| 236 |
+
"simple_examples_empty": "⚠️ 简单模式示例中没有示例文件。",
|
| 237 |
+
"simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
|
| 238 |
+
"format_success": "✅ 描述和歌词格式化成功",
|
| 239 |
+
"format_failed": "❌ 格式化失败: {error}",
|
| 240 |
+
"skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
|
| 241 |
+
"invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
|
| 242 |
+
"timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
|
| 243 |
+
"timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
|
| 244 |
+
}
|
| 245 |
+
}
|
acestep/gradio_ui/interfaces/__init__.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Components Module
|
| 3 |
+
Contains all Gradio interface component definitions and layouts
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from acestep.gradio_ui.i18n import get_i18n, t
|
| 7 |
+
from acestep.gradio_ui.interfaces.dataset import create_dataset_section
|
| 8 |
+
from acestep.gradio_ui.interfaces.generation import create_generation_section
|
| 9 |
+
from acestep.gradio_ui.interfaces.result import create_results_section
|
| 10 |
+
from acestep.gradio_ui.interfaces.training import create_training_section
|
| 11 |
+
from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
|
| 15 |
+
"""
|
| 16 |
+
Create Gradio interface
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
dit_handler: DiT handler instance
|
| 20 |
+
llm_handler: LM handler instance
|
| 21 |
+
dataset_handler: Dataset handler instance
|
| 22 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 23 |
+
If None, service will not be pre-initialized.
|
| 24 |
+
language: UI language code ('en', 'zh', 'ja', default: 'en')
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Gradio Blocks instance
|
| 28 |
+
"""
|
| 29 |
+
# Initialize i18n with selected language
|
| 30 |
+
i18n = get_i18n(language)
|
| 31 |
+
|
| 32 |
+
with gr.Blocks(
|
| 33 |
+
title=t("app.title"),
|
| 34 |
+
theme=gr.themes.Soft(),
|
| 35 |
+
css="""
|
| 36 |
+
.main-header {
|
| 37 |
+
text-align: center;
|
| 38 |
+
margin-bottom: 2rem;
|
| 39 |
+
}
|
| 40 |
+
.section-header {
|
| 41 |
+
background: linear-gradient(90deg, #4CAF50, #45a049);
|
| 42 |
+
color: white;
|
| 43 |
+
padding: 10px;
|
| 44 |
+
border-radius: 5px;
|
| 45 |
+
margin: 10px 0;
|
| 46 |
+
}
|
| 47 |
+
.lm-hints-row {
|
| 48 |
+
align-items: stretch;
|
| 49 |
+
}
|
| 50 |
+
.lm-hints-col {
|
| 51 |
+
display: flex;
|
| 52 |
+
}
|
| 53 |
+
.lm-hints-col > div {
|
| 54 |
+
flex: 1;
|
| 55 |
+
display: flex;
|
| 56 |
+
}
|
| 57 |
+
.lm-hints-btn button {
|
| 58 |
+
height: 100%;
|
| 59 |
+
width: 100%;
|
| 60 |
+
}
|
| 61 |
+
"""
|
| 62 |
+
) as demo:
|
| 63 |
+
|
| 64 |
+
gr.HTML(f"""
|
| 65 |
+
<div class="main-header">
|
| 66 |
+
<h1>{t("app.title")}</h1>
|
| 67 |
+
<p>{t("app.subtitle")}</p>
|
| 68 |
+
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 10px 20px; border-radius: 8px; text-align: center; margin: 8px auto; max-width: 600px;">
|
| 69 |
+
<span style="color: white; font-size: 15px;">
|
| 70 |
+
🚀 Want faster & more stable experience? Try
|
| 71 |
+
<a href="https://acemusic.ai" target="_blank" style="color: #ffd700; font-weight: bold; text-decoration: underline;">acemusic.ai</a>
|
| 72 |
+
— 100% free!
|
| 73 |
+
</span>
|
| 74 |
+
</div>
|
| 75 |
+
<p style="margin-top: 0.5rem;">
|
| 76 |
+
<a href="https://ace-step.github.io/ace-step-v1.5.github.io/" target="_blank">Project</a> |
|
| 77 |
+
<a href="https://huggingface.co/collections/ACE-Step/ace-step-15" target="_blank">Hugging Face</a> |
|
| 78 |
+
<a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5" target="_blank">ModelScope</a> |
|
| 79 |
+
<a href="https://github.com/ACE-Step/ACE-Step-1.5" target="_blank">GitHub</a> |
|
| 80 |
+
<a href="https://discord.gg/PeWDxrkdj7" target="_blank">Discord</a> |
|
| 81 |
+
<a href="https://arxiv.org/abs/2602.00744" target="_blank">Technical Report</a>
|
| 82 |
+
</p>
|
| 83 |
+
</div>
|
| 84 |
+
""")
|
| 85 |
+
|
| 86 |
+
# Dataset Explorer Section
|
| 87 |
+
dataset_section = create_dataset_section(dataset_handler)
|
| 88 |
+
|
| 89 |
+
# Generation Section (pass init_params and language to support pre-initialization)
|
| 90 |
+
generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
|
| 91 |
+
|
| 92 |
+
# Results Section
|
| 93 |
+
results_section = create_results_section(dit_handler)
|
| 94 |
+
|
| 95 |
+
# Training Section (LoRA training and dataset builder)
|
| 96 |
+
# Pass init_params to support hiding in service mode
|
| 97 |
+
training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
|
| 98 |
+
|
| 99 |
+
# Connect event handlers (pass init_params for multi-model support)
|
| 100 |
+
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=init_params)
|
| 101 |
+
|
| 102 |
+
# Connect training event handlers
|
| 103 |
+
setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
|
| 104 |
+
|
| 105 |
+
return demo
|
acestep/gradio_ui/interfaces/dataset.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Dataset Section Module
|
| 3 |
+
Contains dataset explorer section component definitions
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_dataset_section(dataset_handler) -> dict:
|
| 9 |
+
"""Create dataset explorer section"""
|
| 10 |
+
with gr.Accordion("📊 Dataset Explorer", open=False, visible=False):
|
| 11 |
+
with gr.Row(equal_height=True):
|
| 12 |
+
dataset_type = gr.Dropdown(
|
| 13 |
+
choices=["train", "test"],
|
| 14 |
+
value="train",
|
| 15 |
+
label="Dataset",
|
| 16 |
+
info="Choose dataset to explore",
|
| 17 |
+
scale=2
|
| 18 |
+
)
|
| 19 |
+
import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
|
| 20 |
+
|
| 21 |
+
search_type = gr.Dropdown(
|
| 22 |
+
choices=["keys", "idx", "random"],
|
| 23 |
+
value="random",
|
| 24 |
+
label="Search Type",
|
| 25 |
+
info="How to find items",
|
| 26 |
+
scale=1
|
| 27 |
+
)
|
| 28 |
+
search_value = gr.Textbox(
|
| 29 |
+
label="Search Value",
|
| 30 |
+
placeholder="Enter keys or index (leave empty for random)",
|
| 31 |
+
info="Keys: exact match, Index: 0 to dataset size-1",
|
| 32 |
+
scale=2
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
instruction_display = gr.Textbox(
|
| 36 |
+
label="📝 Instruction",
|
| 37 |
+
interactive=False,
|
| 38 |
+
placeholder="No instruction available",
|
| 39 |
+
lines=1
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
repaint_viz_plot = gr.Plot()
|
| 43 |
+
|
| 44 |
+
with gr.Accordion("📋 Item Metadata (JSON)", open=False):
|
| 45 |
+
item_info_json = gr.Code(
|
| 46 |
+
label="Complete Item Information",
|
| 47 |
+
language="json",
|
| 48 |
+
interactive=False,
|
| 49 |
+
lines=15
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with gr.Row(equal_height=True):
|
| 53 |
+
item_src_audio = gr.Audio(
|
| 54 |
+
label="Source Audio",
|
| 55 |
+
type="filepath",
|
| 56 |
+
interactive=False,
|
| 57 |
+
scale=8
|
| 58 |
+
)
|
| 59 |
+
get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
|
| 60 |
+
|
| 61 |
+
with gr.Row(equal_height=True):
|
| 62 |
+
item_target_audio = gr.Audio(
|
| 63 |
+
label="Target Audio",
|
| 64 |
+
type="filepath",
|
| 65 |
+
interactive=False,
|
| 66 |
+
scale=8
|
| 67 |
+
)
|
| 68 |
+
item_refer_audio = gr.Audio(
|
| 69 |
+
label="Reference Audio",
|
| 70 |
+
type="filepath",
|
| 71 |
+
interactive=False,
|
| 72 |
+
scale=2
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
with gr.Row():
|
| 76 |
+
use_src_checkbox = gr.Checkbox(
|
| 77 |
+
label="Use Source Audio from Dataset",
|
| 78 |
+
value=True,
|
| 79 |
+
info="Check to use the source audio from dataset"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
|
| 83 |
+
auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
"dataset_type": dataset_type,
|
| 87 |
+
"import_dataset_btn": import_dataset_btn,
|
| 88 |
+
"search_type": search_type,
|
| 89 |
+
"search_value": search_value,
|
| 90 |
+
"instruction_display": instruction_display,
|
| 91 |
+
"repaint_viz_plot": repaint_viz_plot,
|
| 92 |
+
"item_info_json": item_info_json,
|
| 93 |
+
"item_src_audio": item_src_audio,
|
| 94 |
+
"get_item_btn": get_item_btn,
|
| 95 |
+
"item_target_audio": item_target_audio,
|
| 96 |
+
"item_refer_audio": item_refer_audio,
|
| 97 |
+
"use_src_checkbox": use_src_checkbox,
|
| 98 |
+
"data_status": data_status,
|
| 99 |
+
"auto_fill_btn": auto_fill_btn,
|
| 100 |
+
}
|
| 101 |
+
|
acestep/gradio_ui/interfaces/generation.py
ADDED
|
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Generation Section Module
|
| 3 |
+
Contains generation section component definitions - Simplified UI
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from acestep.constants import (
|
| 7 |
+
VALID_LANGUAGES,
|
| 8 |
+
TRACK_NAMES,
|
| 9 |
+
TASK_TYPES_TURBO,
|
| 10 |
+
TASK_TYPES_BASE,
|
| 11 |
+
DEFAULT_DIT_INSTRUCTION,
|
| 12 |
+
)
|
| 13 |
+
from acestep.gradio_ui.i18n import t
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
|
| 17 |
+
"""Create generation section with simplified UI
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
dit_handler: DiT handler instance
|
| 21 |
+
llm_handler: LM handler instance
|
| 22 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 23 |
+
If None, service will not be pre-initialized.
|
| 24 |
+
language: UI language code ('en', 'zh', 'ja')
|
| 25 |
+
"""
|
| 26 |
+
# Check if service is pre-initialized
|
| 27 |
+
service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
|
| 28 |
+
|
| 29 |
+
# Check if running in service mode (restricted UI)
|
| 30 |
+
service_mode = init_params is not None and init_params.get('service_mode', False)
|
| 31 |
+
|
| 32 |
+
# Get current language from init_params if available
|
| 33 |
+
current_language = init_params.get('language', language) if init_params else language
|
| 34 |
+
|
| 35 |
+
# Get available models
|
| 36 |
+
available_dit_models = init_params.get('available_dit_models', []) if init_params else []
|
| 37 |
+
current_model_value = init_params.get('config_path', '') if init_params else ''
|
| 38 |
+
show_model_selector = len(available_dit_models) > 1
|
| 39 |
+
|
| 40 |
+
with gr.Group():
|
| 41 |
+
# ==================== Service Configuration (Hidden in service mode) ====================
|
| 42 |
+
accordion_open = not service_pre_initialized
|
| 43 |
+
accordion_visible = not service_pre_initialized
|
| 44 |
+
with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
|
| 45 |
+
# Language selector at the top
|
| 46 |
+
with gr.Row():
|
| 47 |
+
language_dropdown = gr.Dropdown(
|
| 48 |
+
choices=[
|
| 49 |
+
("English", "en"),
|
| 50 |
+
("中文", "zh"),
|
| 51 |
+
("日本語", "ja"),
|
| 52 |
+
],
|
| 53 |
+
value=current_language,
|
| 54 |
+
label=t("service.language_label"),
|
| 55 |
+
info=t("service.language_info"),
|
| 56 |
+
scale=1,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
with gr.Row(equal_height=True):
|
| 60 |
+
with gr.Column(scale=4):
|
| 61 |
+
checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
|
| 62 |
+
checkpoint_dropdown = gr.Dropdown(
|
| 63 |
+
label=t("service.checkpoint_label"),
|
| 64 |
+
choices=dit_handler.get_available_checkpoints(),
|
| 65 |
+
value=checkpoint_value,
|
| 66 |
+
info=t("service.checkpoint_info")
|
| 67 |
+
)
|
| 68 |
+
with gr.Column(scale=1, min_width=90):
|
| 69 |
+
refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
|
| 70 |
+
|
| 71 |
+
with gr.Row():
|
| 72 |
+
available_models = dit_handler.get_available_acestep_v15_models()
|
| 73 |
+
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
|
| 74 |
+
config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
|
| 75 |
+
config_path = gr.Dropdown(
|
| 76 |
+
label=t("service.model_path_label"),
|
| 77 |
+
choices=available_models,
|
| 78 |
+
value=config_path_value,
|
| 79 |
+
info=t("service.model_path_info")
|
| 80 |
+
)
|
| 81 |
+
device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
|
| 82 |
+
device = gr.Dropdown(
|
| 83 |
+
choices=["auto", "cuda", "cpu"],
|
| 84 |
+
value=device_value,
|
| 85 |
+
label=t("service.device_label"),
|
| 86 |
+
info=t("service.device_info")
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
with gr.Row():
|
| 90 |
+
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 91 |
+
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
|
| 92 |
+
lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
|
| 93 |
+
lm_model_path = gr.Dropdown(
|
| 94 |
+
label=t("service.lm_model_path_label"),
|
| 95 |
+
choices=available_lm_models,
|
| 96 |
+
value=lm_model_path_value,
|
| 97 |
+
info=t("service.lm_model_path_info")
|
| 98 |
+
)
|
| 99 |
+
backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
|
| 100 |
+
backend_dropdown = gr.Dropdown(
|
| 101 |
+
choices=["vllm", "pt"],
|
| 102 |
+
value=backend_value,
|
| 103 |
+
label=t("service.backend_label"),
|
| 104 |
+
info=t("service.backend_info")
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
with gr.Row():
|
| 108 |
+
init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
|
| 109 |
+
init_llm_checkbox = gr.Checkbox(
|
| 110 |
+
label=t("service.init_llm_label"),
|
| 111 |
+
value=init_llm_value,
|
| 112 |
+
info=t("service.init_llm_info"),
|
| 113 |
+
)
|
| 114 |
+
flash_attn_available = dit_handler.is_flash_attention_available()
|
| 115 |
+
use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
|
| 116 |
+
use_flash_attention_checkbox = gr.Checkbox(
|
| 117 |
+
label=t("service.flash_attention_label"),
|
| 118 |
+
value=use_flash_attention_value,
|
| 119 |
+
interactive=flash_attn_available,
|
| 120 |
+
info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
|
| 121 |
+
)
|
| 122 |
+
offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
|
| 123 |
+
offload_to_cpu_checkbox = gr.Checkbox(
|
| 124 |
+
label=t("service.offload_cpu_label"),
|
| 125 |
+
value=offload_to_cpu_value,
|
| 126 |
+
info=t("service.offload_cpu_info")
|
| 127 |
+
)
|
| 128 |
+
offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
|
| 129 |
+
offload_dit_to_cpu_checkbox = gr.Checkbox(
|
| 130 |
+
label=t("service.offload_dit_cpu_label"),
|
| 131 |
+
value=offload_dit_to_cpu_value,
|
| 132 |
+
info=t("service.offload_dit_cpu_info")
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
|
| 136 |
+
init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
|
| 137 |
+
init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
|
| 138 |
+
|
| 139 |
+
# LoRA Configuration Section
|
| 140 |
+
gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
|
| 141 |
+
with gr.Row():
|
| 142 |
+
lora_path = gr.Textbox(
|
| 143 |
+
label="LoRA Path",
|
| 144 |
+
placeholder="./lora_output/final/adapter",
|
| 145 |
+
info="Path to trained LoRA adapter directory",
|
| 146 |
+
scale=3,
|
| 147 |
+
)
|
| 148 |
+
load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
|
| 149 |
+
unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
|
| 150 |
+
with gr.Row():
|
| 151 |
+
use_lora_checkbox = gr.Checkbox(
|
| 152 |
+
label="Use LoRA",
|
| 153 |
+
value=False,
|
| 154 |
+
info="Enable LoRA adapter for inference",
|
| 155 |
+
scale=1,
|
| 156 |
+
)
|
| 157 |
+
lora_status = gr.Textbox(
|
| 158 |
+
label="LoRA Status",
|
| 159 |
+
value="No LoRA loaded",
|
| 160 |
+
interactive=False,
|
| 161 |
+
scale=2,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# ==================== Model Selector (Top, only when multiple models) ====================
|
| 165 |
+
with gr.Row(visible=show_model_selector):
|
| 166 |
+
dit_model_selector = gr.Dropdown(
|
| 167 |
+
choices=available_dit_models,
|
| 168 |
+
value=current_model_value,
|
| 169 |
+
label="models",
|
| 170 |
+
scale=1,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Hidden dropdown when only one model (for event handler compatibility)
|
| 174 |
+
if not show_model_selector:
|
| 175 |
+
dit_model_selector = gr.Dropdown(
|
| 176 |
+
choices=available_dit_models if available_dit_models else [current_model_value],
|
| 177 |
+
value=current_model_value,
|
| 178 |
+
visible=False,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# ==================== Generation Mode (4 modes) ====================
|
| 182 |
+
gr.HTML("<div style='background: #4a5568; color: white; padding: 8px 16px; border-radius: 4px; font-weight: bold;'>Generation Mode</div>")
|
| 183 |
+
with gr.Row():
|
| 184 |
+
generation_mode = gr.Radio(
|
| 185 |
+
choices=[
|
| 186 |
+
("Simple", "simple"),
|
| 187 |
+
("Custom", "custom"),
|
| 188 |
+
("Cover", "cover"),
|
| 189 |
+
("Repaint", "repaint"),
|
| 190 |
+
],
|
| 191 |
+
value="custom",
|
| 192 |
+
label="",
|
| 193 |
+
show_label=False,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# ==================== Simple Mode Group ====================
|
| 197 |
+
with gr.Column(visible=False) as simple_mode_group:
|
| 198 |
+
# Row: Song Description + Vocal Language + Random button
|
| 199 |
+
with gr.Row(equal_height=True):
|
| 200 |
+
simple_query_input = gr.Textbox(
|
| 201 |
+
label=t("generation.simple_query_label"),
|
| 202 |
+
placeholder=t("generation.simple_query_placeholder"),
|
| 203 |
+
lines=2,
|
| 204 |
+
info=t("generation.simple_query_info"),
|
| 205 |
+
scale=10,
|
| 206 |
+
)
|
| 207 |
+
simple_vocal_language = gr.Dropdown(
|
| 208 |
+
choices=VALID_LANGUAGES,
|
| 209 |
+
value="unknown",
|
| 210 |
+
allow_custom_value=True,
|
| 211 |
+
label=t("generation.simple_vocal_language_label"),
|
| 212 |
+
interactive=True,
|
| 213 |
+
info="use unknown for instrumental",
|
| 214 |
+
scale=2,
|
| 215 |
+
)
|
| 216 |
+
with gr.Column(scale=1, min_width=60):
|
| 217 |
+
random_desc_btn = gr.Button(
|
| 218 |
+
"🎲",
|
| 219 |
+
variant="primary",
|
| 220 |
+
size="lg",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Hidden components (kept for compatibility but not shown)
|
| 224 |
+
simple_instrumental_checkbox = gr.Checkbox(
|
| 225 |
+
label=t("generation.instrumental_label"),
|
| 226 |
+
value=False,
|
| 227 |
+
visible=False,
|
| 228 |
+
)
|
| 229 |
+
create_sample_btn = gr.Button(
|
| 230 |
+
t("generation.create_sample_btn"),
|
| 231 |
+
variant="primary",
|
| 232 |
+
size="lg",
|
| 233 |
+
visible=False,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# State to track if sample has been created in Simple mode
|
| 237 |
+
simple_sample_created = gr.State(value=False)
|
| 238 |
+
|
| 239 |
+
# ==================== Source Audio (for Cover/Repaint) ====================
|
| 240 |
+
# This is shown above the main content for Cover and Repaint modes
|
| 241 |
+
with gr.Column(visible=False) as src_audio_group:
|
| 242 |
+
with gr.Row(equal_height=True):
|
| 243 |
+
# Source Audio - scale=10 to match (refer_audio=2 + prompt/lyrics=8)
|
| 244 |
+
src_audio = gr.Audio(
|
| 245 |
+
label="Source Audio",
|
| 246 |
+
type="filepath",
|
| 247 |
+
scale=10,
|
| 248 |
+
)
|
| 249 |
+
# Process button - scale=1 to align with random button
|
| 250 |
+
with gr.Column(scale=1, min_width=80):
|
| 251 |
+
process_src_btn = gr.Button(
|
| 252 |
+
"Analyze",
|
| 253 |
+
variant="secondary",
|
| 254 |
+
size="lg",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Hidden Audio Codes storage (needed internally but not displayed)
|
| 258 |
+
text2music_audio_code_string = gr.Textbox(
|
| 259 |
+
label="Audio Codes",
|
| 260 |
+
visible=False,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# ==================== Custom/Cover/Repaint Mode Content ====================
|
| 264 |
+
with gr.Column() as custom_mode_content:
|
| 265 |
+
with gr.Row(equal_height=True):
|
| 266 |
+
# Left: Reference Audio
|
| 267 |
+
with gr.Column(scale=2, min_width=200):
|
| 268 |
+
reference_audio = gr.Audio(
|
| 269 |
+
label="Reference Audio (optional)",
|
| 270 |
+
type="filepath",
|
| 271 |
+
show_label=True,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Middle: Prompt + Lyrics + Format button
|
| 275 |
+
with gr.Column(scale=8):
|
| 276 |
+
# Row 1: Prompt and Lyrics
|
| 277 |
+
with gr.Row(equal_height=True):
|
| 278 |
+
captions = gr.Textbox(
|
| 279 |
+
label="Prompt",
|
| 280 |
+
placeholder="Describe the music style, mood, instruments...",
|
| 281 |
+
lines=12,
|
| 282 |
+
max_lines=12,
|
| 283 |
+
scale=1,
|
| 284 |
+
)
|
| 285 |
+
lyrics = gr.Textbox(
|
| 286 |
+
label="Lyrics",
|
| 287 |
+
placeholder="Enter lyrics here... Use [Verse], [Chorus] etc. for structure",
|
| 288 |
+
lines=12,
|
| 289 |
+
max_lines=12,
|
| 290 |
+
scale=1,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Row 2: Format button (only below Prompt and Lyrics)
|
| 294 |
+
format_btn = gr.Button(
|
| 295 |
+
"Format",
|
| 296 |
+
variant="secondary",
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Right: Random button
|
| 300 |
+
with gr.Column(scale=1, min_width=60):
|
| 301 |
+
sample_btn = gr.Button(
|
| 302 |
+
"🎲",
|
| 303 |
+
variant="primary",
|
| 304 |
+
size="lg",
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Placeholder for removed audio_uploads_accordion (for compatibility)
|
| 308 |
+
audio_uploads_accordion = gr.Column(visible=False)
|
| 309 |
+
|
| 310 |
+
# Legacy cover_mode_group (hidden, for backward compatibility)
|
| 311 |
+
cover_mode_group = gr.Column(visible=False)
|
| 312 |
+
# Legacy convert button (hidden, for backward compatibility)
|
| 313 |
+
convert_src_to_codes_btn = gr.Button("Convert to Codes", visible=False)
|
| 314 |
+
|
| 315 |
+
# ==================== Repaint Mode: Source + Time Range ====================
|
| 316 |
+
with gr.Column(visible=False) as repainting_group:
|
| 317 |
+
with gr.Row():
|
| 318 |
+
repainting_start = gr.Number(
|
| 319 |
+
label="Start (seconds)",
|
| 320 |
+
value=0.0,
|
| 321 |
+
step=0.1,
|
| 322 |
+
scale=1,
|
| 323 |
+
)
|
| 324 |
+
repainting_end = gr.Number(
|
| 325 |
+
label="End (seconds, -1 for end)",
|
| 326 |
+
value=-1,
|
| 327 |
+
minimum=-1,
|
| 328 |
+
step=0.1,
|
| 329 |
+
scale=1,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# ==================== Optional Parameters ====================
|
| 333 |
+
with gr.Accordion("⚙️ Optional Parameters", open=False, visible=False) as optional_params_accordion:
|
| 334 |
+
pass
|
| 335 |
+
|
| 336 |
+
# ==================== Advanced Settings ====================
|
| 337 |
+
with gr.Accordion("🔧 Advanced Settings", open=False) as advanced_options_accordion:
|
| 338 |
+
with gr.Row():
|
| 339 |
+
bpm = gr.Number(
|
| 340 |
+
label="BPM (optional)",
|
| 341 |
+
value=0,
|
| 342 |
+
step=1,
|
| 343 |
+
info="leave empty for N/A",
|
| 344 |
+
scale=1,
|
| 345 |
+
)
|
| 346 |
+
key_scale = gr.Textbox(
|
| 347 |
+
label="Key Signature (optional)",
|
| 348 |
+
placeholder="Leave empty for N/A",
|
| 349 |
+
value="",
|
| 350 |
+
info="A-G, #/♭, major/minor",
|
| 351 |
+
scale=1,
|
| 352 |
+
)
|
| 353 |
+
time_signature = gr.Dropdown(
|
| 354 |
+
choices=["", "2", "3", "4"],
|
| 355 |
+
value="",
|
| 356 |
+
label="Time Signature (optional)",
|
| 357 |
+
allow_custom_value=True,
|
| 358 |
+
info="2/4, 3/4, 4/4...",
|
| 359 |
+
scale=1,
|
| 360 |
+
)
|
| 361 |
+
audio_duration = gr.Number(
|
| 362 |
+
label="Audio Duration (seconds)",
|
| 363 |
+
value=-1,
|
| 364 |
+
minimum=-1,
|
| 365 |
+
maximum=600.0,
|
| 366 |
+
step=1,
|
| 367 |
+
info="Use -1 for auto, or 10-600 seconds",
|
| 368 |
+
scale=1,
|
| 369 |
+
)
|
| 370 |
+
vocal_language = gr.Dropdown(
|
| 371 |
+
choices=VALID_LANGUAGES,
|
| 372 |
+
value="unknown",
|
| 373 |
+
label="Vocal Language",
|
| 374 |
+
allow_custom_value=True,
|
| 375 |
+
info="use `unknown` for instrumental",
|
| 376 |
+
scale=1,
|
| 377 |
+
)
|
| 378 |
+
batch_size_input = gr.Number(
|
| 379 |
+
label="batch size",
|
| 380 |
+
info="max 8",
|
| 381 |
+
value=2,
|
| 382 |
+
minimum=1,
|
| 383 |
+
maximum=8,
|
| 384 |
+
step=1,
|
| 385 |
+
scale=1,
|
| 386 |
+
interactive=False,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Row 1: DiT Inference Steps, Seed, Audio Format
|
| 390 |
+
with gr.Row():
|
| 391 |
+
inference_steps = gr.Slider(
|
| 392 |
+
minimum=1,
|
| 393 |
+
maximum=20,
|
| 394 |
+
value=8,
|
| 395 |
+
step=1,
|
| 396 |
+
label="DiT Inference Steps",
|
| 397 |
+
info="Turbo: max 8, Base: max 200",
|
| 398 |
+
)
|
| 399 |
+
seed = gr.Textbox(
|
| 400 |
+
label="Seed",
|
| 401 |
+
value="-1",
|
| 402 |
+
info="Use comma-separated values for batches",
|
| 403 |
+
)
|
| 404 |
+
audio_format = gr.Dropdown(
|
| 405 |
+
choices=["mp3", "flac"],
|
| 406 |
+
value="mp3",
|
| 407 |
+
label="Audio Format",
|
| 408 |
+
info="Audio format for saved files",
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# Row 2: Shift, Random Seed, Inference Method
|
| 412 |
+
with gr.Row():
|
| 413 |
+
shift = gr.Slider(
|
| 414 |
+
minimum=1.0,
|
| 415 |
+
maximum=5.0,
|
| 416 |
+
value=3.0,
|
| 417 |
+
step=0.1,
|
| 418 |
+
label="Shift",
|
| 419 |
+
info="Timestep shift factor for base models (range 1.0-5.0, default 3.0). Not effective for turbo models.",
|
| 420 |
+
)
|
| 421 |
+
random_seed_checkbox = gr.Checkbox(
|
| 422 |
+
label="Random Seed",
|
| 423 |
+
value=True,
|
| 424 |
+
info="Enable to auto-generate seeds",
|
| 425 |
+
)
|
| 426 |
+
infer_method = gr.Dropdown(
|
| 427 |
+
choices=["ode", "sde"],
|
| 428 |
+
value="ode",
|
| 429 |
+
label="Inference Method",
|
| 430 |
+
info="Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Row 3: Custom Timesteps (full width)
|
| 434 |
+
custom_timesteps = gr.Textbox(
|
| 435 |
+
label="Custom Timesteps",
|
| 436 |
+
placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
|
| 437 |
+
value="",
|
| 438 |
+
info="Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Section: LM Generation Parameters
|
| 442 |
+
gr.HTML("<h4>🎵 LM Generation Parameters</h4>")
|
| 443 |
+
|
| 444 |
+
# Row 4: LM Temperature, LM CFG Scale, LM Top-K, LM Top-P
|
| 445 |
+
with gr.Row():
|
| 446 |
+
lm_temperature = gr.Slider(
|
| 447 |
+
minimum=0.0,
|
| 448 |
+
maximum=2.0,
|
| 449 |
+
value=0.85,
|
| 450 |
+
step=0.05,
|
| 451 |
+
label="LM Temperature",
|
| 452 |
+
info="5Hz LM temperature (higher = more random)",
|
| 453 |
+
)
|
| 454 |
+
lm_cfg_scale = gr.Slider(
|
| 455 |
+
minimum=1.0,
|
| 456 |
+
maximum=3.0,
|
| 457 |
+
value=2.0,
|
| 458 |
+
step=0.1,
|
| 459 |
+
label="LM CFG Scale",
|
| 460 |
+
info="5Hz LM CFG (1.0 = no CFG)",
|
| 461 |
+
)
|
| 462 |
+
lm_top_k = gr.Slider(
|
| 463 |
+
minimum=0,
|
| 464 |
+
maximum=100,
|
| 465 |
+
value=0,
|
| 466 |
+
step=1,
|
| 467 |
+
label="LM Top-K",
|
| 468 |
+
info="Top-k (0 = disabled)",
|
| 469 |
+
)
|
| 470 |
+
lm_top_p = gr.Slider(
|
| 471 |
+
minimum=0.0,
|
| 472 |
+
maximum=1.0,
|
| 473 |
+
value=0.9,
|
| 474 |
+
step=0.01,
|
| 475 |
+
label="LM Top-P",
|
| 476 |
+
info="Top-p (1.0 = disabled)",
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Row 5: LM Negative Prompt (full width)
|
| 480 |
+
lm_negative_prompt = gr.Textbox(
|
| 481 |
+
label="LM Negative Prompt",
|
| 482 |
+
value="NO USER INPUT",
|
| 483 |
+
placeholder="Things to avoid in generation...",
|
| 484 |
+
lines=2,
|
| 485 |
+
info="Negative prompt (use when LM CFG Scale > 1.0)",
|
| 486 |
+
)
|
| 487 |
+
# audio_cover_strength remains hidden for now
|
| 488 |
+
audio_cover_strength = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, visible=False)
|
| 489 |
+
|
| 490 |
+
# Note: audio_duration, bpm, key_scale, time_signature are now visible in Optional Parameters
|
| 491 |
+
# ==================== Generate Button Row ====================
|
| 492 |
+
generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
|
| 493 |
+
with gr.Row(equal_height=True):
|
| 494 |
+
# Left: Thinking and Instrumental checkboxes
|
| 495 |
+
with gr.Column(scale=1, min_width=120):
|
| 496 |
+
think_checkbox = gr.Checkbox(
|
| 497 |
+
label="Thinking",
|
| 498 |
+
value=True,
|
| 499 |
+
)
|
| 500 |
+
instrumental_checkbox = gr.Checkbox(
|
| 501 |
+
label="Instrumental",
|
| 502 |
+
value=False,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
# Center: Generate button
|
| 506 |
+
with gr.Column(scale=4):
|
| 507 |
+
generate_btn = gr.Button(
|
| 508 |
+
"🎵 Generate Music",
|
| 509 |
+
variant="primary",
|
| 510 |
+
size="lg",
|
| 511 |
+
interactive=generate_btn_interactive,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Right: auto_score, auto_lrc
|
| 515 |
+
with gr.Column(scale=1, min_width=120):
|
| 516 |
+
auto_score = gr.Checkbox(
|
| 517 |
+
label="Get Scores",
|
| 518 |
+
value=False,
|
| 519 |
+
)
|
| 520 |
+
auto_lrc = gr.Checkbox(
|
| 521 |
+
label="Get LRC",
|
| 522 |
+
value=False,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# ==================== Hidden Components (for internal use) ====================
|
| 526 |
+
# These are needed for event handlers but not shown in UI
|
| 527 |
+
|
| 528 |
+
# Task type (set automatically based on generation_mode)
|
| 529 |
+
actual_model = init_params.get('config_path', 'acestep-v15-turbo') if service_pre_initialized else 'acestep-v15-turbo'
|
| 530 |
+
actual_model_lower = (actual_model or "").lower()
|
| 531 |
+
if "turbo" in actual_model_lower:
|
| 532 |
+
initial_task_choices = TASK_TYPES_TURBO
|
| 533 |
+
else:
|
| 534 |
+
initial_task_choices = TASK_TYPES_BASE
|
| 535 |
+
|
| 536 |
+
task_type = gr.Dropdown(
|
| 537 |
+
choices=initial_task_choices,
|
| 538 |
+
value="text2music",
|
| 539 |
+
visible=False,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
instruction_display_gen = gr.Textbox(
|
| 543 |
+
value=DEFAULT_DIT_INSTRUCTION,
|
| 544 |
+
visible=False,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
track_name = gr.Dropdown(
|
| 548 |
+
choices=TRACK_NAMES,
|
| 549 |
+
value=None,
|
| 550 |
+
visible=False,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
complete_track_classes = gr.CheckboxGroup(
|
| 554 |
+
choices=TRACK_NAMES,
|
| 555 |
+
visible=False,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Note: lyrics, vocal_language, instrumental_checkbox, format_btn are now visible in custom_mode_content
|
| 559 |
+
|
| 560 |
+
# Hidden advanced settings (keep defaults)
|
| 561 |
+
# Note: Most parameters are now visible in Advanced Settings section above
|
| 562 |
+
guidance_scale = gr.Slider(value=7.0, visible=False)
|
| 563 |
+
use_adg = gr.Checkbox(value=False, visible=False)
|
| 564 |
+
cfg_interval_start = gr.Slider(value=0.0, visible=False)
|
| 565 |
+
cfg_interval_end = gr.Slider(value=1.0, visible=False)
|
| 566 |
+
|
| 567 |
+
# LM parameters (remaining hidden ones)
|
| 568 |
+
use_cot_metas = gr.Checkbox(value=True, visible=False)
|
| 569 |
+
use_cot_caption = gr.Checkbox(value=True, visible=False)
|
| 570 |
+
use_cot_language = gr.Checkbox(value=True, visible=False)
|
| 571 |
+
constrained_decoding_debug = gr.Checkbox(value=False, visible=False)
|
| 572 |
+
allow_lm_batch = gr.Checkbox(value=True, visible=False)
|
| 573 |
+
lm_batch_chunk_size = gr.Number(value=8, visible=False)
|
| 574 |
+
score_scale = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, visible=False)
|
| 575 |
+
autogen_checkbox = gr.Checkbox(value=False, visible=False)
|
| 576 |
+
|
| 577 |
+
# Transcribe button (hidden)
|
| 578 |
+
transcribe_btn = gr.Button(value="Transcribe", visible=False)
|
| 579 |
+
text2music_audio_codes_group = gr.Group(visible=False)
|
| 580 |
+
|
| 581 |
+
# Note: format_btn is now visible in custom_mode_content
|
| 582 |
+
|
| 583 |
+
# Load file button (hidden for now)
|
| 584 |
+
load_file = gr.UploadButton(
|
| 585 |
+
label="Load",
|
| 586 |
+
file_types=[".json"],
|
| 587 |
+
file_count="single",
|
| 588 |
+
visible=False,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# Caption/Lyrics accordions (not used in new UI but needed for compatibility)
|
| 592 |
+
caption_accordion = gr.Accordion("Caption", visible=False)
|
| 593 |
+
lyrics_accordion = gr.Accordion("Lyrics", visible=False)
|
| 594 |
+
# Note: optional_params_accordion is now visible above
|
| 595 |
+
|
| 596 |
+
return {
|
| 597 |
+
"service_config_accordion": service_config_accordion,
|
| 598 |
+
"language_dropdown": language_dropdown,
|
| 599 |
+
"checkpoint_dropdown": checkpoint_dropdown,
|
| 600 |
+
"refresh_btn": refresh_btn,
|
| 601 |
+
"config_path": config_path,
|
| 602 |
+
"device": device,
|
| 603 |
+
"init_btn": init_btn,
|
| 604 |
+
"init_status": init_status,
|
| 605 |
+
"lm_model_path": lm_model_path,
|
| 606 |
+
"init_llm_checkbox": init_llm_checkbox,
|
| 607 |
+
"backend_dropdown": backend_dropdown,
|
| 608 |
+
"use_flash_attention_checkbox": use_flash_attention_checkbox,
|
| 609 |
+
"offload_to_cpu_checkbox": offload_to_cpu_checkbox,
|
| 610 |
+
"offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
|
| 611 |
+
# LoRA components
|
| 612 |
+
"lora_path": lora_path,
|
| 613 |
+
"load_lora_btn": load_lora_btn,
|
| 614 |
+
"unload_lora_btn": unload_lora_btn,
|
| 615 |
+
"use_lora_checkbox": use_lora_checkbox,
|
| 616 |
+
"lora_status": lora_status,
|
| 617 |
+
# DiT model selector
|
| 618 |
+
"dit_model_selector": dit_model_selector,
|
| 619 |
+
"task_type": task_type,
|
| 620 |
+
"instruction_display_gen": instruction_display_gen,
|
| 621 |
+
"track_name": track_name,
|
| 622 |
+
"complete_track_classes": complete_track_classes,
|
| 623 |
+
"audio_uploads_accordion": audio_uploads_accordion,
|
| 624 |
+
"reference_audio": reference_audio,
|
| 625 |
+
"src_audio": src_audio,
|
| 626 |
+
"convert_src_to_codes_btn": convert_src_to_codes_btn,
|
| 627 |
+
"text2music_audio_code_string": text2music_audio_code_string,
|
| 628 |
+
"transcribe_btn": transcribe_btn,
|
| 629 |
+
"text2music_audio_codes_group": text2music_audio_codes_group,
|
| 630 |
+
"lm_temperature": lm_temperature,
|
| 631 |
+
"lm_cfg_scale": lm_cfg_scale,
|
| 632 |
+
"lm_top_k": lm_top_k,
|
| 633 |
+
"lm_top_p": lm_top_p,
|
| 634 |
+
"lm_negative_prompt": lm_negative_prompt,
|
| 635 |
+
"use_cot_metas": use_cot_metas,
|
| 636 |
+
"use_cot_caption": use_cot_caption,
|
| 637 |
+
"use_cot_language": use_cot_language,
|
| 638 |
+
"repainting_group": repainting_group,
|
| 639 |
+
"repainting_start": repainting_start,
|
| 640 |
+
"repainting_end": repainting_end,
|
| 641 |
+
"audio_cover_strength": audio_cover_strength,
|
| 642 |
+
# Generation mode components
|
| 643 |
+
"generation_mode": generation_mode,
|
| 644 |
+
"simple_mode_group": simple_mode_group,
|
| 645 |
+
"simple_query_input": simple_query_input,
|
| 646 |
+
"random_desc_btn": random_desc_btn,
|
| 647 |
+
"simple_instrumental_checkbox": simple_instrumental_checkbox,
|
| 648 |
+
"simple_vocal_language": simple_vocal_language,
|
| 649 |
+
"create_sample_btn": create_sample_btn,
|
| 650 |
+
"simple_sample_created": simple_sample_created,
|
| 651 |
+
"caption_accordion": caption_accordion,
|
| 652 |
+
"lyrics_accordion": lyrics_accordion,
|
| 653 |
+
"optional_params_accordion": optional_params_accordion,
|
| 654 |
+
# Custom mode components
|
| 655 |
+
"custom_mode_content": custom_mode_content,
|
| 656 |
+
"cover_mode_group": cover_mode_group,
|
| 657 |
+
# Source audio group for Cover/Repaint
|
| 658 |
+
"src_audio_group": src_audio_group,
|
| 659 |
+
"process_src_btn": process_src_btn,
|
| 660 |
+
"advanced_options_accordion": advanced_options_accordion,
|
| 661 |
+
# Existing components
|
| 662 |
+
"captions": captions,
|
| 663 |
+
"sample_btn": sample_btn,
|
| 664 |
+
"load_file": load_file,
|
| 665 |
+
"lyrics": lyrics,
|
| 666 |
+
"vocal_language": vocal_language,
|
| 667 |
+
"bpm": bpm,
|
| 668 |
+
"key_scale": key_scale,
|
| 669 |
+
"time_signature": time_signature,
|
| 670 |
+
"audio_duration": audio_duration,
|
| 671 |
+
"batch_size_input": batch_size_input,
|
| 672 |
+
"inference_steps": inference_steps,
|
| 673 |
+
"guidance_scale": guidance_scale,
|
| 674 |
+
"seed": seed,
|
| 675 |
+
"random_seed_checkbox": random_seed_checkbox,
|
| 676 |
+
"use_adg": use_adg,
|
| 677 |
+
"cfg_interval_start": cfg_interval_start,
|
| 678 |
+
"cfg_interval_end": cfg_interval_end,
|
| 679 |
+
"shift": shift,
|
| 680 |
+
"infer_method": infer_method,
|
| 681 |
+
"custom_timesteps": custom_timesteps,
|
| 682 |
+
"audio_format": audio_format,
|
| 683 |
+
"think_checkbox": think_checkbox,
|
| 684 |
+
"autogen_checkbox": autogen_checkbox,
|
| 685 |
+
"generate_btn": generate_btn,
|
| 686 |
+
"instrumental_checkbox": instrumental_checkbox,
|
| 687 |
+
"format_btn": format_btn,
|
| 688 |
+
"constrained_decoding_debug": constrained_decoding_debug,
|
| 689 |
+
"score_scale": score_scale,
|
| 690 |
+
"allow_lm_batch": allow_lm_batch,
|
| 691 |
+
"auto_score": auto_score,
|
| 692 |
+
"auto_lrc": auto_lrc,
|
| 693 |
+
"lm_batch_chunk_size": lm_batch_chunk_size,
|
| 694 |
+
}
|
acestep/gradio_ui/interfaces/result.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Results Section Module
|
| 3 |
+
Contains results display section component definitions
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from acestep.gradio_ui.i18n import t
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_results_section(dit_handler) -> dict:
|
| 10 |
+
"""Create results display section"""
|
| 11 |
+
with gr.Accordion(t("results.title"), open=True):
|
| 12 |
+
# Hidden state to store LM-generated metadata
|
| 13 |
+
lm_metadata_state = gr.State(value=None)
|
| 14 |
+
|
| 15 |
+
# Hidden state to track if caption/metadata is from formatted source (LM/transcription)
|
| 16 |
+
is_format_caption_state = gr.State(value=False)
|
| 17 |
+
|
| 18 |
+
# Batch management states
|
| 19 |
+
current_batch_index = gr.State(value=0) # Currently displayed batch index
|
| 20 |
+
total_batches = gr.State(value=1) # Total number of batches generated
|
| 21 |
+
batch_queue = gr.State(value={}) # Dictionary storing all batch data
|
| 22 |
+
generation_params_state = gr.State(value={}) # Store generation parameters for next batches
|
| 23 |
+
is_generating_background = gr.State(value=False) # Background generation flag
|
| 24 |
+
|
| 25 |
+
# All audio components in one row with dynamic visibility
|
| 26 |
+
with gr.Row():
|
| 27 |
+
with gr.Column(visible=True) as audio_col_1:
|
| 28 |
+
generated_audio_1 = gr.Audio(
|
| 29 |
+
label=t("results.generated_music", n=1),
|
| 30 |
+
type="filepath",
|
| 31 |
+
interactive=False,
|
| 32 |
+
buttons=[]
|
| 33 |
+
)
|
| 34 |
+
with gr.Row(equal_height=True):
|
| 35 |
+
send_to_cover_btn_1 = gr.Button(
|
| 36 |
+
t("results.send_to_cover_btn"),
|
| 37 |
+
variant="secondary",
|
| 38 |
+
size="sm",
|
| 39 |
+
scale=1
|
| 40 |
+
)
|
| 41 |
+
send_to_repaint_btn_1 = gr.Button(
|
| 42 |
+
t("results.send_to_repaint_btn"),
|
| 43 |
+
variant="secondary",
|
| 44 |
+
size="sm",
|
| 45 |
+
scale=1
|
| 46 |
+
)
|
| 47 |
+
save_btn_1 = gr.Button(
|
| 48 |
+
t("results.save_btn"),
|
| 49 |
+
variant="primary",
|
| 50 |
+
size="sm",
|
| 51 |
+
scale=1
|
| 52 |
+
)
|
| 53 |
+
score_btn_1 = gr.Button(
|
| 54 |
+
t("results.score_btn"),
|
| 55 |
+
variant="secondary",
|
| 56 |
+
size="sm",
|
| 57 |
+
scale=1,
|
| 58 |
+
visible=False
|
| 59 |
+
)
|
| 60 |
+
lrc_btn_1 = gr.Button(
|
| 61 |
+
t("results.lrc_btn"),
|
| 62 |
+
variant="secondary",
|
| 63 |
+
size="sm",
|
| 64 |
+
scale=1,
|
| 65 |
+
visible=False
|
| 66 |
+
)
|
| 67 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
|
| 68 |
+
score_display_1 = gr.Textbox(
|
| 69 |
+
label=t("results.quality_score_label", n=1),
|
| 70 |
+
interactive=False,
|
| 71 |
+
buttons=["copy"],
|
| 72 |
+
lines=6,
|
| 73 |
+
max_lines=6,
|
| 74 |
+
visible=True
|
| 75 |
+
)
|
| 76 |
+
lrc_display_1 = gr.Textbox(
|
| 77 |
+
label=t("results.lrc_label", n=1),
|
| 78 |
+
interactive=True,
|
| 79 |
+
buttons=["copy"],
|
| 80 |
+
lines=8,
|
| 81 |
+
max_lines=8,
|
| 82 |
+
visible=True
|
| 83 |
+
)
|
| 84 |
+
codes_display_1 = gr.Textbox(
|
| 85 |
+
label=t("results.codes_label", n=1),
|
| 86 |
+
interactive=False,
|
| 87 |
+
buttons=["copy"],
|
| 88 |
+
lines=4,
|
| 89 |
+
max_lines=4,
|
| 90 |
+
visible=True
|
| 91 |
+
)
|
| 92 |
+
with gr.Column(visible=True) as audio_col_2:
|
| 93 |
+
generated_audio_2 = gr.Audio(
|
| 94 |
+
label=t("results.generated_music", n=2),
|
| 95 |
+
type="filepath",
|
| 96 |
+
interactive=False,
|
| 97 |
+
buttons=[]
|
| 98 |
+
)
|
| 99 |
+
with gr.Row(equal_height=True):
|
| 100 |
+
send_to_cover_btn_2 = gr.Button(
|
| 101 |
+
t("results.send_to_cover_btn"),
|
| 102 |
+
variant="secondary",
|
| 103 |
+
size="sm",
|
| 104 |
+
scale=1
|
| 105 |
+
)
|
| 106 |
+
send_to_repaint_btn_2 = gr.Button(
|
| 107 |
+
t("results.send_to_repaint_btn"),
|
| 108 |
+
variant="secondary",
|
| 109 |
+
size="sm",
|
| 110 |
+
scale=1
|
| 111 |
+
)
|
| 112 |
+
save_btn_2 = gr.Button(
|
| 113 |
+
t("results.save_btn"),
|
| 114 |
+
variant="primary",
|
| 115 |
+
size="sm",
|
| 116 |
+
scale=1
|
| 117 |
+
)
|
| 118 |
+
score_btn_2 = gr.Button(
|
| 119 |
+
t("results.score_btn"),
|
| 120 |
+
variant="secondary",
|
| 121 |
+
size="sm",
|
| 122 |
+
scale=1,
|
| 123 |
+
visible=False
|
| 124 |
+
)
|
| 125 |
+
lrc_btn_2 = gr.Button(
|
| 126 |
+
t("results.lrc_btn"),
|
| 127 |
+
variant="secondary",
|
| 128 |
+
size="sm",
|
| 129 |
+
scale=1,
|
| 130 |
+
visible=False
|
| 131 |
+
)
|
| 132 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
|
| 133 |
+
score_display_2 = gr.Textbox(
|
| 134 |
+
label=t("results.quality_score_label", n=2),
|
| 135 |
+
interactive=False,
|
| 136 |
+
buttons=["copy"],
|
| 137 |
+
lines=6,
|
| 138 |
+
max_lines=6,
|
| 139 |
+
visible=True
|
| 140 |
+
)
|
| 141 |
+
lrc_display_2 = gr.Textbox(
|
| 142 |
+
label=t("results.lrc_label", n=2),
|
| 143 |
+
interactive=True,
|
| 144 |
+
buttons=["copy"],
|
| 145 |
+
lines=8,
|
| 146 |
+
max_lines=8,
|
| 147 |
+
visible=True
|
| 148 |
+
)
|
| 149 |
+
codes_display_2 = gr.Textbox(
|
| 150 |
+
label=t("results.codes_label", n=2),
|
| 151 |
+
interactive=False,
|
| 152 |
+
buttons=["copy"],
|
| 153 |
+
lines=4,
|
| 154 |
+
max_lines=4,
|
| 155 |
+
visible=True
|
| 156 |
+
)
|
| 157 |
+
with gr.Column(visible=False) as audio_col_3:
|
| 158 |
+
generated_audio_3 = gr.Audio(
|
| 159 |
+
label=t("results.generated_music", n=3),
|
| 160 |
+
type="filepath",
|
| 161 |
+
interactive=False,
|
| 162 |
+
buttons=[]
|
| 163 |
+
)
|
| 164 |
+
with gr.Row(equal_height=True):
|
| 165 |
+
send_to_cover_btn_3 = gr.Button(
|
| 166 |
+
t("results.send_to_cover_btn"),
|
| 167 |
+
variant="secondary",
|
| 168 |
+
size="sm",
|
| 169 |
+
scale=1
|
| 170 |
+
)
|
| 171 |
+
send_to_repaint_btn_3 = gr.Button(
|
| 172 |
+
t("results.send_to_repaint_btn"),
|
| 173 |
+
variant="secondary",
|
| 174 |
+
size="sm",
|
| 175 |
+
scale=1
|
| 176 |
+
)
|
| 177 |
+
save_btn_3 = gr.Button(
|
| 178 |
+
t("results.save_btn"),
|
| 179 |
+
variant="primary",
|
| 180 |
+
size="sm",
|
| 181 |
+
scale=1
|
| 182 |
+
)
|
| 183 |
+
score_btn_3 = gr.Button(
|
| 184 |
+
t("results.score_btn"),
|
| 185 |
+
variant="secondary",
|
| 186 |
+
size="sm",
|
| 187 |
+
scale=1,
|
| 188 |
+
visible=False
|
| 189 |
+
)
|
| 190 |
+
lrc_btn_3 = gr.Button(
|
| 191 |
+
t("results.lrc_btn"),
|
| 192 |
+
variant="secondary",
|
| 193 |
+
size="sm",
|
| 194 |
+
scale=1,
|
| 195 |
+
visible=False
|
| 196 |
+
)
|
| 197 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
|
| 198 |
+
score_display_3 = gr.Textbox(
|
| 199 |
+
label=t("results.quality_score_label", n=3),
|
| 200 |
+
interactive=False,
|
| 201 |
+
buttons=["copy"],
|
| 202 |
+
lines=6,
|
| 203 |
+
max_lines=6,
|
| 204 |
+
visible=True
|
| 205 |
+
)
|
| 206 |
+
lrc_display_3 = gr.Textbox(
|
| 207 |
+
label=t("results.lrc_label", n=3),
|
| 208 |
+
interactive=True,
|
| 209 |
+
buttons=["copy"],
|
| 210 |
+
lines=8,
|
| 211 |
+
max_lines=8,
|
| 212 |
+
visible=True
|
| 213 |
+
)
|
| 214 |
+
codes_display_3 = gr.Textbox(
|
| 215 |
+
label=t("results.codes_label", n=3),
|
| 216 |
+
interactive=False,
|
| 217 |
+
buttons=["copy"],
|
| 218 |
+
lines=4,
|
| 219 |
+
max_lines=4,
|
| 220 |
+
visible=True
|
| 221 |
+
)
|
| 222 |
+
with gr.Column(visible=False) as audio_col_4:
|
| 223 |
+
generated_audio_4 = gr.Audio(
|
| 224 |
+
label=t("results.generated_music", n=4),
|
| 225 |
+
type="filepath",
|
| 226 |
+
interactive=False,
|
| 227 |
+
buttons=[]
|
| 228 |
+
)
|
| 229 |
+
with gr.Row(equal_height=True):
|
| 230 |
+
send_to_cover_btn_4 = gr.Button(
|
| 231 |
+
t("results.send_to_cover_btn"),
|
| 232 |
+
variant="secondary",
|
| 233 |
+
size="sm",
|
| 234 |
+
scale=1
|
| 235 |
+
)
|
| 236 |
+
send_to_repaint_btn_4 = gr.Button(
|
| 237 |
+
t("results.send_to_repaint_btn"),
|
| 238 |
+
variant="secondary",
|
| 239 |
+
size="sm",
|
| 240 |
+
scale=1
|
| 241 |
+
)
|
| 242 |
+
save_btn_4 = gr.Button(
|
| 243 |
+
t("results.save_btn"),
|
| 244 |
+
variant="primary",
|
| 245 |
+
size="sm",
|
| 246 |
+
scale=1
|
| 247 |
+
)
|
| 248 |
+
score_btn_4 = gr.Button(
|
| 249 |
+
t("results.score_btn"),
|
| 250 |
+
variant="secondary",
|
| 251 |
+
size="sm",
|
| 252 |
+
scale=1,
|
| 253 |
+
visible=False
|
| 254 |
+
)
|
| 255 |
+
lrc_btn_4 = gr.Button(
|
| 256 |
+
t("results.lrc_btn"),
|
| 257 |
+
variant="secondary",
|
| 258 |
+
size="sm",
|
| 259 |
+
scale=1,
|
| 260 |
+
visible=False
|
| 261 |
+
)
|
| 262 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
|
| 263 |
+
score_display_4 = gr.Textbox(
|
| 264 |
+
label=t("results.quality_score_label", n=4),
|
| 265 |
+
interactive=False,
|
| 266 |
+
buttons=["copy"],
|
| 267 |
+
lines=6,
|
| 268 |
+
max_lines=6,
|
| 269 |
+
visible=True
|
| 270 |
+
)
|
| 271 |
+
lrc_display_4 = gr.Textbox(
|
| 272 |
+
label=t("results.lrc_label", n=4),
|
| 273 |
+
interactive=True,
|
| 274 |
+
buttons=["copy"],
|
| 275 |
+
lines=8,
|
| 276 |
+
max_lines=8,
|
| 277 |
+
visible=True
|
| 278 |
+
)
|
| 279 |
+
codes_display_4 = gr.Textbox(
|
| 280 |
+
label=t("results.codes_label", n=4),
|
| 281 |
+
interactive=False,
|
| 282 |
+
buttons=["copy"],
|
| 283 |
+
lines=4,
|
| 284 |
+
max_lines=4,
|
| 285 |
+
visible=True
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Second row for batch size 5-8 (initially hidden)
|
| 289 |
+
with gr.Row(visible=False) as audio_row_5_8:
|
| 290 |
+
with gr.Column() as audio_col_5:
|
| 291 |
+
generated_audio_5 = gr.Audio(
|
| 292 |
+
label=t("results.generated_music", n=5),
|
| 293 |
+
type="filepath",
|
| 294 |
+
interactive=False,
|
| 295 |
+
buttons=[]
|
| 296 |
+
)
|
| 297 |
+
with gr.Row(equal_height=True):
|
| 298 |
+
send_to_cover_btn_5 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 299 |
+
send_to_repaint_btn_5 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 300 |
+
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 301 |
+
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 302 |
+
lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 303 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
|
| 304 |
+
score_display_5 = gr.Textbox(
|
| 305 |
+
label=t("results.quality_score_label", n=5),
|
| 306 |
+
interactive=False,
|
| 307 |
+
buttons=["copy"],
|
| 308 |
+
lines=6,
|
| 309 |
+
max_lines=6,
|
| 310 |
+
visible=True
|
| 311 |
+
)
|
| 312 |
+
lrc_display_5 = gr.Textbox(
|
| 313 |
+
label=t("results.lrc_label", n=5),
|
| 314 |
+
interactive=True,
|
| 315 |
+
buttons=["copy"],
|
| 316 |
+
lines=8,
|
| 317 |
+
max_lines=8,
|
| 318 |
+
visible=True
|
| 319 |
+
)
|
| 320 |
+
codes_display_5 = gr.Textbox(
|
| 321 |
+
label=t("results.codes_label", n=5),
|
| 322 |
+
interactive=False,
|
| 323 |
+
buttons=["copy"],
|
| 324 |
+
lines=4,
|
| 325 |
+
max_lines=4,
|
| 326 |
+
visible=True
|
| 327 |
+
)
|
| 328 |
+
with gr.Column() as audio_col_6:
|
| 329 |
+
generated_audio_6 = gr.Audio(
|
| 330 |
+
label=t("results.generated_music", n=6),
|
| 331 |
+
type="filepath",
|
| 332 |
+
interactive=False,
|
| 333 |
+
buttons=[]
|
| 334 |
+
)
|
| 335 |
+
with gr.Row(equal_height=True):
|
| 336 |
+
send_to_cover_btn_6 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 337 |
+
send_to_repaint_btn_6 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 338 |
+
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 339 |
+
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 340 |
+
lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 341 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
|
| 342 |
+
score_display_6 = gr.Textbox(
|
| 343 |
+
label=t("results.quality_score_label", n=6),
|
| 344 |
+
interactive=False,
|
| 345 |
+
buttons=["copy"],
|
| 346 |
+
lines=6,
|
| 347 |
+
max_lines=6,
|
| 348 |
+
visible=True
|
| 349 |
+
)
|
| 350 |
+
lrc_display_6 = gr.Textbox(
|
| 351 |
+
label=t("results.lrc_label", n=6),
|
| 352 |
+
interactive=True,
|
| 353 |
+
buttons=["copy"],
|
| 354 |
+
lines=8,
|
| 355 |
+
max_lines=8,
|
| 356 |
+
visible=True
|
| 357 |
+
)
|
| 358 |
+
codes_display_6 = gr.Textbox(
|
| 359 |
+
label=t("results.codes_label", n=6),
|
| 360 |
+
interactive=False,
|
| 361 |
+
buttons=["copy"],
|
| 362 |
+
lines=4,
|
| 363 |
+
max_lines=4,
|
| 364 |
+
visible=True
|
| 365 |
+
)
|
| 366 |
+
with gr.Column() as audio_col_7:
|
| 367 |
+
generated_audio_7 = gr.Audio(
|
| 368 |
+
label=t("results.generated_music", n=7),
|
| 369 |
+
type="filepath",
|
| 370 |
+
interactive=False,
|
| 371 |
+
buttons=[]
|
| 372 |
+
)
|
| 373 |
+
with gr.Row(equal_height=True):
|
| 374 |
+
send_to_cover_btn_7 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 375 |
+
send_to_repaint_btn_7 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 376 |
+
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 377 |
+
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 378 |
+
lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 379 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
|
| 380 |
+
score_display_7 = gr.Textbox(
|
| 381 |
+
label=t("results.quality_score_label", n=7),
|
| 382 |
+
interactive=False,
|
| 383 |
+
buttons=["copy"],
|
| 384 |
+
lines=6,
|
| 385 |
+
max_lines=6,
|
| 386 |
+
visible=True
|
| 387 |
+
)
|
| 388 |
+
lrc_display_7 = gr.Textbox(
|
| 389 |
+
label=t("results.lrc_label", n=7),
|
| 390 |
+
interactive=True,
|
| 391 |
+
buttons=["copy"],
|
| 392 |
+
lines=8,
|
| 393 |
+
max_lines=8,
|
| 394 |
+
visible=True
|
| 395 |
+
)
|
| 396 |
+
codes_display_7 = gr.Textbox(
|
| 397 |
+
label=t("results.codes_label", n=7),
|
| 398 |
+
interactive=False,
|
| 399 |
+
buttons=["copy"],
|
| 400 |
+
lines=4,
|
| 401 |
+
max_lines=4,
|
| 402 |
+
visible=True
|
| 403 |
+
)
|
| 404 |
+
with gr.Column() as audio_col_8:
|
| 405 |
+
generated_audio_8 = gr.Audio(
|
| 406 |
+
label=t("results.generated_music", n=8),
|
| 407 |
+
type="filepath",
|
| 408 |
+
interactive=False,
|
| 409 |
+
buttons=[]
|
| 410 |
+
)
|
| 411 |
+
with gr.Row(equal_height=True):
|
| 412 |
+
send_to_cover_btn_8 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 413 |
+
send_to_repaint_btn_8 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 414 |
+
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 415 |
+
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 416 |
+
lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 417 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
|
| 418 |
+
score_display_8 = gr.Textbox(
|
| 419 |
+
label=t("results.quality_score_label", n=8),
|
| 420 |
+
interactive=False,
|
| 421 |
+
buttons=["copy"],
|
| 422 |
+
lines=6,
|
| 423 |
+
max_lines=6,
|
| 424 |
+
visible=True
|
| 425 |
+
)
|
| 426 |
+
lrc_display_8 = gr.Textbox(
|
| 427 |
+
label=t("results.lrc_label", n=8),
|
| 428 |
+
interactive=True,
|
| 429 |
+
buttons=["copy"],
|
| 430 |
+
lines=8,
|
| 431 |
+
max_lines=8,
|
| 432 |
+
visible=True
|
| 433 |
+
)
|
| 434 |
+
codes_display_8 = gr.Textbox(
|
| 435 |
+
label=t("results.codes_label", n=8),
|
| 436 |
+
interactive=False,
|
| 437 |
+
buttons=["copy"],
|
| 438 |
+
lines=4,
|
| 439 |
+
max_lines=4,
|
| 440 |
+
visible=True
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
|
| 444 |
+
|
| 445 |
+
# Batch navigation controls (hidden for simplified UI)
|
| 446 |
+
with gr.Row(equal_height=True, visible=False):
|
| 447 |
+
prev_batch_btn = gr.Button(
|
| 448 |
+
t("results.prev_btn"),
|
| 449 |
+
variant="secondary",
|
| 450 |
+
interactive=False,
|
| 451 |
+
scale=1,
|
| 452 |
+
size="sm"
|
| 453 |
+
)
|
| 454 |
+
batch_indicator = gr.Textbox(
|
| 455 |
+
label=t("results.current_batch"),
|
| 456 |
+
value=t("results.batch_indicator", current=1, total=1),
|
| 457 |
+
interactive=False,
|
| 458 |
+
scale=3
|
| 459 |
+
)
|
| 460 |
+
next_batch_status = gr.Textbox(
|
| 461 |
+
label=t("results.next_batch_status"),
|
| 462 |
+
value="",
|
| 463 |
+
interactive=False,
|
| 464 |
+
scale=3
|
| 465 |
+
)
|
| 466 |
+
next_batch_btn = gr.Button(
|
| 467 |
+
t("results.next_btn"),
|
| 468 |
+
variant="primary",
|
| 469 |
+
interactive=False,
|
| 470 |
+
scale=1,
|
| 471 |
+
size="sm"
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# One-click restore parameters button (hidden for simplified UI)
|
| 475 |
+
restore_params_btn = gr.Button(
|
| 476 |
+
t("results.restore_params_btn"),
|
| 477 |
+
variant="secondary",
|
| 478 |
+
interactive=False,
|
| 479 |
+
size="sm",
|
| 480 |
+
visible=False
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
with gr.Accordion(t("results.batch_results_title"), open=True):
|
| 484 |
+
generated_audio_batch = gr.File(
|
| 485 |
+
label=t("results.all_files_label"),
|
| 486 |
+
file_count="multiple",
|
| 487 |
+
interactive=False,
|
| 488 |
+
visible=False
|
| 489 |
+
)
|
| 490 |
+
generation_info = gr.Markdown(label=t("results.generation_details"))
|
| 491 |
+
|
| 492 |
+
return {
|
| 493 |
+
"lm_metadata_state": lm_metadata_state,
|
| 494 |
+
"is_format_caption_state": is_format_caption_state,
|
| 495 |
+
"current_batch_index": current_batch_index,
|
| 496 |
+
"total_batches": total_batches,
|
| 497 |
+
"batch_queue": batch_queue,
|
| 498 |
+
"generation_params_state": generation_params_state,
|
| 499 |
+
"is_generating_background": is_generating_background,
|
| 500 |
+
"status_output": status_output,
|
| 501 |
+
"prev_batch_btn": prev_batch_btn,
|
| 502 |
+
"batch_indicator": batch_indicator,
|
| 503 |
+
"next_batch_btn": next_batch_btn,
|
| 504 |
+
"next_batch_status": next_batch_status,
|
| 505 |
+
"restore_params_btn": restore_params_btn,
|
| 506 |
+
"generated_audio_1": generated_audio_1,
|
| 507 |
+
"generated_audio_2": generated_audio_2,
|
| 508 |
+
"generated_audio_3": generated_audio_3,
|
| 509 |
+
"generated_audio_4": generated_audio_4,
|
| 510 |
+
"generated_audio_5": generated_audio_5,
|
| 511 |
+
"generated_audio_6": generated_audio_6,
|
| 512 |
+
"generated_audio_7": generated_audio_7,
|
| 513 |
+
"generated_audio_8": generated_audio_8,
|
| 514 |
+
"audio_row_5_8": audio_row_5_8,
|
| 515 |
+
"audio_col_1": audio_col_1,
|
| 516 |
+
"audio_col_2": audio_col_2,
|
| 517 |
+
"audio_col_3": audio_col_3,
|
| 518 |
+
"audio_col_4": audio_col_4,
|
| 519 |
+
"audio_col_5": audio_col_5,
|
| 520 |
+
"audio_col_6": audio_col_6,
|
| 521 |
+
"audio_col_7": audio_col_7,
|
| 522 |
+
"audio_col_8": audio_col_8,
|
| 523 |
+
"send_to_cover_btn_1": send_to_cover_btn_1,
|
| 524 |
+
"send_to_cover_btn_2": send_to_cover_btn_2,
|
| 525 |
+
"send_to_cover_btn_3": send_to_cover_btn_3,
|
| 526 |
+
"send_to_cover_btn_4": send_to_cover_btn_4,
|
| 527 |
+
"send_to_cover_btn_5": send_to_cover_btn_5,
|
| 528 |
+
"send_to_cover_btn_6": send_to_cover_btn_6,
|
| 529 |
+
"send_to_cover_btn_7": send_to_cover_btn_7,
|
| 530 |
+
"send_to_cover_btn_8": send_to_cover_btn_8,
|
| 531 |
+
"send_to_repaint_btn_1": send_to_repaint_btn_1,
|
| 532 |
+
"send_to_repaint_btn_2": send_to_repaint_btn_2,
|
| 533 |
+
"send_to_repaint_btn_3": send_to_repaint_btn_3,
|
| 534 |
+
"send_to_repaint_btn_4": send_to_repaint_btn_4,
|
| 535 |
+
"send_to_repaint_btn_5": send_to_repaint_btn_5,
|
| 536 |
+
"send_to_repaint_btn_6": send_to_repaint_btn_6,
|
| 537 |
+
"send_to_repaint_btn_7": send_to_repaint_btn_7,
|
| 538 |
+
"send_to_repaint_btn_8": send_to_repaint_btn_8,
|
| 539 |
+
"save_btn_1": save_btn_1,
|
| 540 |
+
"save_btn_2": save_btn_2,
|
| 541 |
+
"save_btn_3": save_btn_3,
|
| 542 |
+
"save_btn_4": save_btn_4,
|
| 543 |
+
"save_btn_5": save_btn_5,
|
| 544 |
+
"save_btn_6": save_btn_6,
|
| 545 |
+
"save_btn_7": save_btn_7,
|
| 546 |
+
"save_btn_8": save_btn_8,
|
| 547 |
+
"score_btn_1": score_btn_1,
|
| 548 |
+
"score_btn_2": score_btn_2,
|
| 549 |
+
"score_btn_3": score_btn_3,
|
| 550 |
+
"score_btn_4": score_btn_4,
|
| 551 |
+
"score_btn_5": score_btn_5,
|
| 552 |
+
"score_btn_6": score_btn_6,
|
| 553 |
+
"score_btn_7": score_btn_7,
|
| 554 |
+
"score_btn_8": score_btn_8,
|
| 555 |
+
"score_display_1": score_display_1,
|
| 556 |
+
"score_display_2": score_display_2,
|
| 557 |
+
"score_display_3": score_display_3,
|
| 558 |
+
"score_display_4": score_display_4,
|
| 559 |
+
"score_display_5": score_display_5,
|
| 560 |
+
"score_display_6": score_display_6,
|
| 561 |
+
"score_display_7": score_display_7,
|
| 562 |
+
"score_display_8": score_display_8,
|
| 563 |
+
"codes_display_1": codes_display_1,
|
| 564 |
+
"codes_display_2": codes_display_2,
|
| 565 |
+
"codes_display_3": codes_display_3,
|
| 566 |
+
"codes_display_4": codes_display_4,
|
| 567 |
+
"codes_display_5": codes_display_5,
|
| 568 |
+
"codes_display_6": codes_display_6,
|
| 569 |
+
"codes_display_7": codes_display_7,
|
| 570 |
+
"codes_display_8": codes_display_8,
|
| 571 |
+
"lrc_btn_1": lrc_btn_1,
|
| 572 |
+
"lrc_btn_2": lrc_btn_2,
|
| 573 |
+
"lrc_btn_3": lrc_btn_3,
|
| 574 |
+
"lrc_btn_4": lrc_btn_4,
|
| 575 |
+
"lrc_btn_5": lrc_btn_5,
|
| 576 |
+
"lrc_btn_6": lrc_btn_6,
|
| 577 |
+
"lrc_btn_7": lrc_btn_7,
|
| 578 |
+
"lrc_btn_8": lrc_btn_8,
|
| 579 |
+
"lrc_display_1": lrc_display_1,
|
| 580 |
+
"lrc_display_2": lrc_display_2,
|
| 581 |
+
"lrc_display_3": lrc_display_3,
|
| 582 |
+
"lrc_display_4": lrc_display_4,
|
| 583 |
+
"lrc_display_5": lrc_display_5,
|
| 584 |
+
"lrc_display_6": lrc_display_6,
|
| 585 |
+
"lrc_display_7": lrc_display_7,
|
| 586 |
+
"lrc_display_8": lrc_display_8,
|
| 587 |
+
"details_accordion_1": details_accordion_1,
|
| 588 |
+
"details_accordion_2": details_accordion_2,
|
| 589 |
+
"details_accordion_3": details_accordion_3,
|
| 590 |
+
"details_accordion_4": details_accordion_4,
|
| 591 |
+
"details_accordion_5": details_accordion_5,
|
| 592 |
+
"details_accordion_6": details_accordion_6,
|
| 593 |
+
"details_accordion_7": details_accordion_7,
|
| 594 |
+
"details_accordion_8": details_accordion_8,
|
| 595 |
+
"generated_audio_batch": generated_audio_batch,
|
| 596 |
+
"generation_info": generation_info,
|
| 597 |
+
}
|
| 598 |
+
|
acestep/gradio_ui/interfaces/training.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Training Tab Module
|
| 3 |
+
|
| 4 |
+
Contains the dataset builder and LoRA training interface components.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from acestep.gradio_ui.i18n import t
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_training_section(dit_handler, llm_handler, init_params=None) -> dict:
|
| 13 |
+
"""Create the training tab section with dataset builder and training controls.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
dit_handler: DiT handler instance
|
| 17 |
+
llm_handler: LLM handler instance
|
| 18 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 19 |
+
If None, service will not be pre-initialized.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Dictionary of Gradio components for event handling
|
| 23 |
+
"""
|
| 24 |
+
# Check if running in service mode (hide training tab)
|
| 25 |
+
service_mode = init_params is not None and init_params.get('service_mode', False)
|
| 26 |
+
|
| 27 |
+
with gr.Tab("🎓 LoRA Training", visible=not service_mode):
|
| 28 |
+
gr.HTML("""
|
| 29 |
+
<div style="text-align: center; padding: 10px; margin-bottom: 15px;">
|
| 30 |
+
<h2>🎵 LoRA Training for ACE-Step</h2>
|
| 31 |
+
<p>Build datasets from your audio files and train custom LoRA adapters</p>
|
| 32 |
+
</div>
|
| 33 |
+
""")
|
| 34 |
+
|
| 35 |
+
with gr.Tabs():
|
| 36 |
+
# ==================== Dataset Builder Tab ====================
|
| 37 |
+
with gr.Tab("📁 Dataset Builder"):
|
| 38 |
+
# ========== Load Existing OR Scan New ==========
|
| 39 |
+
gr.HTML("""
|
| 40 |
+
<div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
|
| 41 |
+
<h3 style="margin: 0 0 5px 0;">🚀 Quick Start</h3>
|
| 42 |
+
<p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
|
| 43 |
+
</div>
|
| 44 |
+
""")
|
| 45 |
+
|
| 46 |
+
with gr.Row():
|
| 47 |
+
with gr.Column(scale=1):
|
| 48 |
+
gr.HTML("<h4>📂 Load Existing Dataset</h4>")
|
| 49 |
+
with gr.Row():
|
| 50 |
+
load_json_path = gr.Textbox(
|
| 51 |
+
label="Dataset JSON Path",
|
| 52 |
+
placeholder="./datasets/my_lora_dataset.json",
|
| 53 |
+
info="Load a previously saved dataset",
|
| 54 |
+
scale=3,
|
| 55 |
+
)
|
| 56 |
+
load_json_btn = gr.Button("📂 Load", variant="primary", scale=1)
|
| 57 |
+
load_json_status = gr.Textbox(
|
| 58 |
+
label="Load Status",
|
| 59 |
+
interactive=False,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
with gr.Column(scale=1):
|
| 63 |
+
gr.HTML("<h4>🔍 Scan New Directory</h4>")
|
| 64 |
+
with gr.Row():
|
| 65 |
+
audio_directory = gr.Textbox(
|
| 66 |
+
label="Audio Directory Path",
|
| 67 |
+
placeholder="/path/to/your/audio/folder",
|
| 68 |
+
info="Scan for audio files (wav, mp3, flac, ogg, opus)",
|
| 69 |
+
scale=3,
|
| 70 |
+
)
|
| 71 |
+
scan_btn = gr.Button("🔍 Scan", variant="secondary", scale=1)
|
| 72 |
+
scan_status = gr.Textbox(
|
| 73 |
+
label="Scan Status",
|
| 74 |
+
interactive=False,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
gr.HTML("<hr>")
|
| 78 |
+
|
| 79 |
+
with gr.Row():
|
| 80 |
+
with gr.Column(scale=2):
|
| 81 |
+
|
| 82 |
+
# Audio files table
|
| 83 |
+
audio_files_table = gr.Dataframe(
|
| 84 |
+
headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"],
|
| 85 |
+
datatype=["number", "str", "str", "str", "str", "str", "str"],
|
| 86 |
+
label="Found Audio Files",
|
| 87 |
+
interactive=False,
|
| 88 |
+
wrap=True,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
with gr.Column(scale=1):
|
| 92 |
+
gr.HTML("<h3>⚙️ Dataset Settings</h3>")
|
| 93 |
+
|
| 94 |
+
dataset_name = gr.Textbox(
|
| 95 |
+
label="Dataset Name",
|
| 96 |
+
value="my_lora_dataset",
|
| 97 |
+
placeholder="Enter dataset name",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
all_instrumental = gr.Checkbox(
|
| 101 |
+
label="All Instrumental",
|
| 102 |
+
value=True,
|
| 103 |
+
info="Check if all tracks are instrumental (no vocals)",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
need_lyrics = gr.Checkbox(
|
| 107 |
+
label="Transcribe Lyrics",
|
| 108 |
+
value=False,
|
| 109 |
+
info="Attempt to transcribe lyrics (slower)",
|
| 110 |
+
interactive=False, # Disabled for now
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
custom_tag = gr.Textbox(
|
| 114 |
+
label="Custom Activation Tag",
|
| 115 |
+
placeholder="e.g., 8bit_retro, my_style",
|
| 116 |
+
info="Unique tag to activate this LoRA's style",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
tag_position = gr.Radio(
|
| 120 |
+
choices=[
|
| 121 |
+
("Prepend (tag, caption)", "prepend"),
|
| 122 |
+
("Append (caption, tag)", "append"),
|
| 123 |
+
("Replace caption", "replace"),
|
| 124 |
+
],
|
| 125 |
+
value="replace",
|
| 126 |
+
label="Tag Position",
|
| 127 |
+
info="Where to place the custom tag in the caption",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
gr.HTML("<hr><h3>🤖 Step 2: Auto-Label with AI</h3>")
|
| 131 |
+
|
| 132 |
+
with gr.Row():
|
| 133 |
+
with gr.Column(scale=3):
|
| 134 |
+
gr.Markdown("""
|
| 135 |
+
Click the button below to automatically generate metadata for all audio files using AI:
|
| 136 |
+
- **Caption**: Music style, genre, mood description
|
| 137 |
+
- **BPM**: Beats per minute
|
| 138 |
+
- **Key**: Musical key (e.g., C Major, Am)
|
| 139 |
+
- **Time Signature**: 4/4, 3/4, etc.
|
| 140 |
+
""")
|
| 141 |
+
skip_metas = gr.Checkbox(
|
| 142 |
+
label="Skip Metas (No LLM)",
|
| 143 |
+
value=False,
|
| 144 |
+
info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental",
|
| 145 |
+
)
|
| 146 |
+
with gr.Column(scale=1):
|
| 147 |
+
auto_label_btn = gr.Button(
|
| 148 |
+
"🏷️ Auto-Label All",
|
| 149 |
+
variant="primary",
|
| 150 |
+
size="lg",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
label_progress = gr.Textbox(
|
| 154 |
+
label="Labeling Progress",
|
| 155 |
+
interactive=False,
|
| 156 |
+
lines=2,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
gr.HTML("<hr><h3>👀 Step 3: Preview & Edit</h3>")
|
| 160 |
+
|
| 161 |
+
with gr.Row():
|
| 162 |
+
with gr.Column(scale=1):
|
| 163 |
+
sample_selector = gr.Slider(
|
| 164 |
+
minimum=0,
|
| 165 |
+
maximum=0,
|
| 166 |
+
step=1,
|
| 167 |
+
value=0,
|
| 168 |
+
label="Select Sample #",
|
| 169 |
+
info="Choose a sample to preview and edit",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
preview_audio = gr.Audio(
|
| 173 |
+
label="Audio Preview",
|
| 174 |
+
type="filepath",
|
| 175 |
+
interactive=False,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
preview_filename = gr.Textbox(
|
| 179 |
+
label="Filename",
|
| 180 |
+
interactive=False,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
with gr.Column(scale=2):
|
| 184 |
+
with gr.Row():
|
| 185 |
+
edit_caption = gr.Textbox(
|
| 186 |
+
label="Caption",
|
| 187 |
+
lines=3,
|
| 188 |
+
placeholder="Music description...",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
with gr.Row():
|
| 192 |
+
edit_lyrics = gr.Textbox(
|
| 193 |
+
label="Lyrics",
|
| 194 |
+
lines=4,
|
| 195 |
+
placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
with gr.Row():
|
| 199 |
+
edit_bpm = gr.Number(
|
| 200 |
+
label="BPM",
|
| 201 |
+
precision=0,
|
| 202 |
+
)
|
| 203 |
+
edit_keyscale = gr.Textbox(
|
| 204 |
+
label="Key",
|
| 205 |
+
placeholder="C Major",
|
| 206 |
+
)
|
| 207 |
+
edit_timesig = gr.Dropdown(
|
| 208 |
+
choices=["", "2", "3", "4", "6"],
|
| 209 |
+
label="Time Signature",
|
| 210 |
+
)
|
| 211 |
+
edit_duration = gr.Number(
|
| 212 |
+
label="Duration (s)",
|
| 213 |
+
precision=1,
|
| 214 |
+
interactive=False,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
with gr.Row():
|
| 218 |
+
edit_language = gr.Dropdown(
|
| 219 |
+
choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
|
| 220 |
+
value="instrumental",
|
| 221 |
+
label="Language",
|
| 222 |
+
)
|
| 223 |
+
edit_instrumental = gr.Checkbox(
|
| 224 |
+
label="Instrumental",
|
| 225 |
+
value=True,
|
| 226 |
+
)
|
| 227 |
+
save_edit_btn = gr.Button("💾 Save Changes", variant="secondary")
|
| 228 |
+
|
| 229 |
+
edit_status = gr.Textbox(
|
| 230 |
+
label="Edit Status",
|
| 231 |
+
interactive=False,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
gr.HTML("<hr><h3>💾 Step 4: Save Dataset</h3>")
|
| 235 |
+
|
| 236 |
+
with gr.Row():
|
| 237 |
+
with gr.Column(scale=3):
|
| 238 |
+
save_path = gr.Textbox(
|
| 239 |
+
label="Save Path",
|
| 240 |
+
value="./datasets/my_lora_dataset.json",
|
| 241 |
+
placeholder="./datasets/dataset_name.json",
|
| 242 |
+
info="Path where the dataset JSON will be saved",
|
| 243 |
+
)
|
| 244 |
+
with gr.Column(scale=1):
|
| 245 |
+
save_dataset_btn = gr.Button(
|
| 246 |
+
"💾 Save Dataset",
|
| 247 |
+
variant="primary",
|
| 248 |
+
size="lg",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
save_status = gr.Textbox(
|
| 252 |
+
label="Save Status",
|
| 253 |
+
interactive=False,
|
| 254 |
+
lines=2,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
gr.HTML("<hr><h3>⚡ Step 5: Preprocess to Tensors</h3>")
|
| 258 |
+
|
| 259 |
+
gr.Markdown("""
|
| 260 |
+
**Preprocessing converts your dataset to pre-computed tensors for fast training.**
|
| 261 |
+
|
| 262 |
+
You can either:
|
| 263 |
+
- Use the dataset from Steps 1-4 above, **OR**
|
| 264 |
+
- Load an existing dataset JSON file (if you've already saved one)
|
| 265 |
+
""")
|
| 266 |
+
|
| 267 |
+
with gr.Row():
|
| 268 |
+
with gr.Column(scale=3):
|
| 269 |
+
load_existing_dataset_path = gr.Textbox(
|
| 270 |
+
label="Load Existing Dataset (Optional)",
|
| 271 |
+
placeholder="./datasets/my_lora_dataset.json",
|
| 272 |
+
info="Path to a previously saved dataset JSON file",
|
| 273 |
+
)
|
| 274 |
+
with gr.Column(scale=1):
|
| 275 |
+
load_existing_dataset_btn = gr.Button(
|
| 276 |
+
"📂 Load Dataset",
|
| 277 |
+
variant="secondary",
|
| 278 |
+
size="lg",
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
load_existing_status = gr.Textbox(
|
| 282 |
+
label="Load Status",
|
| 283 |
+
interactive=False,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
gr.Markdown("""
|
| 287 |
+
This step:
|
| 288 |
+
- Encodes audio to VAE latents
|
| 289 |
+
- Encodes captions and lyrics to text embeddings
|
| 290 |
+
- Runs the condition encoder
|
| 291 |
+
- Saves all tensors to `.pt` files
|
| 292 |
+
|
| 293 |
+
⚠️ **This requires the model to be loaded and may take a few minutes.**
|
| 294 |
+
""")
|
| 295 |
+
|
| 296 |
+
with gr.Row():
|
| 297 |
+
with gr.Column(scale=3):
|
| 298 |
+
preprocess_output_dir = gr.Textbox(
|
| 299 |
+
label="Tensor Output Directory",
|
| 300 |
+
value="./datasets/preprocessed_tensors",
|
| 301 |
+
placeholder="./datasets/preprocessed_tensors",
|
| 302 |
+
info="Directory to save preprocessed tensor files",
|
| 303 |
+
)
|
| 304 |
+
with gr.Column(scale=1):
|
| 305 |
+
preprocess_btn = gr.Button(
|
| 306 |
+
"⚡ Preprocess",
|
| 307 |
+
variant="primary",
|
| 308 |
+
size="lg",
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
preprocess_progress = gr.Textbox(
|
| 312 |
+
label="Preprocessing Progress",
|
| 313 |
+
interactive=False,
|
| 314 |
+
lines=3,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# ==================== Training Tab ====================
|
| 318 |
+
with gr.Tab("🚀 Train LoRA"):
|
| 319 |
+
with gr.Row():
|
| 320 |
+
with gr.Column(scale=2):
|
| 321 |
+
gr.HTML("<h3>📊 Preprocessed Dataset Selection</h3>")
|
| 322 |
+
|
| 323 |
+
gr.Markdown("""
|
| 324 |
+
Select the directory containing preprocessed tensor files (`.pt` files).
|
| 325 |
+
These are created in the "Dataset Builder" tab using the "Preprocess" button.
|
| 326 |
+
""")
|
| 327 |
+
|
| 328 |
+
training_tensor_dir = gr.Textbox(
|
| 329 |
+
label="Preprocessed Tensors Directory",
|
| 330 |
+
placeholder="./datasets/preprocessed_tensors",
|
| 331 |
+
value="./datasets/preprocessed_tensors",
|
| 332 |
+
info="Directory containing preprocessed .pt tensor files",
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
load_dataset_btn = gr.Button("📂 Load Dataset", variant="secondary")
|
| 336 |
+
|
| 337 |
+
training_dataset_info = gr.Textbox(
|
| 338 |
+
label="Dataset Info",
|
| 339 |
+
interactive=False,
|
| 340 |
+
lines=3,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
with gr.Column(scale=1):
|
| 344 |
+
gr.HTML("<h3>⚙️ LoRA Settings</h3>")
|
| 345 |
+
|
| 346 |
+
lora_rank = gr.Slider(
|
| 347 |
+
minimum=4,
|
| 348 |
+
maximum=256,
|
| 349 |
+
step=4,
|
| 350 |
+
value=64,
|
| 351 |
+
label="LoRA Rank (r)",
|
| 352 |
+
info="Higher = more capacity, more memory",
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
lora_alpha = gr.Slider(
|
| 356 |
+
minimum=4,
|
| 357 |
+
maximum=512,
|
| 358 |
+
step=4,
|
| 359 |
+
value=128,
|
| 360 |
+
label="LoRA Alpha",
|
| 361 |
+
info="Scaling factor (typically 2x rank)",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
lora_dropout = gr.Slider(
|
| 365 |
+
minimum=0.0,
|
| 366 |
+
maximum=0.5,
|
| 367 |
+
step=0.05,
|
| 368 |
+
value=0.1,
|
| 369 |
+
label="LoRA Dropout",
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
gr.HTML("<hr><h3>🎛️ Training Parameters</h3>")
|
| 373 |
+
|
| 374 |
+
with gr.Row():
|
| 375 |
+
learning_rate = gr.Number(
|
| 376 |
+
label="Learning Rate",
|
| 377 |
+
value=1e-4,
|
| 378 |
+
info="Start with 1e-4, adjust if needed",
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
train_epochs = gr.Slider(
|
| 382 |
+
minimum=100,
|
| 383 |
+
maximum=4000,
|
| 384 |
+
step=100,
|
| 385 |
+
value=500,
|
| 386 |
+
label="Max Epochs",
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
train_batch_size = gr.Slider(
|
| 390 |
+
minimum=1,
|
| 391 |
+
maximum=8,
|
| 392 |
+
step=1,
|
| 393 |
+
value=1,
|
| 394 |
+
label="Batch Size",
|
| 395 |
+
info="Increase if you have enough VRAM",
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
gradient_accumulation = gr.Slider(
|
| 399 |
+
minimum=1,
|
| 400 |
+
maximum=16,
|
| 401 |
+
step=1,
|
| 402 |
+
value=1,
|
| 403 |
+
label="Gradient Accumulation",
|
| 404 |
+
info="Effective batch = batch_size × accumulation",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
with gr.Row():
|
| 408 |
+
save_every_n_epochs = gr.Slider(
|
| 409 |
+
minimum=50,
|
| 410 |
+
maximum=1000,
|
| 411 |
+
step=50,
|
| 412 |
+
value=200,
|
| 413 |
+
label="Save Every N Epochs",
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
training_shift = gr.Slider(
|
| 417 |
+
minimum=1.0,
|
| 418 |
+
maximum=5.0,
|
| 419 |
+
step=0.5,
|
| 420 |
+
value=3.0,
|
| 421 |
+
label="Shift",
|
| 422 |
+
info="Timestep shift for turbo model",
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
training_seed = gr.Number(
|
| 426 |
+
label="Seed",
|
| 427 |
+
value=42,
|
| 428 |
+
precision=0,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
with gr.Row():
|
| 432 |
+
lora_output_dir = gr.Textbox(
|
| 433 |
+
label="Output Directory",
|
| 434 |
+
value="./lora_output",
|
| 435 |
+
placeholder="./lora_output",
|
| 436 |
+
info="Directory to save trained LoRA weights",
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
gr.HTML("<hr>")
|
| 440 |
+
|
| 441 |
+
with gr.Row():
|
| 442 |
+
with gr.Column(scale=1):
|
| 443 |
+
start_training_btn = gr.Button(
|
| 444 |
+
"🚀 Start Training",
|
| 445 |
+
variant="primary",
|
| 446 |
+
size="lg",
|
| 447 |
+
)
|
| 448 |
+
with gr.Column(scale=1):
|
| 449 |
+
stop_training_btn = gr.Button(
|
| 450 |
+
"⏹️ Stop Training",
|
| 451 |
+
variant="stop",
|
| 452 |
+
size="lg",
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
training_progress = gr.Textbox(
|
| 456 |
+
label="Training Progress",
|
| 457 |
+
interactive=False,
|
| 458 |
+
lines=2,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
with gr.Row():
|
| 462 |
+
training_log = gr.Textbox(
|
| 463 |
+
label="Training Log",
|
| 464 |
+
interactive=False,
|
| 465 |
+
lines=10,
|
| 466 |
+
max_lines=15,
|
| 467 |
+
scale=1,
|
| 468 |
+
)
|
| 469 |
+
training_loss_plot = gr.LinePlot(
|
| 470 |
+
x="step",
|
| 471 |
+
y="loss",
|
| 472 |
+
title="Training Loss",
|
| 473 |
+
x_title="Step",
|
| 474 |
+
y_title="Loss",
|
| 475 |
+
scale=1,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
gr.HTML("<hr><h3>📦 Export LoRA</h3>")
|
| 479 |
+
|
| 480 |
+
with gr.Row():
|
| 481 |
+
export_path = gr.Textbox(
|
| 482 |
+
label="Export Path",
|
| 483 |
+
value="./lora_output/final_lora",
|
| 484 |
+
placeholder="./lora_output/my_lora",
|
| 485 |
+
)
|
| 486 |
+
export_lora_btn = gr.Button("📦 Export LoRA", variant="secondary")
|
| 487 |
+
|
| 488 |
+
export_status = gr.Textbox(
|
| 489 |
+
label="Export Status",
|
| 490 |
+
interactive=False,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Store dataset builder state
|
| 494 |
+
dataset_builder_state = gr.State(None)
|
| 495 |
+
training_state = gr.State({"is_training": False, "should_stop": False})
|
| 496 |
+
|
| 497 |
+
return {
|
| 498 |
+
# Dataset Builder - Load or Scan
|
| 499 |
+
"load_json_path": load_json_path,
|
| 500 |
+
"load_json_btn": load_json_btn,
|
| 501 |
+
"load_json_status": load_json_status,
|
| 502 |
+
"audio_directory": audio_directory,
|
| 503 |
+
"scan_btn": scan_btn,
|
| 504 |
+
"scan_status": scan_status,
|
| 505 |
+
"audio_files_table": audio_files_table,
|
| 506 |
+
"dataset_name": dataset_name,
|
| 507 |
+
"all_instrumental": all_instrumental,
|
| 508 |
+
"need_lyrics": need_lyrics,
|
| 509 |
+
"custom_tag": custom_tag,
|
| 510 |
+
"tag_position": tag_position,
|
| 511 |
+
"skip_metas": skip_metas,
|
| 512 |
+
"auto_label_btn": auto_label_btn,
|
| 513 |
+
"label_progress": label_progress,
|
| 514 |
+
"sample_selector": sample_selector,
|
| 515 |
+
"preview_audio": preview_audio,
|
| 516 |
+
"preview_filename": preview_filename,
|
| 517 |
+
"edit_caption": edit_caption,
|
| 518 |
+
"edit_lyrics": edit_lyrics,
|
| 519 |
+
"edit_bpm": edit_bpm,
|
| 520 |
+
"edit_keyscale": edit_keyscale,
|
| 521 |
+
"edit_timesig": edit_timesig,
|
| 522 |
+
"edit_duration": edit_duration,
|
| 523 |
+
"edit_language": edit_language,
|
| 524 |
+
"edit_instrumental": edit_instrumental,
|
| 525 |
+
"save_edit_btn": save_edit_btn,
|
| 526 |
+
"edit_status": edit_status,
|
| 527 |
+
"save_path": save_path,
|
| 528 |
+
"save_dataset_btn": save_dataset_btn,
|
| 529 |
+
"save_status": save_status,
|
| 530 |
+
# Preprocessing
|
| 531 |
+
"load_existing_dataset_path": load_existing_dataset_path,
|
| 532 |
+
"load_existing_dataset_btn": load_existing_dataset_btn,
|
| 533 |
+
"load_existing_status": load_existing_status,
|
| 534 |
+
"preprocess_output_dir": preprocess_output_dir,
|
| 535 |
+
"preprocess_btn": preprocess_btn,
|
| 536 |
+
"preprocess_progress": preprocess_progress,
|
| 537 |
+
"dataset_builder_state": dataset_builder_state,
|
| 538 |
+
# Training
|
| 539 |
+
"training_tensor_dir": training_tensor_dir,
|
| 540 |
+
"load_dataset_btn": load_dataset_btn,
|
| 541 |
+
"training_dataset_info": training_dataset_info,
|
| 542 |
+
"lora_rank": lora_rank,
|
| 543 |
+
"lora_alpha": lora_alpha,
|
| 544 |
+
"lora_dropout": lora_dropout,
|
| 545 |
+
"learning_rate": learning_rate,
|
| 546 |
+
"train_epochs": train_epochs,
|
| 547 |
+
"train_batch_size": train_batch_size,
|
| 548 |
+
"gradient_accumulation": gradient_accumulation,
|
| 549 |
+
"save_every_n_epochs": save_every_n_epochs,
|
| 550 |
+
"training_shift": training_shift,
|
| 551 |
+
"training_seed": training_seed,
|
| 552 |
+
"lora_output_dir": lora_output_dir,
|
| 553 |
+
"start_training_btn": start_training_btn,
|
| 554 |
+
"stop_training_btn": stop_training_btn,
|
| 555 |
+
"training_progress": training_progress,
|
| 556 |
+
"training_log": training_log,
|
| 557 |
+
"training_loss_plot": training_loss_plot,
|
| 558 |
+
"export_path": export_path,
|
| 559 |
+
"export_lora_btn": export_lora_btn,
|
| 560 |
+
"export_status": export_status,
|
| 561 |
+
"training_state": training_state,
|
| 562 |
+
}
|
acestep/handler.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/inference.py
ADDED
|
@@ -0,0 +1,1181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step Inference API Module
|
| 3 |
+
|
| 4 |
+
This module provides a standardized inference interface for music generation,
|
| 5 |
+
designed for third-party integration. It offers both a simplified API and
|
| 6 |
+
backward-compatible Gradio UI support.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
+
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 13 |
+
from dataclasses import dataclass, field, asdict
|
| 14 |
+
from loguru import logger
|
| 15 |
+
|
| 16 |
+
from acestep.audio_utils import AudioSaver, generate_uuid_from_params
|
| 17 |
+
|
| 18 |
+
# HuggingFace Space environment detection
|
| 19 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class GenerationParams:
|
| 24 |
+
"""Configuration for music generation parameters.
|
| 25 |
+
|
| 26 |
+
Attributes:
|
| 27 |
+
# Text Inputs
|
| 28 |
+
caption: A short text prompt describing the desired music (main prompt). < 512 characters
|
| 29 |
+
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
|
| 30 |
+
instrumental: If True, generate instrumental music regardless of lyrics.
|
| 31 |
+
|
| 32 |
+
# Music Metadata
|
| 33 |
+
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
|
| 34 |
+
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
|
| 35 |
+
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
|
| 36 |
+
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
|
| 37 |
+
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
|
| 38 |
+
|
| 39 |
+
# Generation Parameters
|
| 40 |
+
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
|
| 41 |
+
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
|
| 42 |
+
seed: Integer seed for reproducibility. -1 means use random seed each time.
|
| 43 |
+
|
| 44 |
+
# Advanced DiT Parameters
|
| 45 |
+
use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
|
| 46 |
+
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
|
| 47 |
+
cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
|
| 48 |
+
shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
|
| 49 |
+
|
| 50 |
+
# Task-Specific Parameters
|
| 51 |
+
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
|
| 52 |
+
reference_audio: Path to a reference audio file for style transfer or cover tasks.
|
| 53 |
+
src_audio: Path to a source audio file for audio-to-audio tasks.
|
| 54 |
+
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
|
| 55 |
+
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
|
| 56 |
+
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
|
| 57 |
+
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
|
| 58 |
+
instruction: Optional task instruction prompt. If empty, auto-generated by system.
|
| 59 |
+
|
| 60 |
+
# 5Hz Language Model Parameters for CoT reasoning
|
| 61 |
+
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
|
| 62 |
+
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
|
| 63 |
+
lm_cfg_scale: Classifier-free guidance scale for the LLM.
|
| 64 |
+
lm_top_k: LLM top-k sampling (0 = disabled).
|
| 65 |
+
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
|
| 66 |
+
lm_negative_prompt: Negative prompt to use for LLM (for control).
|
| 67 |
+
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
|
| 68 |
+
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
|
| 69 |
+
use_cot_language: Whether to let LLM detect vocal language via CoT.
|
| 70 |
+
"""
|
| 71 |
+
# Required Inputs
|
| 72 |
+
task_type: str = "text2music"
|
| 73 |
+
instruction: str = "Fill the audio semantic mask based on the given conditions:"
|
| 74 |
+
|
| 75 |
+
# Audio Uploads
|
| 76 |
+
reference_audio: Optional[str] = None
|
| 77 |
+
src_audio: Optional[str] = None
|
| 78 |
+
|
| 79 |
+
# LM Codes Hints
|
| 80 |
+
audio_codes: str = ""
|
| 81 |
+
|
| 82 |
+
# Text Inputs
|
| 83 |
+
caption: str = ""
|
| 84 |
+
lyrics: str = ""
|
| 85 |
+
instrumental: bool = False
|
| 86 |
+
|
| 87 |
+
# Metadata
|
| 88 |
+
vocal_language: str = "unknown"
|
| 89 |
+
bpm: Optional[int] = None
|
| 90 |
+
keyscale: str = ""
|
| 91 |
+
timesignature: str = ""
|
| 92 |
+
duration: float = -1.0
|
| 93 |
+
|
| 94 |
+
# Advanced Settings
|
| 95 |
+
inference_steps: int = 8
|
| 96 |
+
seed: int = -1
|
| 97 |
+
guidance_scale: float = 7.0
|
| 98 |
+
use_adg: bool = False
|
| 99 |
+
cfg_interval_start: float = 0.0
|
| 100 |
+
cfg_interval_end: float = 1.0
|
| 101 |
+
shift: float = 1.0
|
| 102 |
+
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 103 |
+
# Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 104 |
+
# If provided, overrides inference_steps and shift
|
| 105 |
+
timesteps: Optional[List[float]] = None
|
| 106 |
+
|
| 107 |
+
repainting_start: float = 0.0
|
| 108 |
+
repainting_end: float = -1
|
| 109 |
+
audio_cover_strength: float = 1.0
|
| 110 |
+
|
| 111 |
+
# 5Hz Language Model Parameters
|
| 112 |
+
thinking: bool = True
|
| 113 |
+
lm_temperature: float = 0.85
|
| 114 |
+
lm_cfg_scale: float = 2.0
|
| 115 |
+
lm_top_k: int = 0
|
| 116 |
+
lm_top_p: float = 0.9
|
| 117 |
+
lm_negative_prompt: str = "NO USER INPUT"
|
| 118 |
+
use_cot_metas: bool = True
|
| 119 |
+
use_cot_caption: bool = True
|
| 120 |
+
use_cot_lyrics: bool = False # TODO: not used yet
|
| 121 |
+
use_cot_language: bool = True
|
| 122 |
+
use_constrained_decoding: bool = True
|
| 123 |
+
|
| 124 |
+
cot_bpm: Optional[int] = None
|
| 125 |
+
cot_keyscale: str = ""
|
| 126 |
+
cot_timesignature: str = ""
|
| 127 |
+
cot_duration: Optional[float] = None
|
| 128 |
+
cot_vocal_language: str = "unknown"
|
| 129 |
+
cot_caption: str = ""
|
| 130 |
+
cot_lyrics: str = ""
|
| 131 |
+
|
| 132 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 133 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 134 |
+
return asdict(self)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class GenerationConfig:
|
| 139 |
+
"""Configuration for music generation.
|
| 140 |
+
|
| 141 |
+
Attributes:
|
| 142 |
+
batch_size: Number of audio samples to generate
|
| 143 |
+
allow_lm_batch: Whether to allow batch processing in LM
|
| 144 |
+
use_random_seed: Whether to use random seed
|
| 145 |
+
seeds: Seed(s) for batch generation. Can be:
|
| 146 |
+
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
|
| 147 |
+
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
|
| 148 |
+
- int: Single seed value (will be converted to list and padded)
|
| 149 |
+
lm_batch_chunk_size: Batch chunk size for LM processing
|
| 150 |
+
constrained_decoding_debug: Whether to enable constrained decoding debug
|
| 151 |
+
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
|
| 152 |
+
"""
|
| 153 |
+
batch_size: int = 2
|
| 154 |
+
allow_lm_batch: bool = False
|
| 155 |
+
use_random_seed: bool = True
|
| 156 |
+
seeds: Optional[List[int]] = None
|
| 157 |
+
lm_batch_chunk_size: int = 8
|
| 158 |
+
constrained_decoding_debug: bool = False
|
| 159 |
+
audio_format: str = "flac" # Default to FLAC for fast saving
|
| 160 |
+
|
| 161 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 162 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 163 |
+
return asdict(self)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@dataclass
|
| 167 |
+
class GenerationResult:
|
| 168 |
+
"""Result of music generation.
|
| 169 |
+
|
| 170 |
+
Attributes:
|
| 171 |
+
# Audio Outputs
|
| 172 |
+
audios: List of audio dictionaries with paths, keys, params
|
| 173 |
+
status_message: Status message from generation
|
| 174 |
+
extra_outputs: Extra outputs from generation
|
| 175 |
+
success: Whether generation completed successfully
|
| 176 |
+
error: Error message if generation failed
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
# Audio Outputs
|
| 180 |
+
audios: List[Dict[str, Any]] = field(default_factory=list)
|
| 181 |
+
# Generation Information
|
| 182 |
+
status_message: str = ""
|
| 183 |
+
extra_outputs: Dict[str, Any] = field(default_factory=dict)
|
| 184 |
+
# Success Status
|
| 185 |
+
success: bool = True
|
| 186 |
+
error: Optional[str] = None
|
| 187 |
+
|
| 188 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 189 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 190 |
+
return asdict(self)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@dataclass
|
| 194 |
+
class UnderstandResult:
|
| 195 |
+
"""Result of music understanding from audio codes.
|
| 196 |
+
|
| 197 |
+
Attributes:
|
| 198 |
+
# Metadata Fields
|
| 199 |
+
caption: Generated caption describing the music
|
| 200 |
+
lyrics: Generated or extracted lyrics
|
| 201 |
+
bpm: Beats per minute (None if not detected)
|
| 202 |
+
duration: Duration in seconds (None if not detected)
|
| 203 |
+
keyscale: Musical key (e.g., "C Major")
|
| 204 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 205 |
+
timesignature: Time signature (e.g., "4/4")
|
| 206 |
+
|
| 207 |
+
# Status
|
| 208 |
+
status_message: Status message from understanding
|
| 209 |
+
success: Whether understanding completed successfully
|
| 210 |
+
error: Error message if understanding failed
|
| 211 |
+
"""
|
| 212 |
+
# Metadata Fields
|
| 213 |
+
caption: str = ""
|
| 214 |
+
lyrics: str = ""
|
| 215 |
+
bpm: Optional[int] = None
|
| 216 |
+
duration: Optional[float] = None
|
| 217 |
+
keyscale: str = ""
|
| 218 |
+
language: str = ""
|
| 219 |
+
timesignature: str = ""
|
| 220 |
+
|
| 221 |
+
# Status
|
| 222 |
+
status_message: str = ""
|
| 223 |
+
success: bool = True
|
| 224 |
+
error: Optional[str] = None
|
| 225 |
+
|
| 226 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 227 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 228 |
+
return asdict(self)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _update_metadata_from_lm(
|
| 232 |
+
metadata: Dict[str, Any],
|
| 233 |
+
bpm: Optional[int],
|
| 234 |
+
key_scale: str,
|
| 235 |
+
time_signature: str,
|
| 236 |
+
audio_duration: Optional[float],
|
| 237 |
+
vocal_language: str,
|
| 238 |
+
caption: str,
|
| 239 |
+
lyrics: str,
|
| 240 |
+
) -> Tuple[Optional[int], str, str, Optional[float]]:
|
| 241 |
+
"""Update metadata fields from LM output if not provided by user."""
|
| 242 |
+
|
| 243 |
+
if bpm is None and metadata.get('bpm'):
|
| 244 |
+
bpm_value = metadata.get('bpm')
|
| 245 |
+
if bpm_value not in ["N/A", ""]:
|
| 246 |
+
try:
|
| 247 |
+
bpm = int(bpm_value)
|
| 248 |
+
except (ValueError, TypeError):
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
if not key_scale and metadata.get('keyscale'):
|
| 252 |
+
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 253 |
+
if key_scale_value != "N/A":
|
| 254 |
+
key_scale = key_scale_value
|
| 255 |
+
|
| 256 |
+
if not time_signature and metadata.get('timesignature'):
|
| 257 |
+
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 258 |
+
if time_signature_value != "N/A":
|
| 259 |
+
time_signature = time_signature_value
|
| 260 |
+
|
| 261 |
+
if audio_duration is None or audio_duration <= 0:
|
| 262 |
+
audio_duration_value = metadata.get('duration', -1)
|
| 263 |
+
if audio_duration_value not in ["N/A", ""]:
|
| 264 |
+
try:
|
| 265 |
+
audio_duration = float(audio_duration_value)
|
| 266 |
+
except (ValueError, TypeError):
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
if not vocal_language and metadata.get('vocal_language'):
|
| 270 |
+
vocal_language = metadata.get('vocal_language')
|
| 271 |
+
if not caption and metadata.get('caption'):
|
| 272 |
+
caption = metadata.get('caption')
|
| 273 |
+
if not lyrics and metadata.get('lyrics'):
|
| 274 |
+
lyrics = metadata.get('lyrics')
|
| 275 |
+
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def generate_music(
|
| 279 |
+
dit_handler,
|
| 280 |
+
llm_handler,
|
| 281 |
+
params: GenerationParams,
|
| 282 |
+
config: GenerationConfig,
|
| 283 |
+
save_dir: Optional[str] = None,
|
| 284 |
+
progress=None,
|
| 285 |
+
) -> GenerationResult:
|
| 286 |
+
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 290 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 291 |
+
params: Generation parameters (GenerationParams instance)
|
| 292 |
+
config: Generation configuration (GenerationConfig instance)
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
GenerationResult with generated audio files and metadata
|
| 296 |
+
"""
|
| 297 |
+
try:
|
| 298 |
+
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 299 |
+
audio_code_string_to_use = params.audio_codes
|
| 300 |
+
lm_generated_metadata = None
|
| 301 |
+
lm_generated_audio_codes_list = []
|
| 302 |
+
lm_total_time_costs = {
|
| 303 |
+
"phase1_time": 0.0,
|
| 304 |
+
"phase2_time": 0.0,
|
| 305 |
+
"total_time": 0.0,
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 309 |
+
bpm = params.bpm
|
| 310 |
+
key_scale = params.keyscale
|
| 311 |
+
time_signature = params.timesignature
|
| 312 |
+
audio_duration = params.duration
|
| 313 |
+
dit_input_caption = params.caption
|
| 314 |
+
dit_input_vocal_language = params.vocal_language
|
| 315 |
+
dit_input_lyrics = params.lyrics
|
| 316 |
+
# Determine if we need to generate audio codes
|
| 317 |
+
# If user has provided audio_codes, we don't need to generate them
|
| 318 |
+
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
|
| 319 |
+
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
|
| 320 |
+
|
| 321 |
+
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
|
| 322 |
+
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
|
| 323 |
+
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
|
| 324 |
+
# Note: This logic can be refined based on specific requirements
|
| 325 |
+
need_audio_codes = not user_provided_audio_codes
|
| 326 |
+
|
| 327 |
+
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
|
| 328 |
+
# Determine actual batch size for chunk processing
|
| 329 |
+
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
| 330 |
+
|
| 331 |
+
# Prepare seeds for batch generation
|
| 332 |
+
# Use config.seed if provided, otherwise fallback to params.seed
|
| 333 |
+
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
|
| 334 |
+
seed_for_generation = ""
|
| 335 |
+
if config.seeds is not None and len(config.seeds) > 0:
|
| 336 |
+
if isinstance(config.seeds, list):
|
| 337 |
+
# Convert List[int] to comma-separated string
|
| 338 |
+
seed_for_generation = ",".join(str(s) for s in config.seeds)
|
| 339 |
+
|
| 340 |
+
# Use dit_handler.prepare_seeds to handle seed list generation and padding
|
| 341 |
+
# This will handle all the logic: padding with random seeds if needed, etc.
|
| 342 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
| 343 |
+
|
| 344 |
+
# LM-based Chain-of-Thought reasoning
|
| 345 |
+
# Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
|
| 346 |
+
# and don't need LM to generate audio codes
|
| 347 |
+
skip_lm_tasks = {"cover", "repaint"}
|
| 348 |
+
|
| 349 |
+
# Determine if we should use LLM
|
| 350 |
+
# LLM is needed for:
|
| 351 |
+
# 1. thinking=True: generate audio codes via LM
|
| 352 |
+
# 2. use_cot_caption=True: enhance/generate caption via CoT
|
| 353 |
+
# 3. use_cot_language=True: detect vocal language via CoT
|
| 354 |
+
# 4. use_cot_metas=True: fill missing metadata via CoT
|
| 355 |
+
need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
|
| 356 |
+
use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
|
| 357 |
+
lm_status = []
|
| 358 |
+
|
| 359 |
+
if params.task_type in skip_lm_tasks:
|
| 360 |
+
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
|
| 361 |
+
|
| 362 |
+
logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
|
| 363 |
+
f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
|
| 364 |
+
f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
|
| 365 |
+
f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
|
| 366 |
+
|
| 367 |
+
if use_lm:
|
| 368 |
+
# Convert sampling parameters - handle None values safely
|
| 369 |
+
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|
| 370 |
+
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
|
| 371 |
+
|
| 372 |
+
# Build user_metadata from user-provided values
|
| 373 |
+
user_metadata = {}
|
| 374 |
+
if bpm is not None:
|
| 375 |
+
try:
|
| 376 |
+
bpm_value = float(bpm)
|
| 377 |
+
if bpm_value > 0:
|
| 378 |
+
user_metadata['bpm'] = int(bpm_value)
|
| 379 |
+
except (ValueError, TypeError):
|
| 380 |
+
pass
|
| 381 |
+
|
| 382 |
+
if key_scale and key_scale.strip():
|
| 383 |
+
key_scale_clean = key_scale.strip()
|
| 384 |
+
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 385 |
+
user_metadata['keyscale'] = key_scale_clean
|
| 386 |
+
|
| 387 |
+
if time_signature and time_signature.strip():
|
| 388 |
+
time_sig_clean = time_signature.strip()
|
| 389 |
+
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 390 |
+
user_metadata['timesignature'] = time_sig_clean
|
| 391 |
+
|
| 392 |
+
if audio_duration is not None:
|
| 393 |
+
try:
|
| 394 |
+
duration_value = float(audio_duration)
|
| 395 |
+
if duration_value > 0:
|
| 396 |
+
user_metadata['duration'] = int(duration_value)
|
| 397 |
+
except (ValueError, TypeError):
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 401 |
+
|
| 402 |
+
# Determine infer_type based on whether we need audio codes
|
| 403 |
+
# - "llm_dit": generates both metas and audio codes (two-phase internally)
|
| 404 |
+
# - "dit": generates only metas (single phase)
|
| 405 |
+
infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
|
| 406 |
+
|
| 407 |
+
# Use chunk size from config, or default to batch_size if not set
|
| 408 |
+
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
|
| 409 |
+
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
|
| 410 |
+
|
| 411 |
+
all_metadata_list = []
|
| 412 |
+
all_audio_codes_list = []
|
| 413 |
+
|
| 414 |
+
for chunk_idx in range(num_chunks):
|
| 415 |
+
chunk_start = chunk_idx * max_inference_batch_size
|
| 416 |
+
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
|
| 417 |
+
chunk_size = chunk_end - chunk_start
|
| 418 |
+
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
|
| 419 |
+
|
| 420 |
+
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
|
| 421 |
+
f"(size: {chunk_size}, seeds: {chunk_seeds})")
|
| 422 |
+
|
| 423 |
+
# Use the determined infer_type
|
| 424 |
+
# - "llm_dit" will internally run two phases (metas + codes)
|
| 425 |
+
# - "dit" will only run phase 1 (metas only)
|
| 426 |
+
result = llm_handler.generate_with_stop_condition(
|
| 427 |
+
caption=params.caption or "",
|
| 428 |
+
lyrics=params.lyrics or "",
|
| 429 |
+
infer_type=infer_type,
|
| 430 |
+
temperature=params.lm_temperature,
|
| 431 |
+
cfg_scale=params.lm_cfg_scale,
|
| 432 |
+
negative_prompt=params.lm_negative_prompt,
|
| 433 |
+
top_k=top_k_value,
|
| 434 |
+
top_p=top_p_value,
|
| 435 |
+
target_duration=audio_duration, # Pass duration to limit audio codes generation
|
| 436 |
+
user_metadata=user_metadata_to_pass,
|
| 437 |
+
use_cot_caption=params.use_cot_caption,
|
| 438 |
+
use_cot_language=params.use_cot_language,
|
| 439 |
+
use_cot_metas=params.use_cot_metas,
|
| 440 |
+
use_constrained_decoding=params.use_constrained_decoding,
|
| 441 |
+
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 442 |
+
batch_size=chunk_size,
|
| 443 |
+
seeds=chunk_seeds,
|
| 444 |
+
progress=progress,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Check if LM generation failed
|
| 448 |
+
if not result.get("success", False):
|
| 449 |
+
error_msg = result.get("error", "Unknown LM error")
|
| 450 |
+
lm_status.append(f"❌ LM Error: {error_msg}")
|
| 451 |
+
# Return early with error
|
| 452 |
+
return GenerationResult(
|
| 453 |
+
audios=[],
|
| 454 |
+
status_message=f"❌ LM generation failed: {error_msg}",
|
| 455 |
+
extra_outputs={},
|
| 456 |
+
success=False,
|
| 457 |
+
error=error_msg,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# Extract metadata and audio_codes from result dict
|
| 461 |
+
if chunk_size > 1:
|
| 462 |
+
metadata_list = result.get("metadata", [])
|
| 463 |
+
audio_codes_list = result.get("audio_codes", [])
|
| 464 |
+
all_metadata_list.extend(metadata_list)
|
| 465 |
+
all_audio_codes_list.extend(audio_codes_list)
|
| 466 |
+
else:
|
| 467 |
+
metadata = result.get("metadata", {})
|
| 468 |
+
audio_codes = result.get("audio_codes", "")
|
| 469 |
+
all_metadata_list.append(metadata)
|
| 470 |
+
all_audio_codes_list.append(audio_codes)
|
| 471 |
+
|
| 472 |
+
# Collect time costs from LM extra_outputs
|
| 473 |
+
lm_extra = result.get("extra_outputs", {})
|
| 474 |
+
lm_chunk_time_costs = lm_extra.get("time_costs", {})
|
| 475 |
+
if lm_chunk_time_costs:
|
| 476 |
+
# Accumulate time costs from all chunks
|
| 477 |
+
for key in ["phase1_time", "phase2_time", "total_time"]:
|
| 478 |
+
if key in lm_chunk_time_costs:
|
| 479 |
+
lm_total_time_costs[key] += lm_chunk_time_costs[key]
|
| 480 |
+
|
| 481 |
+
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
|
| 482 |
+
lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
|
| 483 |
+
|
| 484 |
+
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 485 |
+
lm_generated_audio_codes_list = all_audio_codes_list
|
| 486 |
+
|
| 487 |
+
# Set audio_code_string_to_use based on infer_type
|
| 488 |
+
if infer_type == "llm_dit":
|
| 489 |
+
# If batch mode, use list; otherwise use single string
|
| 490 |
+
if actual_batch_size > 1:
|
| 491 |
+
audio_code_string_to_use = all_audio_codes_list
|
| 492 |
+
else:
|
| 493 |
+
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
|
| 494 |
+
else:
|
| 495 |
+
# For "dit" mode, keep user-provided codes or empty
|
| 496 |
+
audio_code_string_to_use = params.audio_codes
|
| 497 |
+
|
| 498 |
+
# Update metadata from LM if not provided by user
|
| 499 |
+
if lm_generated_metadata:
|
| 500 |
+
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
|
| 501 |
+
metadata=lm_generated_metadata,
|
| 502 |
+
bpm=bpm,
|
| 503 |
+
key_scale=key_scale,
|
| 504 |
+
time_signature=time_signature,
|
| 505 |
+
audio_duration=audio_duration,
|
| 506 |
+
vocal_language=dit_input_vocal_language,
|
| 507 |
+
caption=dit_input_caption,
|
| 508 |
+
lyrics=dit_input_lyrics)
|
| 509 |
+
if not params.bpm:
|
| 510 |
+
params.cot_bpm = bpm
|
| 511 |
+
if not params.keyscale:
|
| 512 |
+
params.cot_keyscale = key_scale
|
| 513 |
+
if not params.timesignature:
|
| 514 |
+
params.cot_timesignature = time_signature
|
| 515 |
+
if not params.duration:
|
| 516 |
+
params.cot_duration = audio_duration
|
| 517 |
+
if not params.vocal_language:
|
| 518 |
+
params.cot_vocal_language = vocal_language
|
| 519 |
+
if not params.caption:
|
| 520 |
+
params.cot_caption = caption
|
| 521 |
+
if not params.lyrics:
|
| 522 |
+
params.cot_lyrics = lyrics
|
| 523 |
+
|
| 524 |
+
# set cot caption and language if needed
|
| 525 |
+
if params.use_cot_caption:
|
| 526 |
+
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
|
| 527 |
+
if params.use_cot_language:
|
| 528 |
+
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
|
| 529 |
+
|
| 530 |
+
# Phase 2: DiT music generation
|
| 531 |
+
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
|
| 532 |
+
result = dit_handler.generate_music(
|
| 533 |
+
captions=dit_input_caption,
|
| 534 |
+
lyrics=dit_input_lyrics,
|
| 535 |
+
bpm=bpm,
|
| 536 |
+
key_scale=key_scale,
|
| 537 |
+
time_signature=time_signature,
|
| 538 |
+
vocal_language=dit_input_vocal_language,
|
| 539 |
+
inference_steps=params.inference_steps,
|
| 540 |
+
guidance_scale=params.guidance_scale,
|
| 541 |
+
use_random_seed=config.use_random_seed,
|
| 542 |
+
seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
|
| 543 |
+
reference_audio=params.reference_audio,
|
| 544 |
+
audio_duration=audio_duration,
|
| 545 |
+
batch_size=config.batch_size if config.batch_size is not None else 1,
|
| 546 |
+
src_audio=params.src_audio,
|
| 547 |
+
audio_code_string=audio_code_string_to_use,
|
| 548 |
+
repainting_start=params.repainting_start,
|
| 549 |
+
repainting_end=params.repainting_end,
|
| 550 |
+
instruction=params.instruction,
|
| 551 |
+
audio_cover_strength=params.audio_cover_strength,
|
| 552 |
+
task_type=params.task_type,
|
| 553 |
+
use_adg=params.use_adg,
|
| 554 |
+
cfg_interval_start=params.cfg_interval_start,
|
| 555 |
+
cfg_interval_end=params.cfg_interval_end,
|
| 556 |
+
shift=params.shift,
|
| 557 |
+
infer_method=params.infer_method,
|
| 558 |
+
timesteps=params.timesteps,
|
| 559 |
+
progress=progress,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Check if generation failed
|
| 563 |
+
if not result.get("success", False):
|
| 564 |
+
return GenerationResult(
|
| 565 |
+
audios=[],
|
| 566 |
+
status_message=result.get("status_message", ""),
|
| 567 |
+
extra_outputs={},
|
| 568 |
+
success=False,
|
| 569 |
+
error=result.get("error"),
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
# Extract results from dit_handler.generate_music dict
|
| 573 |
+
dit_audios = result.get("audios", [])
|
| 574 |
+
status_message = result.get("status_message", "")
|
| 575 |
+
dit_extra_outputs = result.get("extra_outputs", {})
|
| 576 |
+
|
| 577 |
+
# Use the seed list already prepared above (from config.seed or params.seed fallback)
|
| 578 |
+
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
|
| 579 |
+
seed_list = actual_seed_list
|
| 580 |
+
|
| 581 |
+
# Get base params dictionary
|
| 582 |
+
base_params_dict = params.to_dict()
|
| 583 |
+
|
| 584 |
+
# Save audio files using AudioSaver (format from config)
|
| 585 |
+
audio_format = config.audio_format if config.audio_format else "flac"
|
| 586 |
+
audio_saver = AudioSaver(default_format=audio_format)
|
| 587 |
+
|
| 588 |
+
# Use handler's temp_dir for saving files
|
| 589 |
+
if save_dir is not None:
|
| 590 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 591 |
+
|
| 592 |
+
# Build audios list for GenerationResult with params and save files
|
| 593 |
+
# Audio saving and UUID generation handled here, outside of handler
|
| 594 |
+
audios = []
|
| 595 |
+
for idx, dit_audio in enumerate(dit_audios):
|
| 596 |
+
# Create a copy of params dict for this audio
|
| 597 |
+
audio_params = base_params_dict.copy()
|
| 598 |
+
|
| 599 |
+
# Update audio-specific values
|
| 600 |
+
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
|
| 601 |
+
|
| 602 |
+
# Add audio codes if batch mode
|
| 603 |
+
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
|
| 604 |
+
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
|
| 605 |
+
|
| 606 |
+
# Get audio tensor and metadata
|
| 607 |
+
audio_tensor = dit_audio.get("tensor")
|
| 608 |
+
sample_rate = dit_audio.get("sample_rate", 48000)
|
| 609 |
+
|
| 610 |
+
# Generate UUID for this audio (moved from handler)
|
| 611 |
+
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
|
| 612 |
+
audio_code_str = lm_generated_audio_codes_list[idx] if (
|
| 613 |
+
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
|
| 614 |
+
if isinstance(audio_code_str, list):
|
| 615 |
+
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
|
| 616 |
+
|
| 617 |
+
audio_key = generate_uuid_from_params(audio_params)
|
| 618 |
+
|
| 619 |
+
# Save audio file (handled outside handler)
|
| 620 |
+
audio_path = None
|
| 621 |
+
if audio_tensor is not None and save_dir is not None:
|
| 622 |
+
try:
|
| 623 |
+
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
|
| 624 |
+
audio_path = audio_saver.save_audio(audio_tensor,
|
| 625 |
+
audio_file,
|
| 626 |
+
sample_rate=sample_rate,
|
| 627 |
+
format=audio_format,
|
| 628 |
+
channels_first=True)
|
| 629 |
+
except Exception as e:
|
| 630 |
+
logger.error(f"[generate_music] Failed to save audio file: {e}")
|
| 631 |
+
audio_path = "" # Fallback to empty path
|
| 632 |
+
|
| 633 |
+
audio_dict = {
|
| 634 |
+
"path": audio_path or "", # File path (saved here, not in handler)
|
| 635 |
+
"tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
|
| 636 |
+
"key": audio_key,
|
| 637 |
+
"sample_rate": sample_rate,
|
| 638 |
+
"params": audio_params,
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
audios.append(audio_dict)
|
| 642 |
+
|
| 643 |
+
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
|
| 644 |
+
extra_outputs = dit_extra_outputs.copy()
|
| 645 |
+
extra_outputs["lm_metadata"] = lm_generated_metadata
|
| 646 |
+
|
| 647 |
+
# Merge time_costs from both LM and DiT into a unified dictionary
|
| 648 |
+
unified_time_costs = {}
|
| 649 |
+
|
| 650 |
+
# Add LM time costs (if LM was used)
|
| 651 |
+
if use_lm and lm_total_time_costs:
|
| 652 |
+
for key, value in lm_total_time_costs.items():
|
| 653 |
+
unified_time_costs[f"lm_{key}"] = value
|
| 654 |
+
|
| 655 |
+
# Add DiT time costs (if available)
|
| 656 |
+
dit_time_costs = dit_extra_outputs.get("time_costs", {})
|
| 657 |
+
if dit_time_costs:
|
| 658 |
+
for key, value in dit_time_costs.items():
|
| 659 |
+
unified_time_costs[f"dit_{key}"] = value
|
| 660 |
+
|
| 661 |
+
# Calculate total pipeline time
|
| 662 |
+
if unified_time_costs:
|
| 663 |
+
lm_total = unified_time_costs.get("lm_total_time", 0.0)
|
| 664 |
+
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
|
| 665 |
+
unified_time_costs["pipeline_total_time"] = lm_total + dit_total
|
| 666 |
+
|
| 667 |
+
# Update extra_outputs with unified time_costs
|
| 668 |
+
extra_outputs["time_costs"] = unified_time_costs
|
| 669 |
+
|
| 670 |
+
if lm_status:
|
| 671 |
+
status_message = "\n".join(lm_status) + "\n" + status_message
|
| 672 |
+
else:
|
| 673 |
+
status_message = status_message
|
| 674 |
+
# Create and return GenerationResult
|
| 675 |
+
return GenerationResult(
|
| 676 |
+
audios=audios,
|
| 677 |
+
status_message=status_message,
|
| 678 |
+
extra_outputs=extra_outputs,
|
| 679 |
+
success=True,
|
| 680 |
+
error=None,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
except Exception as e:
|
| 684 |
+
logger.exception("Music generation failed")
|
| 685 |
+
return GenerationResult(
|
| 686 |
+
audios=[],
|
| 687 |
+
status_message=f"Error: {str(e)}",
|
| 688 |
+
extra_outputs={},
|
| 689 |
+
success=False,
|
| 690 |
+
error=str(e),
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def understand_music(
|
| 695 |
+
llm_handler,
|
| 696 |
+
audio_codes: str,
|
| 697 |
+
temperature: float = 0.85,
|
| 698 |
+
top_k: Optional[int] = None,
|
| 699 |
+
top_p: Optional[float] = None,
|
| 700 |
+
repetition_penalty: float = 1.0,
|
| 701 |
+
use_constrained_decoding: bool = True,
|
| 702 |
+
constrained_decoding_debug: bool = False,
|
| 703 |
+
) -> UnderstandResult:
|
| 704 |
+
"""Understand music from audio codes using the 5Hz Language Model.
|
| 705 |
+
|
| 706 |
+
This function analyzes audio semantic codes and generates metadata about the music,
|
| 707 |
+
including caption, lyrics, BPM, duration, key scale, language, and time signature.
|
| 708 |
+
|
| 709 |
+
If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
|
| 710 |
+
instead of analyzing existing codes.
|
| 711 |
+
|
| 712 |
+
Note: cfg_scale and negative_prompt are not supported in understand mode.
|
| 713 |
+
|
| 714 |
+
Args:
|
| 715 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 716 |
+
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
|
| 717 |
+
Use empty string or "NO USER INPUT" to generate a sample example.
|
| 718 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 719 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 720 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 721 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 722 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 723 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 724 |
+
|
| 725 |
+
Returns:
|
| 726 |
+
UnderstandResult with parsed metadata fields and status
|
| 727 |
+
|
| 728 |
+
Example:
|
| 729 |
+
>>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
|
| 730 |
+
>>> if result.success:
|
| 731 |
+
... print(f"Caption: {result.caption}")
|
| 732 |
+
... print(f"BPM: {result.bpm}")
|
| 733 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 734 |
+
"""
|
| 735 |
+
# Check if LLM is initialized
|
| 736 |
+
if not llm_handler.llm_initialized:
|
| 737 |
+
return UnderstandResult(
|
| 738 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 739 |
+
success=False,
|
| 740 |
+
error="LLM not initialized",
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# If codes are empty, use "NO USER INPUT" to generate a sample example
|
| 744 |
+
if not audio_codes or not audio_codes.strip():
|
| 745 |
+
audio_codes = "NO USER INPUT"
|
| 746 |
+
|
| 747 |
+
try:
|
| 748 |
+
# Call LLM understanding
|
| 749 |
+
metadata, status = llm_handler.understand_audio_from_codes(
|
| 750 |
+
audio_codes=audio_codes,
|
| 751 |
+
temperature=temperature,
|
| 752 |
+
top_k=top_k,
|
| 753 |
+
top_p=top_p,
|
| 754 |
+
repetition_penalty=repetition_penalty,
|
| 755 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 756 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
# Check if LLM returned empty metadata (error case)
|
| 760 |
+
if not metadata:
|
| 761 |
+
return UnderstandResult(
|
| 762 |
+
status_message=status or "Failed to understand audio codes",
|
| 763 |
+
success=False,
|
| 764 |
+
error=status or "Empty metadata returned",
|
| 765 |
+
)
|
| 766 |
+
|
| 767 |
+
# Extract and convert fields
|
| 768 |
+
caption = metadata.get('caption', '')
|
| 769 |
+
lyrics = metadata.get('lyrics', '')
|
| 770 |
+
keyscale = metadata.get('keyscale', '')
|
| 771 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 772 |
+
timesignature = metadata.get('timesignature', '')
|
| 773 |
+
|
| 774 |
+
# Convert BPM to int
|
| 775 |
+
bpm = None
|
| 776 |
+
bpm_value = metadata.get('bpm')
|
| 777 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 778 |
+
try:
|
| 779 |
+
bpm = int(bpm_value)
|
| 780 |
+
except (ValueError, TypeError):
|
| 781 |
+
pass
|
| 782 |
+
|
| 783 |
+
# Convert duration to float
|
| 784 |
+
duration = None
|
| 785 |
+
duration_value = metadata.get('duration')
|
| 786 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 787 |
+
try:
|
| 788 |
+
duration = float(duration_value)
|
| 789 |
+
except (ValueError, TypeError):
|
| 790 |
+
pass
|
| 791 |
+
|
| 792 |
+
# Clean up N/A values
|
| 793 |
+
if keyscale == 'N/A':
|
| 794 |
+
keyscale = ''
|
| 795 |
+
if language == 'N/A':
|
| 796 |
+
language = ''
|
| 797 |
+
if timesignature == 'N/A':
|
| 798 |
+
timesignature = ''
|
| 799 |
+
|
| 800 |
+
return UnderstandResult(
|
| 801 |
+
caption=caption,
|
| 802 |
+
lyrics=lyrics,
|
| 803 |
+
bpm=bpm,
|
| 804 |
+
duration=duration,
|
| 805 |
+
keyscale=keyscale,
|
| 806 |
+
language=language,
|
| 807 |
+
timesignature=timesignature,
|
| 808 |
+
status_message=status,
|
| 809 |
+
success=True,
|
| 810 |
+
error=None,
|
| 811 |
+
)
|
| 812 |
+
|
| 813 |
+
except Exception as e:
|
| 814 |
+
logger.exception("Music understanding failed")
|
| 815 |
+
return UnderstandResult(
|
| 816 |
+
status_message=f"Error: {str(e)}",
|
| 817 |
+
success=False,
|
| 818 |
+
error=str(e),
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
@dataclass
|
| 823 |
+
class CreateSampleResult:
|
| 824 |
+
"""Result of creating a music sample from a natural language query.
|
| 825 |
+
|
| 826 |
+
This is used by the "Simple Mode" / "Inspiration Mode" feature where users
|
| 827 |
+
provide a natural language description and the LLM generates a complete
|
| 828 |
+
sample with caption, lyrics, and metadata.
|
| 829 |
+
|
| 830 |
+
Attributes:
|
| 831 |
+
# Metadata Fields
|
| 832 |
+
caption: Generated detailed music description/caption
|
| 833 |
+
lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
|
| 834 |
+
bpm: Beats per minute (None if not generated)
|
| 835 |
+
duration: Duration in seconds (None if not generated)
|
| 836 |
+
keyscale: Musical key (e.g., "C Major")
|
| 837 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 838 |
+
timesignature: Time signature (e.g., "4")
|
| 839 |
+
instrumental: Whether this is an instrumental piece
|
| 840 |
+
|
| 841 |
+
# Status
|
| 842 |
+
status_message: Status message from sample creation
|
| 843 |
+
success: Whether sample creation completed successfully
|
| 844 |
+
error: Error message if sample creation failed
|
| 845 |
+
"""
|
| 846 |
+
# Metadata Fields
|
| 847 |
+
caption: str = ""
|
| 848 |
+
lyrics: str = ""
|
| 849 |
+
bpm: Optional[int] = None
|
| 850 |
+
duration: Optional[float] = None
|
| 851 |
+
keyscale: str = ""
|
| 852 |
+
language: str = ""
|
| 853 |
+
timesignature: str = ""
|
| 854 |
+
instrumental: bool = False
|
| 855 |
+
|
| 856 |
+
# Status
|
| 857 |
+
status_message: str = ""
|
| 858 |
+
success: bool = True
|
| 859 |
+
error: Optional[str] = None
|
| 860 |
+
|
| 861 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 862 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 863 |
+
return asdict(self)
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
def create_sample(
|
| 867 |
+
llm_handler,
|
| 868 |
+
query: str,
|
| 869 |
+
instrumental: bool = False,
|
| 870 |
+
vocal_language: Optional[str] = None,
|
| 871 |
+
temperature: float = 0.85,
|
| 872 |
+
top_k: Optional[int] = None,
|
| 873 |
+
top_p: Optional[float] = None,
|
| 874 |
+
repetition_penalty: float = 1.0,
|
| 875 |
+
use_constrained_decoding: bool = True,
|
| 876 |
+
constrained_decoding_debug: bool = False,
|
| 877 |
+
) -> CreateSampleResult:
|
| 878 |
+
"""Create a music sample from a natural language query using the 5Hz Language Model.
|
| 879 |
+
|
| 880 |
+
This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
|
| 881 |
+
language description of music and generates a complete sample including:
|
| 882 |
+
- Detailed caption/description
|
| 883 |
+
- Lyrics (unless instrumental)
|
| 884 |
+
- Metadata (BPM, duration, key, language, time signature)
|
| 885 |
+
|
| 886 |
+
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 887 |
+
|
| 888 |
+
Args:
|
| 889 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 890 |
+
query: User's natural language music description (e.g., "a soft Bengali love song")
|
| 891 |
+
instrumental: Whether to generate instrumental music (no vocals)
|
| 892 |
+
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
|
| 893 |
+
If provided, the model will be constrained to generate lyrics in this language.
|
| 894 |
+
If None or "unknown", no language constraint is applied.
|
| 895 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 896 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 897 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 898 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 899 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 900 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 901 |
+
|
| 902 |
+
Returns:
|
| 903 |
+
CreateSampleResult with generated sample fields and status
|
| 904 |
+
|
| 905 |
+
Example:
|
| 906 |
+
>>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
|
| 907 |
+
>>> if result.success:
|
| 908 |
+
... print(f"Caption: {result.caption}")
|
| 909 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 910 |
+
... print(f"BPM: {result.bpm}")
|
| 911 |
+
"""
|
| 912 |
+
import torch
|
| 913 |
+
# Debug logging for ZeroGPU diagnosis
|
| 914 |
+
logger.info(f"[create_sample Debug] Entry: IS_HUGGINGFACE_SPACE={IS_HUGGINGFACE_SPACE}")
|
| 915 |
+
logger.info(f"[create_sample Debug] torch.cuda.is_available()={torch.cuda.is_available()}")
|
| 916 |
+
if torch.cuda.is_available():
|
| 917 |
+
logger.info(f"[create_sample Debug] torch.cuda.current_device()={torch.cuda.current_device()}")
|
| 918 |
+
logger.info(f"[create_sample Debug] llm_handler.device={llm_handler.device}, llm_handler.offload_to_cpu={llm_handler.offload_to_cpu}")
|
| 919 |
+
if llm_handler.llm is not None:
|
| 920 |
+
try:
|
| 921 |
+
logger.info(f"[create_sample Debug] Model device: {next(llm_handler.llm.parameters()).device}")
|
| 922 |
+
except Exception as e:
|
| 923 |
+
logger.info(f"[create_sample Debug] Could not get model device: {e}")
|
| 924 |
+
|
| 925 |
+
# Check if LLM is initialized
|
| 926 |
+
if not llm_handler.llm_initialized:
|
| 927 |
+
return CreateSampleResult(
|
| 928 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 929 |
+
success=False,
|
| 930 |
+
error="LLM not initialized",
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
try:
|
| 934 |
+
# Call LLM to create sample
|
| 935 |
+
metadata, status = llm_handler.create_sample_from_query(
|
| 936 |
+
query=query,
|
| 937 |
+
instrumental=instrumental,
|
| 938 |
+
vocal_language=vocal_language,
|
| 939 |
+
temperature=temperature,
|
| 940 |
+
top_k=top_k,
|
| 941 |
+
top_p=top_p,
|
| 942 |
+
repetition_penalty=repetition_penalty,
|
| 943 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 944 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
# Check if LLM returned empty metadata (error case)
|
| 948 |
+
if not metadata:
|
| 949 |
+
return CreateSampleResult(
|
| 950 |
+
status_message=status or "Failed to create sample",
|
| 951 |
+
success=False,
|
| 952 |
+
error=status or "Empty metadata returned",
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
# Extract and convert fields
|
| 956 |
+
caption = metadata.get('caption', '')
|
| 957 |
+
lyrics = metadata.get('lyrics', '')
|
| 958 |
+
keyscale = metadata.get('keyscale', '')
|
| 959 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 960 |
+
timesignature = metadata.get('timesignature', '')
|
| 961 |
+
is_instrumental = metadata.get('instrumental', instrumental)
|
| 962 |
+
|
| 963 |
+
# Convert BPM to int
|
| 964 |
+
bpm = None
|
| 965 |
+
bpm_value = metadata.get('bpm')
|
| 966 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 967 |
+
try:
|
| 968 |
+
bpm = int(bpm_value)
|
| 969 |
+
except (ValueError, TypeError):
|
| 970 |
+
pass
|
| 971 |
+
|
| 972 |
+
# Convert duration to float
|
| 973 |
+
duration = None
|
| 974 |
+
duration_value = metadata.get('duration')
|
| 975 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 976 |
+
try:
|
| 977 |
+
duration = float(duration_value)
|
| 978 |
+
except (ValueError, TypeError):
|
| 979 |
+
pass
|
| 980 |
+
|
| 981 |
+
# Clean up N/A values
|
| 982 |
+
if keyscale == 'N/A':
|
| 983 |
+
keyscale = ''
|
| 984 |
+
if language == 'N/A':
|
| 985 |
+
language = ''
|
| 986 |
+
if timesignature == 'N/A':
|
| 987 |
+
timesignature = ''
|
| 988 |
+
|
| 989 |
+
return CreateSampleResult(
|
| 990 |
+
caption=caption,
|
| 991 |
+
lyrics=lyrics,
|
| 992 |
+
bpm=bpm,
|
| 993 |
+
duration=duration,
|
| 994 |
+
keyscale=keyscale,
|
| 995 |
+
language=language,
|
| 996 |
+
timesignature=timesignature,
|
| 997 |
+
instrumental=is_instrumental,
|
| 998 |
+
status_message=status,
|
| 999 |
+
success=True,
|
| 1000 |
+
error=None,
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
except Exception as e:
|
| 1004 |
+
logger.exception("Sample creation failed")
|
| 1005 |
+
return CreateSampleResult(
|
| 1006 |
+
status_message=f"Error: {str(e)}",
|
| 1007 |
+
success=False,
|
| 1008 |
+
error=str(e),
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
@dataclass
|
| 1013 |
+
class FormatSampleResult:
|
| 1014 |
+
"""Result of formatting user-provided caption and lyrics.
|
| 1015 |
+
|
| 1016 |
+
This is used by the "Format" feature where users provide caption and lyrics,
|
| 1017 |
+
and the LLM formats them into structured music metadata and an enhanced description.
|
| 1018 |
+
|
| 1019 |
+
Attributes:
|
| 1020 |
+
# Metadata Fields
|
| 1021 |
+
caption: Enhanced/formatted music description/caption
|
| 1022 |
+
lyrics: Formatted lyrics (may be same as input or reformatted)
|
| 1023 |
+
bpm: Beats per minute (None if not detected)
|
| 1024 |
+
duration: Duration in seconds (None if not detected)
|
| 1025 |
+
keyscale: Musical key (e.g., "C Major")
|
| 1026 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 1027 |
+
timesignature: Time signature (e.g., "4")
|
| 1028 |
+
|
| 1029 |
+
# Status
|
| 1030 |
+
status_message: Status message from formatting
|
| 1031 |
+
success: Whether formatting completed successfully
|
| 1032 |
+
error: Error message if formatting failed
|
| 1033 |
+
"""
|
| 1034 |
+
# Metadata Fields
|
| 1035 |
+
caption: str = ""
|
| 1036 |
+
lyrics: str = ""
|
| 1037 |
+
bpm: Optional[int] = None
|
| 1038 |
+
duration: Optional[float] = None
|
| 1039 |
+
keyscale: str = ""
|
| 1040 |
+
language: str = ""
|
| 1041 |
+
timesignature: str = ""
|
| 1042 |
+
|
| 1043 |
+
# Status
|
| 1044 |
+
status_message: str = ""
|
| 1045 |
+
success: bool = True
|
| 1046 |
+
error: Optional[str] = None
|
| 1047 |
+
|
| 1048 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1049 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 1050 |
+
return asdict(self)
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
def format_sample(
|
| 1054 |
+
llm_handler,
|
| 1055 |
+
caption: str,
|
| 1056 |
+
lyrics: str,
|
| 1057 |
+
user_metadata: Optional[Dict[str, Any]] = None,
|
| 1058 |
+
temperature: float = 0.85,
|
| 1059 |
+
top_k: Optional[int] = None,
|
| 1060 |
+
top_p: Optional[float] = None,
|
| 1061 |
+
repetition_penalty: float = 1.0,
|
| 1062 |
+
use_constrained_decoding: bool = True,
|
| 1063 |
+
constrained_decoding_debug: bool = False,
|
| 1064 |
+
) -> FormatSampleResult:
|
| 1065 |
+
"""Format user-provided caption and lyrics using the 5Hz Language Model.
|
| 1066 |
+
|
| 1067 |
+
This function takes user input (caption and lyrics) and generates structured
|
| 1068 |
+
music metadata including an enhanced caption, BPM, duration, key, language,
|
| 1069 |
+
and time signature.
|
| 1070 |
+
|
| 1071 |
+
If user_metadata is provided, those values will be used to constrain the
|
| 1072 |
+
decoding, ensuring the output matches user-specified values.
|
| 1073 |
+
|
| 1074 |
+
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 1075 |
+
|
| 1076 |
+
Args:
|
| 1077 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 1078 |
+
caption: User's caption/description (e.g., "Latin pop, reggaeton")
|
| 1079 |
+
lyrics: User's lyrics with structure tags
|
| 1080 |
+
user_metadata: Optional dict with user-provided metadata to constrain decoding.
|
| 1081 |
+
Supported keys: bpm, duration, keyscale, timesignature, language
|
| 1082 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 1083 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 1084 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 1085 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 1086 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 1087 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 1088 |
+
|
| 1089 |
+
Returns:
|
| 1090 |
+
FormatSampleResult with formatted metadata fields and status
|
| 1091 |
+
|
| 1092 |
+
Example:
|
| 1093 |
+
>>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
|
| 1094 |
+
>>> if result.success:
|
| 1095 |
+
... print(f"Caption: {result.caption}")
|
| 1096 |
+
... print(f"BPM: {result.bpm}")
|
| 1097 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 1098 |
+
"""
|
| 1099 |
+
# Check if LLM is initialized
|
| 1100 |
+
if not llm_handler.llm_initialized:
|
| 1101 |
+
return FormatSampleResult(
|
| 1102 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 1103 |
+
success=False,
|
| 1104 |
+
error="LLM not initialized",
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
try:
|
| 1108 |
+
# Call LLM formatting
|
| 1109 |
+
metadata, status = llm_handler.format_sample_from_input(
|
| 1110 |
+
caption=caption,
|
| 1111 |
+
lyrics=lyrics,
|
| 1112 |
+
user_metadata=user_metadata,
|
| 1113 |
+
temperature=temperature,
|
| 1114 |
+
top_k=top_k,
|
| 1115 |
+
top_p=top_p,
|
| 1116 |
+
repetition_penalty=repetition_penalty,
|
| 1117 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1118 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
# Check if LLM returned empty metadata (error case)
|
| 1122 |
+
if not metadata:
|
| 1123 |
+
return FormatSampleResult(
|
| 1124 |
+
status_message=status or "Failed to format input",
|
| 1125 |
+
success=False,
|
| 1126 |
+
error=status or "Empty metadata returned",
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
# Extract and convert fields
|
| 1130 |
+
result_caption = metadata.get('caption', '')
|
| 1131 |
+
result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
|
| 1132 |
+
keyscale = metadata.get('keyscale', '')
|
| 1133 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 1134 |
+
timesignature = metadata.get('timesignature', '')
|
| 1135 |
+
|
| 1136 |
+
# Convert BPM to int
|
| 1137 |
+
bpm = None
|
| 1138 |
+
bpm_value = metadata.get('bpm')
|
| 1139 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 1140 |
+
try:
|
| 1141 |
+
bpm = int(bpm_value)
|
| 1142 |
+
except (ValueError, TypeError):
|
| 1143 |
+
pass
|
| 1144 |
+
|
| 1145 |
+
# Convert duration to float
|
| 1146 |
+
duration = None
|
| 1147 |
+
duration_value = metadata.get('duration')
|
| 1148 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 1149 |
+
try:
|
| 1150 |
+
duration = float(duration_value)
|
| 1151 |
+
except (ValueError, TypeError):
|
| 1152 |
+
pass
|
| 1153 |
+
|
| 1154 |
+
# Clean up N/A values
|
| 1155 |
+
if keyscale == 'N/A':
|
| 1156 |
+
keyscale = ''
|
| 1157 |
+
if language == 'N/A':
|
| 1158 |
+
language = ''
|
| 1159 |
+
if timesignature == 'N/A':
|
| 1160 |
+
timesignature = ''
|
| 1161 |
+
|
| 1162 |
+
return FormatSampleResult(
|
| 1163 |
+
caption=result_caption,
|
| 1164 |
+
lyrics=result_lyrics,
|
| 1165 |
+
bpm=bpm,
|
| 1166 |
+
duration=duration,
|
| 1167 |
+
keyscale=keyscale,
|
| 1168 |
+
language=language,
|
| 1169 |
+
timesignature=timesignature,
|
| 1170 |
+
status_message=status,
|
| 1171 |
+
success=True,
|
| 1172 |
+
error=None,
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
except Exception as e:
|
| 1176 |
+
logger.exception("Format sample failed")
|
| 1177 |
+
return FormatSampleResult(
|
| 1178 |
+
status_message=f"Error: {str(e)}",
|
| 1179 |
+
success=False,
|
| 1180 |
+
error=str(e),
|
| 1181 |
+
)
|
acestep/llm_inference.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/local_cache.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local cache module to replace Redis
|
| 2 |
+
|
| 3 |
+
Uses diskcache as backend, provides Redis-compatible API.
|
| 4 |
+
Supports persistent storage and TTL expiration.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from typing import Any, Optional
|
| 10 |
+
from threading import Lock
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from diskcache import Cache
|
| 14 |
+
HAS_DISKCACHE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
HAS_DISKCACHE = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LocalCache:
|
| 20 |
+
"""
|
| 21 |
+
Local cache implementation with Redis-compatible API.
|
| 22 |
+
Uses diskcache as backend, supports persistence and TTL.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
_instance = None
|
| 26 |
+
_lock = Lock()
|
| 27 |
+
|
| 28 |
+
def __new__(cls, cache_dir: Optional[str] = None):
|
| 29 |
+
"""Singleton pattern"""
|
| 30 |
+
if cls._instance is None:
|
| 31 |
+
with cls._lock:
|
| 32 |
+
if cls._instance is None:
|
| 33 |
+
cls._instance = super().__new__(cls)
|
| 34 |
+
cls._instance._initialized = False
|
| 35 |
+
return cls._instance
|
| 36 |
+
|
| 37 |
+
def __init__(self, cache_dir: Optional[str] = None):
|
| 38 |
+
if getattr(self, '_initialized', False):
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
if not HAS_DISKCACHE:
|
| 42 |
+
raise ImportError(
|
| 43 |
+
"diskcache not installed. Run: pip install diskcache"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if cache_dir is None:
|
| 47 |
+
cache_dir = os.path.join(
|
| 48 |
+
os.path.dirname(os.path.dirname(__file__)),
|
| 49 |
+
".cache",
|
| 50 |
+
"local_redis"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 54 |
+
self._cache = Cache(cache_dir)
|
| 55 |
+
self._initialized = True
|
| 56 |
+
|
| 57 |
+
def set(self, name: str, value: Any, ex: Optional[int] = None) -> bool:
|
| 58 |
+
"""
|
| 59 |
+
Set key-value pair
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
name: Key name
|
| 63 |
+
value: Value (auto-serialize dict/list)
|
| 64 |
+
ex: Expiration time (seconds)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
bool: Success status
|
| 68 |
+
"""
|
| 69 |
+
if isinstance(value, (dict, list)):
|
| 70 |
+
value = json.dumps(value, ensure_ascii=False)
|
| 71 |
+
self._cache.set(name, value, expire=ex)
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
def get(self, name: str) -> Optional[str]:
|
| 75 |
+
"""Get value"""
|
| 76 |
+
return self._cache.get(name)
|
| 77 |
+
|
| 78 |
+
def delete(self, name: str) -> int:
|
| 79 |
+
"""Delete key, returns number of deleted items"""
|
| 80 |
+
return 1 if self._cache.delete(name) else 0
|
| 81 |
+
|
| 82 |
+
def exists(self, name: str) -> bool:
|
| 83 |
+
"""Check if key exists"""
|
| 84 |
+
return name in self._cache
|
| 85 |
+
|
| 86 |
+
def keys(self, pattern: str = "*") -> list:
|
| 87 |
+
"""
|
| 88 |
+
Get list of matching keys
|
| 89 |
+
Note: Simplified implementation, only supports prefix and full matching
|
| 90 |
+
"""
|
| 91 |
+
if pattern == "*":
|
| 92 |
+
return list(self._cache.iterkeys())
|
| 93 |
+
|
| 94 |
+
prefix = pattern.rstrip("*")
|
| 95 |
+
return [k for k in self._cache.iterkeys() if k.startswith(prefix)]
|
| 96 |
+
|
| 97 |
+
def expire(self, name: str, seconds: int) -> bool:
|
| 98 |
+
"""Set key expiration time"""
|
| 99 |
+
value = self._cache.get(name)
|
| 100 |
+
if value is not None:
|
| 101 |
+
self._cache.set(name, value, expire=seconds)
|
| 102 |
+
return True
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def ttl(self, name: str) -> int:
|
| 106 |
+
"""
|
| 107 |
+
Get remaining time to live (seconds)
|
| 108 |
+
Note: diskcache does not directly support TTL queries
|
| 109 |
+
"""
|
| 110 |
+
if name in self._cache:
|
| 111 |
+
return -1 # Exists but TTL unknown
|
| 112 |
+
return -2 # Key does not exist
|
| 113 |
+
|
| 114 |
+
def close(self):
|
| 115 |
+
"""Close cache connection"""
|
| 116 |
+
if hasattr(self, '_cache'):
|
| 117 |
+
self._cache.close()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Lazily initialized global instance
|
| 121 |
+
_local_cache: Optional[LocalCache] = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_local_cache(cache_dir: Optional[str] = None) -> LocalCache:
|
| 125 |
+
"""Get local cache instance"""
|
| 126 |
+
global _local_cache
|
| 127 |
+
if _local_cache is None:
|
| 128 |
+
_local_cache = LocalCache(cache_dir)
|
| 129 |
+
return _local_cache
|
acestep/test_time_scaling.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test-Time Scaling Module
|
| 3 |
+
Implements perplexity-based scoring for generated audio codes
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from typing import Tuple, Optional, Dict, Any, List
|
| 8 |
+
from loguru import logger
|
| 9 |
+
import yaml
|
| 10 |
+
import math
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
|
| 15 |
+
"""
|
| 16 |
+
Calculate Pointwise Mutual Information (PMI) score.
|
| 17 |
+
|
| 18 |
+
PMI = log P(condition|codes) - log P(condition)
|
| 19 |
+
= log [P(codes|condition) / P(codes)]
|
| 20 |
+
|
| 21 |
+
This removes the bias from P(condition) and measures how much the codes
|
| 22 |
+
improve our ability to predict the condition.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
log_prob_conditional: Average log probability of condition given codes
|
| 26 |
+
log_prob_unconditional: Average log probability of condition without codes
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
PMI score (higher is better, can be positive or negative)
|
| 30 |
+
- Positive: codes improve prediction → good match
|
| 31 |
+
- Zero: codes don't help → no correlation
|
| 32 |
+
- Negative: codes hurt prediction → poor match
|
| 33 |
+
"""
|
| 34 |
+
return log_prob_conditional - log_prob_unconditional
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
|
| 38 |
+
"""
|
| 39 |
+
Convert PMI score to normalized [0, 1] range using sigmoid function.
|
| 40 |
+
|
| 41 |
+
score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
pmi: PMI score (can be positive or negative)
|
| 45 |
+
scale: Scale parameter to control sensitivity (default 0.1)
|
| 46 |
+
- Smaller scale: more sensitive to PMI changes
|
| 47 |
+
- Larger scale: less sensitive to PMI changes
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Normalized score in [0, 1] range, where:
|
| 51 |
+
- PMI > 0 → score > 0.5 (good match)
|
| 52 |
+
- PMI = 0 → score = 0.5 (neutral)
|
| 53 |
+
- PMI < 0 → score < 0.5 (poor match)
|
| 54 |
+
|
| 55 |
+
Examples (scale=1.0):
|
| 56 |
+
PMI=2.0 → score≈0.88 (excellent)
|
| 57 |
+
PMI=1.0 → score≈0.73 (good)
|
| 58 |
+
PMI=0.0 → score=0.50 (neutral)
|
| 59 |
+
PMI=-1.0 → score≈0.27 (poor)
|
| 60 |
+
PMI=-2.0 → score≈0.12 (bad)
|
| 61 |
+
"""
|
| 62 |
+
return 1.0 / (1.0 + math.exp(-pmi / scale))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
|
| 66 |
+
target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
llm_handler: The handler containing the model and tokenizer.
|
| 70 |
+
formatted_prompt: The input context.
|
| 71 |
+
target_text: The text we want to calculate probability/recall for.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Tuple of (target_logits, target_ids)
|
| 75 |
+
- target_logits: Logits used to predict the target tokens.
|
| 76 |
+
- target_ids: The ground truth token IDs of the target.
|
| 77 |
+
"""
|
| 78 |
+
model = llm_handler.get_hf_model_for_scoring()
|
| 79 |
+
tokenizer = llm_handler.llm_tokenizer
|
| 80 |
+
device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
|
| 81 |
+
|
| 82 |
+
# 1. Tokenize prompt ONLY to get its length (used for slicing later).
|
| 83 |
+
# We must ensure special tokens are added to count the offset correctly.
|
| 84 |
+
prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
|
| 85 |
+
prompt_len = prompt_tokens_temp['input_ids'].shape[1]
|
| 86 |
+
|
| 87 |
+
# 2. Tokenize the FULL text (Prompt + Target).
|
| 88 |
+
# This ensures subword merging at boundaries is handled correctly by the tokenizer.
|
| 89 |
+
full_text = formatted_prompt + target_text
|
| 90 |
+
full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)
|
| 91 |
+
|
| 92 |
+
input_ids = full_tokens['input_ids']
|
| 93 |
+
|
| 94 |
+
# Safety check: if target was empty or truncated entirely
|
| 95 |
+
if input_ids.shape[1] <= prompt_len:
|
| 96 |
+
return torch.empty(0, device=device), torch.empty(0, device=device)
|
| 97 |
+
|
| 98 |
+
# 3. Forward Pass (Teacher Forcing)
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
with llm_handler._load_model_context():
|
| 101 |
+
outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
|
| 102 |
+
all_logits = outputs.logits # [1, seq_len, vocab_size]
|
| 103 |
+
|
| 104 |
+
# 4. Extract Logits and Labels
|
| 105 |
+
# We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
|
| 106 |
+
# Target starts at index `prompt_len`.
|
| 107 |
+
# So we need logits from `prompt_len - 1` up to the second to last position.
|
| 108 |
+
|
| 109 |
+
target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size]
|
| 110 |
+
target_ids = input_ids[0, prompt_len:] # [target_len]
|
| 111 |
+
|
| 112 |
+
return target_logits, target_ids
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ==============================================================================
|
| 116 |
+
# Scoring Logic
|
| 117 |
+
# ==============================================================================
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _calculate_topk_recall(llm_handler,
|
| 121 |
+
formatted_prompt: str,
|
| 122 |
+
target_text: str,
|
| 123 |
+
topk: int = 10) -> Tuple[float, Dict[int, float]]:
|
| 124 |
+
"""
|
| 125 |
+
Calculate top-k recall for target text given prompt.
|
| 126 |
+
Checks if the ground truth token is within the top-k probabilities at each step.
|
| 127 |
+
"""
|
| 128 |
+
# Use the fixed helper to get aligned logits/labels
|
| 129 |
+
pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
|
| 130 |
+
|
| 131 |
+
if target_ids.shape[0] == 0:
|
| 132 |
+
return 0.0, {}
|
| 133 |
+
|
| 134 |
+
target_len = target_ids.shape[0]
|
| 135 |
+
|
| 136 |
+
# Get top-k indices for all positions at once
|
| 137 |
+
# topk_indices: [target_len, topk]
|
| 138 |
+
_, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)
|
| 139 |
+
|
| 140 |
+
recall_per_k = {}
|
| 141 |
+
position_scores = []
|
| 142 |
+
|
| 143 |
+
# Convert to list for faster CPU iteration
|
| 144 |
+
target_ids_list = target_ids.tolist()
|
| 145 |
+
topk_indices_list = topk_indices.tolist()
|
| 146 |
+
|
| 147 |
+
for k in range(1, topk + 1):
|
| 148 |
+
hits = 0
|
| 149 |
+
for pos in range(target_len):
|
| 150 |
+
gt_token = target_ids_list[pos]
|
| 151 |
+
# Check the top-k slice
|
| 152 |
+
topk_at_pos = topk_indices_list[pos][:k]
|
| 153 |
+
|
| 154 |
+
if gt_token in topk_at_pos:
|
| 155 |
+
hits += 1
|
| 156 |
+
# Calculate position-weighted score only once (when k=topk)
|
| 157 |
+
if k == topk:
|
| 158 |
+
rank = topk_at_pos.index(gt_token) + 1
|
| 159 |
+
# Rank 1 = 1.0, Rank k = small positive
|
| 160 |
+
position_weight = 1.0 - (rank - 1) / topk
|
| 161 |
+
position_scores.append(position_weight)
|
| 162 |
+
|
| 163 |
+
recall_per_k[k] = hits / target_len if target_len > 0 else 0.0
|
| 164 |
+
|
| 165 |
+
# Fill scores for positions where GT was NOT in top-k
|
| 166 |
+
while len(position_scores) < target_len:
|
| 167 |
+
position_scores.append(0.0)
|
| 168 |
+
|
| 169 |
+
average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0
|
| 170 |
+
|
| 171 |
+
return average_recall, recall_per_k
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _calculate_metadata_recall(llm_handler,
|
| 175 |
+
formatted_prompt: str,
|
| 176 |
+
fields_dict: Dict[str, Any],
|
| 177 |
+
topk: int = 10) -> Dict[str, float]:
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
fields_dict: Dictionary of {field_name: field_value}
|
| 181 |
+
"""
|
| 182 |
+
if not fields_dict:
|
| 183 |
+
return {}
|
| 184 |
+
|
| 185 |
+
field_scores = {}
|
| 186 |
+
|
| 187 |
+
for field_name in sorted(fields_dict.keys()):
|
| 188 |
+
# Construct target text for this specific field
|
| 189 |
+
# e.g. <think>\nbpm: 120\n</think>\n
|
| 190 |
+
field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
|
| 191 |
+
field_target_text = f"<think>\n{field_yaml}\n</think>\n"
|
| 192 |
+
|
| 193 |
+
# Calculate recall using the robust logic
|
| 194 |
+
avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)
|
| 195 |
+
|
| 196 |
+
field_scores[field_name] = avg_score
|
| 197 |
+
logger.debug(f"Recall for {field_name}: {avg_score:.4f}")
|
| 198 |
+
|
| 199 |
+
return field_scores
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _calculate_log_prob(
|
| 203 |
+
llm_handler,
|
| 204 |
+
formatted_prompt: str,
|
| 205 |
+
target_text: str,
|
| 206 |
+
temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring
|
| 207 |
+
) -> float:
|
| 208 |
+
"""
|
| 209 |
+
Calculate average log probability of target text given prompt.
|
| 210 |
+
"""
|
| 211 |
+
pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
|
| 212 |
+
|
| 213 |
+
if target_ids.shape[0] == 0:
|
| 214 |
+
return float('-inf')
|
| 215 |
+
|
| 216 |
+
# FIX: Do not divide by temperature.
|
| 217 |
+
# Log-probability for PMI/Perplexity should be exact.
|
| 218 |
+
|
| 219 |
+
# Calculate log probabilities (log_softmax)
|
| 220 |
+
log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
|
| 221 |
+
|
| 222 |
+
# Gather log probabilities of the ground truth tokens
|
| 223 |
+
target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]
|
| 224 |
+
|
| 225 |
+
# Return average log probability
|
| 226 |
+
mean_log_prob = target_log_probs.mean().item()
|
| 227 |
+
|
| 228 |
+
return mean_log_prob
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def calculate_reward_score(
|
| 232 |
+
scores: Dict[str, float],
|
| 233 |
+
weights_config: Optional[Dict[str, float]] = None
|
| 234 |
+
) -> Tuple[float, str]:
|
| 235 |
+
"""
|
| 236 |
+
Reward Model Calculator: Computes a final reward based on user priorities.
|
| 237 |
+
|
| 238 |
+
Priority Logic:
|
| 239 |
+
1. Caption (Highest): The overall vibe/style must match.
|
| 240 |
+
2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
|
| 241 |
+
3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
|
| 242 |
+
|
| 243 |
+
Strategy: Dynamic Weighted Sum
|
| 244 |
+
- Metadata fields are aggregated into a single 'metadata' score first.
|
| 245 |
+
- Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
|
| 249 |
+
weights_config: Optional custom weights. Defaults to:
|
| 250 |
+
Caption (50%), Lyrics (30%), Metadata (20%).
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
final_reward: The calculated reward score (0.0 - 1.0).
|
| 254 |
+
explanation: A formatted string explaining how the score was derived.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
# 1. Default Preference Configuration
|
| 258 |
+
# These weights determine the relative importance of each component.
|
| 259 |
+
if weights_config is None:
|
| 260 |
+
weights_config = {
|
| 261 |
+
'caption': 0.50, # High priority: Style/Vibe
|
| 262 |
+
'lyrics': 0.30, # Medium priority: Content
|
| 263 |
+
'metadata': 0.20 # Low priority: Technical details
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# 2. Extract and Group Scores
|
| 267 |
+
# Caption and Lyrics are standalone high-level features.
|
| 268 |
+
caption_score = scores.get('caption')
|
| 269 |
+
lyrics_score = scores.get('lyrics')
|
| 270 |
+
|
| 271 |
+
# Metadata fields (bpm, key, duration, etc.) are aggregated.
|
| 272 |
+
# We treat them as a single "Technical Score" to prevent them from
|
| 273 |
+
# diluting the weight of Caption/Lyrics simply by having many fields.
|
| 274 |
+
meta_scores_list = [
|
| 275 |
+
val for key, val in scores.items()
|
| 276 |
+
if key not in ['caption', 'lyrics']
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
# Calculate average of all metadata fields (if any exist)
|
| 280 |
+
meta_aggregate_score = None
|
| 281 |
+
if meta_scores_list:
|
| 282 |
+
meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
|
| 283 |
+
|
| 284 |
+
# 3. specific Active Components & Dynamic Weighting
|
| 285 |
+
# We only include components that actually exist in this generation.
|
| 286 |
+
active_components = {}
|
| 287 |
+
|
| 288 |
+
if caption_score is not None:
|
| 289 |
+
active_components['caption'] = (caption_score, weights_config['caption'])
|
| 290 |
+
|
| 291 |
+
if lyrics_score is not None:
|
| 292 |
+
active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
|
| 293 |
+
|
| 294 |
+
if meta_aggregate_score is not None:
|
| 295 |
+
active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
|
| 296 |
+
|
| 297 |
+
# 4. Calculate Final Weighted Score
|
| 298 |
+
total_base_weight = sum(w for _, w in active_components.values())
|
| 299 |
+
total_score = 0.0
|
| 300 |
+
|
| 301 |
+
breakdown_lines = []
|
| 302 |
+
|
| 303 |
+
if total_base_weight == 0:
|
| 304 |
+
return 0.0, "❌ No valid scores available to calculate reward."
|
| 305 |
+
|
| 306 |
+
# Sort by weight (importance) for display
|
| 307 |
+
sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
|
| 308 |
+
|
| 309 |
+
for name, (score, base_weight) in sorted_components:
|
| 310 |
+
# Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
|
| 311 |
+
normalized_weight = base_weight / total_base_weight
|
| 312 |
+
weighted_contribution = score * normalized_weight
|
| 313 |
+
total_score += weighted_contribution
|
| 314 |
+
|
| 315 |
+
breakdown_lines.append(
|
| 316 |
+
f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
|
| 317 |
+
f"-> Contrib: +{weighted_contribution:.4f}"
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
return total_score, "\n".join(breakdown_lines)
|
| 321 |
+
|
| 322 |
+
# ==============================================================================
|
| 323 |
+
# Main Public API
|
| 324 |
+
# ==============================================================================
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def calculate_pmi_score_per_condition(
|
| 328 |
+
llm_handler,
|
| 329 |
+
audio_codes: str,
|
| 330 |
+
caption: str = "",
|
| 331 |
+
lyrics: str = "",
|
| 332 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 333 |
+
temperature: float = 1.0,
|
| 334 |
+
topk: int = 10,
|
| 335 |
+
score_scale: float = 0.1,
|
| 336 |
+
) -> Tuple[Dict[str, float], float, str]:
|
| 337 |
+
"""
|
| 338 |
+
Calculate quality score separately for each condition.
|
| 339 |
+
- Metadata: Uses Top-k Recall.
|
| 340 |
+
- Caption/Lyrics: Uses PMI (Normalized).
|
| 341 |
+
"""
|
| 342 |
+
if not llm_handler.llm_initialized:
|
| 343 |
+
return {}, 0.0, "❌ LLM not initialized"
|
| 344 |
+
|
| 345 |
+
if not audio_codes or not audio_codes.strip():
|
| 346 |
+
return {}, 0.0, "❌ No audio codes provided"
|
| 347 |
+
|
| 348 |
+
if "caption" not in metadata:
|
| 349 |
+
metadata['caption'] = caption
|
| 350 |
+
|
| 351 |
+
formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
|
| 352 |
+
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
|
| 353 |
+
try:
|
| 354 |
+
# 1. Calculate Recall for Metadata Fields
|
| 355 |
+
if metadata and isinstance(metadata, dict):
|
| 356 |
+
scores = {}
|
| 357 |
+
# Define which fields use which metric
|
| 358 |
+
metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
|
| 359 |
+
metadata_pmi_keys = ['caption']
|
| 360 |
+
for key in metadata_recall_keys:
|
| 361 |
+
if key in metadata and metadata[key] is not None:
|
| 362 |
+
recall_metadata = {key: metadata[key]}
|
| 363 |
+
field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
|
| 364 |
+
scores.update(field_scores)
|
| 365 |
+
|
| 366 |
+
# 2. Calculate PMI for Caption
|
| 367 |
+
for key in metadata_pmi_keys:
|
| 368 |
+
if key in metadata and metadata[key] is not None:
|
| 369 |
+
cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
|
| 370 |
+
target_text = f"<think>\n{cot_yaml}\n</think>\n"
|
| 371 |
+
|
| 372 |
+
log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
|
| 373 |
+
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
|
| 374 |
+
|
| 375 |
+
pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
|
| 376 |
+
scores[key] = pmi_normalized
|
| 377 |
+
|
| 378 |
+
# 3. Calculate PMI for Lyrics
|
| 379 |
+
if lyrics:
|
| 380 |
+
target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"
|
| 381 |
+
|
| 382 |
+
log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
|
| 383 |
+
|
| 384 |
+
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
|
| 385 |
+
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
|
| 386 |
+
|
| 387 |
+
scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
|
| 388 |
+
|
| 389 |
+
if not scores:
|
| 390 |
+
return {}, 0.0, "❌ No conditions to evaluate"
|
| 391 |
+
|
| 392 |
+
# 4. Global Score
|
| 393 |
+
global_score = sum(scores.values()) / len(scores)
|
| 394 |
+
global_score, breakdown_lines = calculate_reward_score(scores)
|
| 395 |
+
|
| 396 |
+
# Status Message
|
| 397 |
+
status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
|
| 398 |
+
for key, score in sorted(scores.items()):
|
| 399 |
+
metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
|
| 400 |
+
status_lines.append(f" {key}: {score:.4f} ({metric})")
|
| 401 |
+
status = "\n".join(status_lines)
|
| 402 |
+
logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
|
| 403 |
+
return scores, global_score, status
|
| 404 |
+
|
| 405 |
+
except Exception as e:
|
| 406 |
+
import traceback
|
| 407 |
+
error_msg = f"❌ Error: {str(e)}"
|
| 408 |
+
logger.error(error_msg)
|
| 409 |
+
logger.error(traceback.format_exc())
|
| 410 |
+
return {}, float('-inf'), error_msg
|
acestep/third_parts/nano-vllm/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Xingkai Yu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
acestep/third_parts/nano-vllm/README.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img width="300" src="assets/logo.png">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
<p align="center">
|
| 6 |
+
<a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+
# Nano-vLLM
|
| 10 |
+
|
| 11 |
+
A lightweight vLLM implementation built from scratch.
|
| 12 |
+
|
| 13 |
+
## Key Features
|
| 14 |
+
|
| 15 |
+
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
|
| 16 |
+
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
|
| 17 |
+
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
|
| 18 |
+
|
| 19 |
+
## Installation
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Model Download
|
| 26 |
+
|
| 27 |
+
To download the model weights manually, use the following command:
|
| 28 |
+
```bash
|
| 29 |
+
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
|
| 30 |
+
--local-dir ~/huggingface/Qwen3-0.6B/ \
|
| 31 |
+
--local-dir-use-symlinks False
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Quick Start
|
| 35 |
+
|
| 36 |
+
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
|
| 37 |
+
```python
|
| 38 |
+
from nanovllm import LLM, SamplingParams
|
| 39 |
+
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
| 40 |
+
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
| 41 |
+
prompts = ["Hello, Nano-vLLM."]
|
| 42 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 43 |
+
outputs[0]["text"]
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Benchmark
|
| 47 |
+
|
| 48 |
+
See `bench.py` for benchmark.
|
| 49 |
+
|
| 50 |
+
**Test Configuration:**
|
| 51 |
+
- Hardware: RTX 4070 Laptop (8GB)
|
| 52 |
+
- Model: Qwen3-0.6B
|
| 53 |
+
- Total Requests: 256 sequences
|
| 54 |
+
- Input Length: Randomly sampled between 100–1024 tokens
|
| 55 |
+
- Output Length: Randomly sampled between 100–1024 tokens
|
| 56 |
+
|
| 57 |
+
**Performance Results:**
|
| 58 |
+
| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
|
| 59 |
+
|----------------|-------------|----------|-----------------------|
|
| 60 |
+
| vLLM | 133,966 | 98.37 | 1361.84 |
|
| 61 |
+
| Nano-vLLM | 133,966 | 93.41 | 1434.13 |
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
## Star History
|
| 65 |
+
|
| 66 |
+
[](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
|
acestep/third_parts/nano-vllm/bench.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from random import randint, seed
|
| 4 |
+
from nanovllm import LLM, SamplingParams
|
| 5 |
+
# from vllm import LLM, SamplingParams
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
seed(0)
|
| 10 |
+
num_seqs = 256
|
| 11 |
+
max_input_len = 1024
|
| 12 |
+
max_ouput_len = 1024
|
| 13 |
+
|
| 14 |
+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
| 15 |
+
llm = LLM(path, enforce_eager=False, max_model_len=4096)
|
| 16 |
+
|
| 17 |
+
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
| 18 |
+
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
|
| 19 |
+
# uncomment the following line for vllm
|
| 20 |
+
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
| 21 |
+
|
| 22 |
+
llm.generate(["Benchmark: "], SamplingParams())
|
| 23 |
+
t = time.time()
|
| 24 |
+
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
| 25 |
+
t = (time.time() - t)
|
| 26 |
+
total_tokens = sum(sp.max_tokens for sp in sampling_params)
|
| 27 |
+
throughput = total_tokens / t
|
| 28 |
+
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
acestep/third_parts/nano-vllm/example.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from nanovllm import LLM, SamplingParams
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
| 9 |
+
llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
|
| 10 |
+
|
| 11 |
+
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
| 12 |
+
prompts = [
|
| 13 |
+
"introduce yourself",
|
| 14 |
+
"list all prime numbers within 100",
|
| 15 |
+
]
|
| 16 |
+
prompts = [
|
| 17 |
+
tokenizer.apply_chat_template(
|
| 18 |
+
[{"role": "user", "content": prompt}],
|
| 19 |
+
tokenize=False,
|
| 20 |
+
add_generation_prompt=True,
|
| 21 |
+
)
|
| 22 |
+
for prompt in prompts
|
| 23 |
+
]
|
| 24 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 25 |
+
|
| 26 |
+
for prompt, output in zip(prompts, outputs):
|
| 27 |
+
print("\n")
|
| 28 |
+
print(f"Prompt: {prompt!r}")
|
| 29 |
+
print(f"Completion: {output['text']!r}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main()
|
acestep/third_parts/nano-vllm/nanovllm/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nanovllm.llm import LLM
|
| 2 |
+
from nanovllm.sampling_params import SamplingParams
|
acestep/third_parts/nano-vllm/nanovllm/config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from transformers import AutoConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class Config:
|
| 8 |
+
model: str
|
| 9 |
+
max_num_batched_tokens: int = 16384
|
| 10 |
+
max_num_seqs: int = 512
|
| 11 |
+
max_model_len: int = 4096
|
| 12 |
+
gpu_memory_utilization: float = 0.9
|
| 13 |
+
tensor_parallel_size: int = 1
|
| 14 |
+
enforce_eager: bool = False
|
| 15 |
+
hf_config: AutoConfig | None = None
|
| 16 |
+
eos: int = -1
|
| 17 |
+
kvcache_block_size: int = 256
|
| 18 |
+
num_kvcache_blocks: int = -1
|
| 19 |
+
|
| 20 |
+
def __post_init__(self):
|
| 21 |
+
assert os.path.isdir(self.model)
|
| 22 |
+
assert self.kvcache_block_size % 256 == 0
|
| 23 |
+
assert 1 <= self.tensor_parallel_size <= 8
|
| 24 |
+
self.hf_config = AutoConfig.from_pretrained(self.model)
|
| 25 |
+
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
| 26 |
+
assert self.max_num_batched_tokens >= self.max_model_len
|
acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
import xxhash
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from nanovllm.engine.sequence import Sequence
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Block:
|
| 9 |
+
|
| 10 |
+
def __init__(self, block_id):
|
| 11 |
+
self.block_id = block_id
|
| 12 |
+
self.ref_count = 0
|
| 13 |
+
self.hash = -1
|
| 14 |
+
self.token_ids = []
|
| 15 |
+
|
| 16 |
+
def update(self, hash: int, token_ids: list[int]):
|
| 17 |
+
self.hash = hash
|
| 18 |
+
self.token_ids = token_ids
|
| 19 |
+
|
| 20 |
+
def reset(self):
|
| 21 |
+
self.ref_count = 1
|
| 22 |
+
self.hash = -1
|
| 23 |
+
self.token_ids = []
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BlockManager:
|
| 27 |
+
|
| 28 |
+
def __init__(self, num_blocks: int, block_size: int):
|
| 29 |
+
self.block_size = block_size
|
| 30 |
+
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
| 31 |
+
self.hash_to_block_id: dict[int, int] = dict()
|
| 32 |
+
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
| 33 |
+
self.used_block_ids: set[int] = set()
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
| 37 |
+
h = xxhash.xxh64()
|
| 38 |
+
if prefix != -1:
|
| 39 |
+
h.update(prefix.to_bytes(8, "little"))
|
| 40 |
+
h.update(np.array(token_ids).tobytes())
|
| 41 |
+
return h.intdigest()
|
| 42 |
+
|
| 43 |
+
def _allocate_block(self, block_id: int) -> Block:
|
| 44 |
+
block = self.blocks[block_id]
|
| 45 |
+
assert block.ref_count == 0
|
| 46 |
+
block.reset()
|
| 47 |
+
self.free_block_ids.remove(block_id)
|
| 48 |
+
self.used_block_ids.add(block_id)
|
| 49 |
+
return self.blocks[block_id]
|
| 50 |
+
|
| 51 |
+
def _deallocate_block(self, block_id: int) -> Block:
|
| 52 |
+
assert self.blocks[block_id].ref_count == 0
|
| 53 |
+
self.used_block_ids.remove(block_id)
|
| 54 |
+
self.free_block_ids.append(block_id)
|
| 55 |
+
|
| 56 |
+
def can_allocate(self, seq: Sequence) -> bool:
|
| 57 |
+
return len(self.free_block_ids) >= seq.num_blocks
|
| 58 |
+
|
| 59 |
+
def allocate(self, seq: Sequence):
|
| 60 |
+
assert not seq.block_table
|
| 61 |
+
h = -1
|
| 62 |
+
cache_miss = False
|
| 63 |
+
for i in range(seq.num_blocks):
|
| 64 |
+
token_ids = seq.block(i)
|
| 65 |
+
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
| 66 |
+
block_id = self.hash_to_block_id.get(h, -1)
|
| 67 |
+
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
| 68 |
+
cache_miss = True
|
| 69 |
+
if cache_miss:
|
| 70 |
+
block_id = self.free_block_ids[0]
|
| 71 |
+
block = self._allocate_block(block_id)
|
| 72 |
+
else:
|
| 73 |
+
seq.num_cached_tokens += self.block_size
|
| 74 |
+
if block_id in self.used_block_ids:
|
| 75 |
+
block = self.blocks[block_id]
|
| 76 |
+
block.ref_count += 1
|
| 77 |
+
else:
|
| 78 |
+
block = self._allocate_block(block_id)
|
| 79 |
+
if h != -1:
|
| 80 |
+
block.update(h, token_ids)
|
| 81 |
+
self.hash_to_block_id[h] = block_id
|
| 82 |
+
seq.block_table.append(block_id)
|
| 83 |
+
|
| 84 |
+
def deallocate(self, seq: Sequence):
|
| 85 |
+
for block_id in reversed(seq.block_table):
|
| 86 |
+
block = self.blocks[block_id]
|
| 87 |
+
block.ref_count -= 1
|
| 88 |
+
if block.ref_count == 0:
|
| 89 |
+
# Fix: Clean up hash_to_block_id mapping to prevent stale references
|
| 90 |
+
# This prevents CUDA illegal memory access when prefix cache tries to
|
| 91 |
+
# reuse a block_id that has already been freed
|
| 92 |
+
if block.hash != -1:
|
| 93 |
+
cached_id = self.hash_to_block_id.get(block.hash)
|
| 94 |
+
if cached_id == block_id:
|
| 95 |
+
del self.hash_to_block_id[block.hash]
|
| 96 |
+
self._deallocate_block(block_id)
|
| 97 |
+
seq.num_cached_tokens = 0
|
| 98 |
+
seq.block_table.clear()
|
| 99 |
+
|
| 100 |
+
def can_append(self, seq: Sequence) -> bool:
|
| 101 |
+
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
| 102 |
+
|
| 103 |
+
def may_append(self, seq: Sequence):
|
| 104 |
+
block_table = seq.block_table
|
| 105 |
+
last_block = self.blocks[block_table[-1]]
|
| 106 |
+
if len(seq) % self.block_size == 1:
|
| 107 |
+
assert last_block.hash != -1
|
| 108 |
+
block_id = self.free_block_ids[0]
|
| 109 |
+
self._allocate_block(block_id)
|
| 110 |
+
block_table.append(block_id)
|
| 111 |
+
elif len(seq) % self.block_size == 0:
|
| 112 |
+
assert last_block.hash == -1
|
| 113 |
+
token_ids = seq.block(seq.num_blocks-1)
|
| 114 |
+
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
| 115 |
+
h = self.compute_hash(token_ids, prefix)
|
| 116 |
+
last_block.update(h, token_ids)
|
| 117 |
+
self.hash_to_block_id[h] = last_block.block_id
|
| 118 |
+
else:
|
| 119 |
+
assert last_block.hash == -1
|
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
import threading
|
| 3 |
+
from dataclasses import fields
|
| 4 |
+
from time import perf_counter
|
| 5 |
+
from tqdm.auto import tqdm
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
import torch.multiprocessing as mp
|
| 8 |
+
|
| 9 |
+
from nanovllm.config import Config
|
| 10 |
+
from nanovllm.sampling_params import SamplingParams
|
| 11 |
+
from nanovllm.engine.sequence import Sequence
|
| 12 |
+
from nanovllm.engine.scheduler import Scheduler
|
| 13 |
+
from nanovllm.engine.model_runner import ModelRunner
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LLMEngine:
|
| 17 |
+
|
| 18 |
+
def __init__(self, model, **kwargs):
|
| 19 |
+
config_fields = {field.name for field in fields(Config)}
|
| 20 |
+
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
| 21 |
+
config = Config(model, **config_kwargs)
|
| 22 |
+
self.ps = []
|
| 23 |
+
self.events = []
|
| 24 |
+
# Thread-safety lock for generate().
|
| 25 |
+
# The scheduler, block manager, model runner, and CUDA graph buffers are all
|
| 26 |
+
# shared mutable state that is NOT thread-safe. In concurrent serving scenarios
|
| 27 |
+
# (API server with ThreadPoolExecutor, multiple queue workers, Gradio with
|
| 28 |
+
# concurrent requests), multiple threads can call generate() simultaneously.
|
| 29 |
+
# Without this lock, concurrent access corrupts scheduler state, block tables,
|
| 30 |
+
# and CUDA graph input buffers, leading to intermittent CUDA device-side
|
| 31 |
+
# assertion failures (illegal memory access in KV cache).
|
| 32 |
+
self._generate_lock = threading.Lock()
|
| 33 |
+
ctx = mp.get_context("spawn")
|
| 34 |
+
for i in range(1, config.tensor_parallel_size):
|
| 35 |
+
event = ctx.Event()
|
| 36 |
+
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
| 37 |
+
process.start()
|
| 38 |
+
self.ps.append(process)
|
| 39 |
+
self.events.append(event)
|
| 40 |
+
self.model_runner = ModelRunner(config, 0, self.events)
|
| 41 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 42 |
+
if tokenizer is not None:
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
else:
|
| 45 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
| 46 |
+
config.eos = self.tokenizer.eos_token_id
|
| 47 |
+
self.scheduler = Scheduler(config)
|
| 48 |
+
atexit.register(self.exit)
|
| 49 |
+
|
| 50 |
+
def exit(self):
|
| 51 |
+
self.model_runner.call("exit")
|
| 52 |
+
del self.model_runner
|
| 53 |
+
for p in self.ps:
|
| 54 |
+
p.join()
|
| 55 |
+
|
| 56 |
+
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
|
| 57 |
+
if isinstance(prompt, str):
|
| 58 |
+
prompt = self.tokenizer.encode(prompt)
|
| 59 |
+
# For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
|
| 60 |
+
if sampling_params.cfg_scale > 1.0:
|
| 61 |
+
if unconditional_prompt is None:
|
| 62 |
+
# Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
|
| 63 |
+
# This is a fallback - ideally users should provide unconditional_prompt
|
| 64 |
+
if isinstance(prompt, list):
|
| 65 |
+
# For now, just use the same prompt (user should provide unconditional_prompt)
|
| 66 |
+
# TODO: Implement automatic "NO USER INPUT" replacement if possible
|
| 67 |
+
unconditional_prompt = prompt
|
| 68 |
+
else:
|
| 69 |
+
unconditional_prompt = prompt
|
| 70 |
+
if isinstance(unconditional_prompt, str):
|
| 71 |
+
unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
|
| 72 |
+
# Create unconditional sequence first (so we can reference it from conditional)
|
| 73 |
+
uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
|
| 74 |
+
# Create conditional sequence with reference to unconditional
|
| 75 |
+
cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
|
| 76 |
+
uncond_seq.paired_seq = cond_seq # Link them bidirectionally
|
| 77 |
+
# Add both sequences to scheduler
|
| 78 |
+
self.scheduler.add(cond_seq)
|
| 79 |
+
self.scheduler.add(uncond_seq)
|
| 80 |
+
else:
|
| 81 |
+
seq = Sequence(prompt, sampling_params)
|
| 82 |
+
self.scheduler.add(seq)
|
| 83 |
+
|
| 84 |
+
def step(self):
|
| 85 |
+
seqs, is_prefill = self.scheduler.schedule()
|
| 86 |
+
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
| 87 |
+
self.scheduler.postprocess(seqs, token_ids)
|
| 88 |
+
# Only output conditional sequences (unconditional sequences are just for CFG computation)
|
| 89 |
+
output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
|
| 90 |
+
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
|
| 91 |
+
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
|
| 92 |
+
return outputs, num_tokens
|
| 93 |
+
|
| 94 |
+
def is_finished(self):
|
| 95 |
+
return self.scheduler.is_finished()
|
| 96 |
+
|
| 97 |
+
def reset(self):
|
| 98 |
+
"""
|
| 99 |
+
Reset the scheduler state and release all allocated blocks.
|
| 100 |
+
This should be called when an exception occurs during generation to prevent
|
| 101 |
+
KV cache block leaks that can cause 'deque index out of range' errors.
|
| 102 |
+
"""
|
| 103 |
+
# Deallocate all running sequences
|
| 104 |
+
while self.scheduler.running:
|
| 105 |
+
seq = self.scheduler.running.popleft()
|
| 106 |
+
if seq.block_table: # Only deallocate if blocks are allocated
|
| 107 |
+
self.scheduler.block_manager.deallocate(seq)
|
| 108 |
+
|
| 109 |
+
# Deallocate all waiting sequences (they might have blocks from preemption)
|
| 110 |
+
while self.scheduler.waiting:
|
| 111 |
+
seq = self.scheduler.waiting.popleft()
|
| 112 |
+
if seq.block_table:
|
| 113 |
+
self.scheduler.block_manager.deallocate(seq)
|
| 114 |
+
|
| 115 |
+
def generate(
|
| 116 |
+
self,
|
| 117 |
+
prompts: list[str] | list[list[int]],
|
| 118 |
+
sampling_params: SamplingParams | list[SamplingParams],
|
| 119 |
+
use_tqdm: bool = True,
|
| 120 |
+
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
| 121 |
+
) -> list[str]:
|
| 122 |
+
# Serialize access to the engine to prevent concurrent corruption of
|
| 123 |
+
# scheduler state, block manager, CUDA graph buffers, and KV cache.
|
| 124 |
+
# This is the primary defense against the intermittent CUDA device-side
|
| 125 |
+
# assertion error that occurs in concurrent serving scenarios.
|
| 126 |
+
with self._generate_lock:
|
| 127 |
+
return self._generate_impl(prompts, sampling_params, use_tqdm, unconditional_prompts)
|
| 128 |
+
|
| 129 |
+
def _generate_impl(
|
| 130 |
+
self,
|
| 131 |
+
prompts: list[str] | list[list[int]],
|
| 132 |
+
sampling_params: SamplingParams | list[SamplingParams],
|
| 133 |
+
use_tqdm: bool = True,
|
| 134 |
+
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
| 135 |
+
) -> list[str]:
|
| 136 |
+
# Clean up any residual state from previous interrupted generations
|
| 137 |
+
# This prevents 'deque index out of range' errors from accumulated block leaks
|
| 138 |
+
if not self.is_finished():
|
| 139 |
+
self.reset()
|
| 140 |
+
|
| 141 |
+
if use_tqdm:
|
| 142 |
+
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
| 143 |
+
if not isinstance(sampling_params, list):
|
| 144 |
+
sampling_params = [sampling_params] * len(prompts)
|
| 145 |
+
if unconditional_prompts is None:
|
| 146 |
+
unconditional_prompts = [None] * len(prompts)
|
| 147 |
+
for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
|
| 148 |
+
self.add_request(prompt, sp, uncond_prompt)
|
| 149 |
+
outputs = {}
|
| 150 |
+
prefill_throughput = decode_throughput = 0.
|
| 151 |
+
try:
|
| 152 |
+
while not self.is_finished():
|
| 153 |
+
t = perf_counter()
|
| 154 |
+
output, num_tokens = self.step()
|
| 155 |
+
if use_tqdm:
|
| 156 |
+
if num_tokens > 0:
|
| 157 |
+
prefill_throughput = num_tokens / (perf_counter() - t)
|
| 158 |
+
else:
|
| 159 |
+
decode_throughput = -num_tokens / (perf_counter() - t)
|
| 160 |
+
pbar.set_postfix({
|
| 161 |
+
"Prefill": f"{int(prefill_throughput)}tok/s",
|
| 162 |
+
"Decode": f"{int(decode_throughput)}tok/s",
|
| 163 |
+
})
|
| 164 |
+
for seq_id, token_ids in output:
|
| 165 |
+
outputs[seq_id] = token_ids
|
| 166 |
+
if use_tqdm:
|
| 167 |
+
pbar.update(1)
|
| 168 |
+
except Exception:
|
| 169 |
+
# Clean up on exception to prevent block leaks
|
| 170 |
+
self.reset()
|
| 171 |
+
raise
|
| 172 |
+
finally:
|
| 173 |
+
if use_tqdm:
|
| 174 |
+
pbar.close()
|
| 175 |
+
|
| 176 |
+
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
| 177 |
+
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
| 178 |
+
return outputs
|
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
ADDED
|
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from multiprocessing.synchronize import Event
|
| 5 |
+
from multiprocessing.shared_memory import SharedMemory
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
from nanovllm.config import Config
|
| 9 |
+
from nanovllm.engine.sequence import Sequence
|
| 10 |
+
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
| 11 |
+
from nanovllm.layers.sampler import Sampler
|
| 12 |
+
from nanovllm.utils.context import set_context, get_context, reset_context
|
| 13 |
+
from nanovllm.utils.loader import load_model
|
| 14 |
+
|
| 15 |
+
import socket
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
|
| 19 |
+
"""Find an available port starting from start_port.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
start_port: The starting port number to check
|
| 23 |
+
max_attempts: Maximum number of ports to try
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
An available port number
|
| 27 |
+
|
| 28 |
+
Raises:
|
| 29 |
+
RuntimeError: If no available port is found within max_attempts
|
| 30 |
+
"""
|
| 31 |
+
for i in range(max_attempts):
|
| 32 |
+
port = start_port + i
|
| 33 |
+
try:
|
| 34 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 35 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 36 |
+
s.bind(('localhost', port))
|
| 37 |
+
return port
|
| 38 |
+
except OSError:
|
| 39 |
+
# Port is in use, try next one
|
| 40 |
+
continue
|
| 41 |
+
raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ModelRunner:
|
| 45 |
+
|
| 46 |
+
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
| 47 |
+
# Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
|
| 48 |
+
torch._dynamo.config.capture_scalar_outputs = True
|
| 49 |
+
|
| 50 |
+
self.config = config
|
| 51 |
+
hf_config = config.hf_config
|
| 52 |
+
self.block_size = config.kvcache_block_size
|
| 53 |
+
self.enforce_eager = config.enforce_eager
|
| 54 |
+
self.world_size = config.tensor_parallel_size
|
| 55 |
+
self.rank = rank
|
| 56 |
+
self.event = event
|
| 57 |
+
dist_port = find_available_port()
|
| 58 |
+
print(f"[debug]dist_port: {dist_port}")
|
| 59 |
+
# Use gloo backend on Windows, nccl on Linux/other platforms
|
| 60 |
+
backend = "gloo" if sys.platform == "win32" else "nccl"
|
| 61 |
+
dist.init_process_group(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
|
| 62 |
+
torch.cuda.set_device(rank)
|
| 63 |
+
default_dtype = torch.get_default_dtype()
|
| 64 |
+
# Use dtype instead of deprecated torch_dtype
|
| 65 |
+
config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
|
| 66 |
+
torch.set_default_dtype(config_dtype)
|
| 67 |
+
torch.set_default_device("cuda")
|
| 68 |
+
self.model = Qwen3ForCausalLM(hf_config)
|
| 69 |
+
load_model(self.model, config.model)
|
| 70 |
+
self.sampler = Sampler()
|
| 71 |
+
|
| 72 |
+
# Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
|
| 73 |
+
# Must be called before warmup_model() since it uses these buffers
|
| 74 |
+
self._allocate_sample_buffers()
|
| 75 |
+
|
| 76 |
+
self.warmup_model()
|
| 77 |
+
self.allocate_kv_cache()
|
| 78 |
+
if not self.enforce_eager:
|
| 79 |
+
self.capture_cudagraph()
|
| 80 |
+
|
| 81 |
+
torch.set_default_device("cpu")
|
| 82 |
+
torch.set_default_dtype(default_dtype)
|
| 83 |
+
|
| 84 |
+
if self.world_size > 1:
|
| 85 |
+
if rank == 0:
|
| 86 |
+
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
|
| 87 |
+
dist.barrier()
|
| 88 |
+
else:
|
| 89 |
+
dist.barrier()
|
| 90 |
+
self.shm = SharedMemory(name="nanovllm")
|
| 91 |
+
self.loop()
|
| 92 |
+
|
| 93 |
+
def _allocate_sample_buffers(self):
|
| 94 |
+
"""Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
|
| 95 |
+
max_bs = self.config.max_num_seqs
|
| 96 |
+
max_tokens = self.config.max_num_batched_tokens
|
| 97 |
+
max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
|
| 98 |
+
|
| 99 |
+
# Pre-allocate pinned memory buffers on CPU for fast transfer
|
| 100 |
+
# Must explicitly specify device="cpu" since default device may be "cuda"
|
| 101 |
+
self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 102 |
+
self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 103 |
+
self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 104 |
+
self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 105 |
+
self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 106 |
+
|
| 107 |
+
# Pre-allocate decode buffers on CPU with pinned memory
|
| 108 |
+
self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 109 |
+
self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 110 |
+
self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 111 |
+
self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 112 |
+
|
| 113 |
+
# Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
|
| 114 |
+
self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 115 |
+
self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 116 |
+
self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 117 |
+
self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 118 |
+
|
| 119 |
+
# Pre-allocate block tables buffer (shared by both decode and prefill)
|
| 120 |
+
self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 121 |
+
|
| 122 |
+
# Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
|
| 123 |
+
# Max length is max_model_len since sequences can be that long
|
| 124 |
+
self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 125 |
+
|
| 126 |
+
def exit(self):
|
| 127 |
+
if self.world_size > 1:
|
| 128 |
+
self.shm.close()
|
| 129 |
+
dist.barrier()
|
| 130 |
+
if self.rank == 0:
|
| 131 |
+
self.shm.unlink()
|
| 132 |
+
if not self.enforce_eager:
|
| 133 |
+
del self.graphs, self.graph_pool
|
| 134 |
+
torch.cuda.synchronize()
|
| 135 |
+
dist.destroy_process_group()
|
| 136 |
+
|
| 137 |
+
def loop(self):
|
| 138 |
+
while True:
|
| 139 |
+
method_name, args = self.read_shm()
|
| 140 |
+
self.call(method_name, *args)
|
| 141 |
+
if method_name == "exit":
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
def read_shm(self):
|
| 145 |
+
assert self.world_size > 1 and self.rank > 0
|
| 146 |
+
self.event.wait()
|
| 147 |
+
n = int.from_bytes(self.shm.buf[0:4], "little")
|
| 148 |
+
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
| 149 |
+
self.event.clear()
|
| 150 |
+
return method_name, args
|
| 151 |
+
|
| 152 |
+
def write_shm(self, method_name, *args):
|
| 153 |
+
assert self.world_size > 1 and self.rank == 0
|
| 154 |
+
data = pickle.dumps([method_name, *args])
|
| 155 |
+
n = len(data)
|
| 156 |
+
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
| 157 |
+
self.shm.buf[4:n+4] = data
|
| 158 |
+
for event in self.event:
|
| 159 |
+
event.set()
|
| 160 |
+
|
| 161 |
+
def call(self, method_name, *args):
|
| 162 |
+
if self.world_size > 1 and self.rank == 0:
|
| 163 |
+
self.write_shm(method_name, *args)
|
| 164 |
+
method = getattr(self, method_name, None)
|
| 165 |
+
return method(*args)
|
| 166 |
+
|
| 167 |
+
def warmup_model(self):
|
| 168 |
+
torch.cuda.empty_cache()
|
| 169 |
+
torch.cuda.reset_peak_memory_stats()
|
| 170 |
+
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
| 171 |
+
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
| 172 |
+
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
| 173 |
+
self.run(seqs, True)
|
| 174 |
+
torch.cuda.empty_cache()
|
| 175 |
+
|
| 176 |
+
def allocate_kv_cache(self):
|
| 177 |
+
config = self.config
|
| 178 |
+
hf_config = config.hf_config
|
| 179 |
+
free, total = torch.cuda.mem_get_info()
|
| 180 |
+
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
| 181 |
+
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
| 182 |
+
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
| 183 |
+
# Use dtype instead of deprecated torch_dtype
|
| 184 |
+
config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
|
| 185 |
+
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * config_dtype.itemsize
|
| 186 |
+
|
| 187 |
+
# Calculate available memory for KV cache
|
| 188 |
+
# After warmup_model, empty_cache has been called, so current represents model memory only
|
| 189 |
+
# Use free memory but respect the gpu_memory_utilization limit
|
| 190 |
+
target_total_usage = total * config.gpu_memory_utilization
|
| 191 |
+
available_for_kv_cache = min(free * 0.9, target_total_usage - current)
|
| 192 |
+
|
| 193 |
+
# Ensure we have positive memory available
|
| 194 |
+
if available_for_kv_cache <= 0:
|
| 195 |
+
available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
|
| 196 |
+
|
| 197 |
+
config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
|
| 198 |
+
if config.num_kvcache_blocks <= 0:
|
| 199 |
+
raise RuntimeError(
|
| 200 |
+
f"Insufficient GPU memory for KV cache. "
|
| 201 |
+
f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
|
| 202 |
+
f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
|
| 203 |
+
f"Block size: {block_bytes / 1024**2:.2f} MB"
|
| 204 |
+
)
|
| 205 |
+
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
| 206 |
+
layer_id = 0
|
| 207 |
+
for module in self.model.modules():
|
| 208 |
+
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
| 209 |
+
module.k_cache = self.kv_cache[0, layer_id]
|
| 210 |
+
module.v_cache = self.kv_cache[1, layer_id]
|
| 211 |
+
layer_id += 1
|
| 212 |
+
|
| 213 |
+
def prepare_block_tables(self, seqs: list[Sequence]):
|
| 214 |
+
max_len = max(len(seq.block_table) for seq in seqs)
|
| 215 |
+
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
| 216 |
+
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 217 |
+
return block_tables
|
| 218 |
+
|
| 219 |
+
def prepare_prefill(self, seqs: list[Sequence]):
|
| 220 |
+
input_ids = []
|
| 221 |
+
positions = []
|
| 222 |
+
cu_seqlens_q = [0]
|
| 223 |
+
cu_seqlens_k = [0]
|
| 224 |
+
max_seqlen_q = 0
|
| 225 |
+
max_seqlen_k = 0
|
| 226 |
+
slot_mapping = []
|
| 227 |
+
block_tables = None
|
| 228 |
+
for seq in seqs:
|
| 229 |
+
seqlen = len(seq)
|
| 230 |
+
input_ids.extend(seq[seq.num_cached_tokens:])
|
| 231 |
+
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
| 232 |
+
seqlen_q = seqlen - seq.num_cached_tokens
|
| 233 |
+
seqlen_k = seqlen
|
| 234 |
+
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
| 235 |
+
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
| 236 |
+
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
| 237 |
+
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
| 238 |
+
if not seq.block_table: # warmup
|
| 239 |
+
continue
|
| 240 |
+
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
| 241 |
+
start = seq.block_table[i] * self.block_size
|
| 242 |
+
if i != seq.num_blocks - 1:
|
| 243 |
+
end = start + self.block_size
|
| 244 |
+
else:
|
| 245 |
+
end = start + seq.last_block_num_tokens
|
| 246 |
+
slot_mapping.extend(list(range(start, end)))
|
| 247 |
+
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
| 248 |
+
block_tables = self.prepare_block_tables(seqs)
|
| 249 |
+
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
| 250 |
+
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
| 251 |
+
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 252 |
+
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 253 |
+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 254 |
+
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
| 255 |
+
return input_ids, positions
|
| 256 |
+
|
| 257 |
+
def prepare_decode(self, seqs: list[Sequence]):
|
| 258 |
+
"""Optimized decode preparation using pre-allocated buffers."""
|
| 259 |
+
bs = len(seqs)
|
| 260 |
+
|
| 261 |
+
# Use pre-allocated CPU buffers
|
| 262 |
+
for i, seq in enumerate(seqs):
|
| 263 |
+
self._cpu_input_ids[i] = seq.last_token
|
| 264 |
+
self._cpu_positions[i] = len(seq) - 1
|
| 265 |
+
self._cpu_context_lens[i] = len(seq)
|
| 266 |
+
self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
|
| 267 |
+
|
| 268 |
+
# Transfer to GPU using sliced views
|
| 269 |
+
input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
|
| 270 |
+
positions = self._cpu_positions[:bs].cuda(non_blocking=True)
|
| 271 |
+
slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
|
| 272 |
+
context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
|
| 273 |
+
block_tables = self.prepare_block_tables(seqs)
|
| 274 |
+
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
| 275 |
+
return input_ids, positions
|
| 276 |
+
|
| 277 |
+
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
|
| 278 |
+
"""Optimized sample preparation using pre-allocated buffers."""
|
| 279 |
+
if is_cfg_batch:
|
| 280 |
+
num_seqs = len(seqs) // 2
|
| 281 |
+
target_seqs = seqs[:num_seqs]
|
| 282 |
+
else:
|
| 283 |
+
num_seqs = len(seqs)
|
| 284 |
+
target_seqs = seqs
|
| 285 |
+
|
| 286 |
+
# Fill pre-allocated CPU buffers
|
| 287 |
+
top_ks_is_zero = True
|
| 288 |
+
top_ps_is_one = True
|
| 289 |
+
repetition_penalties_is_one = True
|
| 290 |
+
for i, seq in enumerate(target_seqs):
|
| 291 |
+
self._cpu_temperatures[i] = seq.temperature
|
| 292 |
+
self._cpu_cfg_scales[i] = seq.cfg_scale
|
| 293 |
+
self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
|
| 294 |
+
if seq.top_k is not None and seq.top_k > 0:
|
| 295 |
+
top_ks_is_zero = False
|
| 296 |
+
self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
|
| 297 |
+
if seq.top_p is not None and seq.top_p == 1.0:
|
| 298 |
+
top_ps_is_one = False
|
| 299 |
+
self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
|
| 300 |
+
if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
|
| 301 |
+
repetition_penalties_is_one = False
|
| 302 |
+
|
| 303 |
+
# Transfer to GPU using sliced views (single batched transfer)
|
| 304 |
+
temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
|
| 305 |
+
cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
|
| 306 |
+
top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
|
| 307 |
+
top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
|
| 308 |
+
repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
|
| 309 |
+
|
| 310 |
+
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 311 |
+
|
| 312 |
+
@torch.inference_mode()
|
| 313 |
+
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
| 314 |
+
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
| 315 |
+
return self.model.compute_logits(self.model(input_ids, positions))
|
| 316 |
+
else:
|
| 317 |
+
bs = input_ids.size(0)
|
| 318 |
+
context = get_context()
|
| 319 |
+
|
| 320 |
+
# Check if block_tables size exceeds pre-allocated buffer size
|
| 321 |
+
# This can happen when conditional and unconditional sequences have different lengths
|
| 322 |
+
# in CFG mode, causing block_tables to have more columns than expected
|
| 323 |
+
max_num_blocks = self.graph_vars["block_tables"].size(1)
|
| 324 |
+
if context.block_tables.size(1) > max_num_blocks:
|
| 325 |
+
# Fall back to eager mode when block_tables is too large for CUDA graph
|
| 326 |
+
return self.model.compute_logits(self.model(input_ids, positions))
|
| 327 |
+
|
| 328 |
+
# Fix: Also check if block_tables row count matches batch size
|
| 329 |
+
# Dimension mismatch can cause CUDA illegal memory access during graph replay
|
| 330 |
+
if context.block_tables.size(0) != bs:
|
| 331 |
+
# Fall back to eager mode when block_tables row count doesn't match batch size
|
| 332 |
+
return self.model.compute_logits(self.model(input_ids, positions))
|
| 333 |
+
|
| 334 |
+
# Fix: Verify slot_mapping and context_lens dimensions match batch size
|
| 335 |
+
if context.slot_mapping.size(0) != bs or context.context_lens.size(0) != bs:
|
| 336 |
+
# Fall back to eager mode when dimensions don't match
|
| 337 |
+
return self.model.compute_logits(self.model(input_ids, positions))
|
| 338 |
+
|
| 339 |
+
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
| 340 |
+
graph_vars = self.graph_vars
|
| 341 |
+
graph_vars["input_ids"][:bs] = input_ids
|
| 342 |
+
graph_vars["positions"][:bs] = positions
|
| 343 |
+
graph_vars["slot_mapping"].fill_(-1)
|
| 344 |
+
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
| 345 |
+
graph_vars["context_lens"].zero_()
|
| 346 |
+
graph_vars["context_lens"][:bs] = context.context_lens
|
| 347 |
+
# Clear block_tables first to ensure no stale data from previous runs
|
| 348 |
+
graph_vars["block_tables"][:bs].fill_(-1)
|
| 349 |
+
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
| 350 |
+
graph.replay()
|
| 351 |
+
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
| 352 |
+
|
| 353 |
+
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
| 354 |
+
"""Run model forward and sampling. For CFG sequences, batch is structured as:
|
| 355 |
+
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 356 |
+
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
|
| 357 |
+
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
|
| 358 |
+
is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
|
| 359 |
+
if is_cfg_batch:
|
| 360 |
+
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 361 |
+
num_cond = len(seqs) // 2
|
| 362 |
+
cond_seqs = seqs[:num_cond]
|
| 363 |
+
# uncond_seqs = seqs[num_cond:]
|
| 364 |
+
|
| 365 |
+
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 366 |
+
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
|
| 367 |
+
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
|
| 368 |
+
if sample_params is not None:
|
| 369 |
+
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
| 370 |
+
else:
|
| 371 |
+
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
|
| 372 |
+
|
| 373 |
+
# Run model forward (processes entire batch: cond + uncond)
|
| 374 |
+
logits_all = self.run_model(input_ids, positions, is_prefill)
|
| 375 |
+
reset_context()
|
| 376 |
+
|
| 377 |
+
if self.rank == 0:
|
| 378 |
+
# Split logits: first half is conditional, second half is unconditional
|
| 379 |
+
logits_cond = logits_all[:num_cond]
|
| 380 |
+
logits_uncond = logits_all[num_cond:]
|
| 381 |
+
|
| 382 |
+
# Apply repetition penalty to conditional logits (before CFG)
|
| 383 |
+
if repetition_penalties is not None:
|
| 384 |
+
for i, seq in enumerate(cond_seqs):
|
| 385 |
+
penalty = repetition_penalties[i].item()
|
| 386 |
+
if penalty != 1.0:
|
| 387 |
+
# Only penalize completion tokens (not prompt tokens)
|
| 388 |
+
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
|
| 389 |
+
if len(completion_tokens) > 0:
|
| 390 |
+
# Create token mask: mark tokens that appeared in completion
|
| 391 |
+
token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
|
| 392 |
+
token_mask[completion_tokens] = True
|
| 393 |
+
|
| 394 |
+
# Apply standard repetition penalty formula (matching transformers implementation):
|
| 395 |
+
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
|
| 396 |
+
penalty_scores = torch.where(
|
| 397 |
+
logits_cond[i] < 0,
|
| 398 |
+
logits_cond[i] * penalty,
|
| 399 |
+
logits_cond[i] / penalty
|
| 400 |
+
)
|
| 401 |
+
# Only apply penalty to tokens that appeared in completion
|
| 402 |
+
logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
|
| 403 |
+
|
| 404 |
+
# Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
|
| 405 |
+
cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
|
| 406 |
+
logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
|
| 407 |
+
|
| 408 |
+
# Apply logits processor for constrained decoding (if any sequence has one)
|
| 409 |
+
for i, seq in enumerate(cond_seqs):
|
| 410 |
+
if seq.logits_processor is not None:
|
| 411 |
+
# Create input_ids tensor for this sequence
|
| 412 |
+
seq_input_ids = torch.tensor([seq.token_ids], device=logits_cfg.device)
|
| 413 |
+
# Apply processor to this sequence's logits
|
| 414 |
+
logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
|
| 415 |
+
|
| 416 |
+
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
|
| 417 |
+
# cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
|
| 418 |
+
|
| 419 |
+
# Sample from CFG logits
|
| 420 |
+
token_ids_cfg = self.sampler(
|
| 421 |
+
logits_cfg,
|
| 422 |
+
temperatures,
|
| 423 |
+
top_ks=top_ks if top_ks is not None else None,
|
| 424 |
+
top_ps=top_ps if top_ps is not None else None,
|
| 425 |
+
repetition_penalties=None, # Already applied above
|
| 426 |
+
# input_ids=cond_input_ids,
|
| 427 |
+
).tolist()
|
| 428 |
+
|
| 429 |
+
# Update logits processor state after sampling
|
| 430 |
+
# NOTE: Only update for the first sequence since all sequences share the same processor
|
| 431 |
+
# Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
|
| 432 |
+
if cond_seqs and cond_seqs[0].logits_processor_update_state is not None:
|
| 433 |
+
cond_seqs[0].logits_processor_update_state(token_ids_cfg[0])
|
| 434 |
+
|
| 435 |
+
# Return token_ids (will be applied to both conditional and unconditional sequences)
|
| 436 |
+
return token_ids_cfg
|
| 437 |
+
else:
|
| 438 |
+
return None
|
| 439 |
+
else:
|
| 440 |
+
# Normal batch (non-CFG)
|
| 441 |
+
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 442 |
+
else self.prepare_decode(seqs))
|
| 443 |
+
sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
|
| 444 |
+
if sample_params is not None:
|
| 445 |
+
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
| 446 |
+
else:
|
| 447 |
+
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
|
| 448 |
+
logits = self.run_model(input_ids, positions, is_prefill)
|
| 449 |
+
reset_context()
|
| 450 |
+
|
| 451 |
+
if self.rank == 0:
|
| 452 |
+
# Apply repetition penalty to logits
|
| 453 |
+
if repetition_penalties is not None:
|
| 454 |
+
for i, seq in enumerate(seqs):
|
| 455 |
+
penalty = repetition_penalties[i].item()
|
| 456 |
+
if penalty != 1.0:
|
| 457 |
+
# Only penalize completion tokens (not prompt tokens)
|
| 458 |
+
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
|
| 459 |
+
if len(completion_tokens) > 0:
|
| 460 |
+
# Create token mask: mark tokens that appeared in completion
|
| 461 |
+
token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
|
| 462 |
+
token_mask[completion_tokens] = True
|
| 463 |
+
|
| 464 |
+
# Apply standard repetition penalty formula (matching transformers implementation):
|
| 465 |
+
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
|
| 466 |
+
penalty_scores = torch.where(
|
| 467 |
+
logits[i] < 0,
|
| 468 |
+
logits[i] * penalty,
|
| 469 |
+
logits[i] / penalty
|
| 470 |
+
)
|
| 471 |
+
# Only apply penalty to tokens that appeared in completion
|
| 472 |
+
logits[i] = torch.where(token_mask, penalty_scores, logits[i])
|
| 473 |
+
|
| 474 |
+
# Apply logits processor for constrained decoding (if any sequence has one)
|
| 475 |
+
# Clone logits to avoid in-place update issues in inference mode
|
| 476 |
+
logits = logits.clone()
|
| 477 |
+
for i, seq in enumerate(seqs):
|
| 478 |
+
if seq.logits_processor is not None:
|
| 479 |
+
# Create input_ids tensor for this sequence
|
| 480 |
+
seq_input_ids = torch.tensor([seq.token_ids], device=logits.device)
|
| 481 |
+
# Apply processor to this sequence's logits (clone to avoid inference mode issues)
|
| 482 |
+
processed = seq.logits_processor(seq_input_ids, logits[i:i+1].clone())
|
| 483 |
+
logits[i] = processed[0]
|
| 484 |
+
|
| 485 |
+
# Prepare input_ids for sampler
|
| 486 |
+
# seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
|
| 487 |
+
|
| 488 |
+
token_ids = self.sampler(
|
| 489 |
+
logits,
|
| 490 |
+
temperatures,
|
| 491 |
+
top_ks=top_ks if top_ks is not None else None,
|
| 492 |
+
top_ps=top_ps if top_ps is not None else None,
|
| 493 |
+
repetition_penalties=None, # Already applied above
|
| 494 |
+
# input_ids=seq_input_ids,
|
| 495 |
+
).tolist()
|
| 496 |
+
|
| 497 |
+
# Update logits processor state after sampling
|
| 498 |
+
# NOTE: Only update for the first sequence since all sequences may share the same processor
|
| 499 |
+
# (when using a single SamplingParams for batch generation)
|
| 500 |
+
# Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
|
| 501 |
+
if seqs and seqs[0].logits_processor_update_state is not None:
|
| 502 |
+
seqs[0].logits_processor_update_state(token_ids[0])
|
| 503 |
+
|
| 504 |
+
return token_ids
|
| 505 |
+
else:
|
| 506 |
+
return None
|
| 507 |
+
|
| 508 |
+
@torch.inference_mode()
|
| 509 |
+
def capture_cudagraph(self):
|
| 510 |
+
config = self.config
|
| 511 |
+
hf_config = config.hf_config
|
| 512 |
+
max_bs = min(self.config.max_num_seqs, 512)
|
| 513 |
+
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
| 514 |
+
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
| 515 |
+
positions = torch.zeros(max_bs, dtype=torch.int64)
|
| 516 |
+
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
| 517 |
+
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
| 518 |
+
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
| 519 |
+
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
| 520 |
+
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
|
| 521 |
+
self.graphs = {}
|
| 522 |
+
self.graph_pool = None
|
| 523 |
+
|
| 524 |
+
for bs in reversed(self.graph_bs):
|
| 525 |
+
graph = torch.cuda.CUDAGraph()
|
| 526 |
+
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
| 527 |
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
| 528 |
+
with torch.cuda.graph(graph, self.graph_pool):
|
| 529 |
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
| 530 |
+
if self.graph_pool is None:
|
| 531 |
+
self.graph_pool = graph.pool()
|
| 532 |
+
self.graphs[bs] = graph
|
| 533 |
+
torch.cuda.synchronize()
|
| 534 |
+
reset_context()
|
| 535 |
+
|
| 536 |
+
self.graph_vars = dict(
|
| 537 |
+
input_ids=input_ids,
|
| 538 |
+
positions=positions,
|
| 539 |
+
slot_mapping=slot_mapping,
|
| 540 |
+
context_lens=context_lens,
|
| 541 |
+
block_tables=block_tables,
|
| 542 |
+
outputs=outputs,
|
| 543 |
+
)
|
acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
|
| 3 |
+
from nanovllm.config import Config
|
| 4 |
+
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
| 5 |
+
from nanovllm.engine.block_manager import BlockManager
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Scheduler:
|
| 9 |
+
|
| 10 |
+
def __init__(self, config: Config):
|
| 11 |
+
self.max_num_seqs = config.max_num_seqs
|
| 12 |
+
self.max_num_batched_tokens = config.max_num_batched_tokens
|
| 13 |
+
self.eos = config.eos
|
| 14 |
+
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
| 15 |
+
self.waiting: deque[Sequence] = deque()
|
| 16 |
+
self.running: deque[Sequence] = deque()
|
| 17 |
+
|
| 18 |
+
def is_finished(self):
|
| 19 |
+
return not self.waiting and not self.running
|
| 20 |
+
|
| 21 |
+
def add(self, seq: Sequence):
|
| 22 |
+
self.waiting.append(seq)
|
| 23 |
+
|
| 24 |
+
def schedule(self) -> tuple[list[Sequence], bool]:
|
| 25 |
+
# prefill
|
| 26 |
+
scheduled_seqs = []
|
| 27 |
+
num_seqs = 0
|
| 28 |
+
num_batched_tokens = 0
|
| 29 |
+
processed_seqs = set() # Track processed sequences to handle CFG pairs
|
| 30 |
+
|
| 31 |
+
while self.waiting and num_seqs < self.max_num_seqs:
|
| 32 |
+
seq = self.waiting[0]
|
| 33 |
+
|
| 34 |
+
# For CFG sequences, ensure conditional and unconditional are scheduled together
|
| 35 |
+
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
|
| 36 |
+
# This is a conditional sequence, need to schedule its paired unconditional sequence too
|
| 37 |
+
paired_seq = seq.paired_seq
|
| 38 |
+
if paired_seq.status != SequenceStatus.WAITING:
|
| 39 |
+
# Paired sequence not in waiting, skip this conditional sequence for now
|
| 40 |
+
break
|
| 41 |
+
|
| 42 |
+
# Calculate tokens for both sequences
|
| 43 |
+
total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
|
| 44 |
+
|
| 45 |
+
# FIX: Check if we have enough blocks for BOTH sequences combined
|
| 46 |
+
# The old check was wrong: it checked each sequence independently,
|
| 47 |
+
# but didn't account for the total blocks needed by both
|
| 48 |
+
total_blocks_needed = seq.num_blocks + paired_seq.num_blocks
|
| 49 |
+
can_allocate_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
| 50 |
+
|
| 51 |
+
if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
|
| 52 |
+
break
|
| 53 |
+
|
| 54 |
+
# Schedule both sequences: conditional first, then unconditional
|
| 55 |
+
for s in [seq, paired_seq]:
|
| 56 |
+
num_seqs += 1
|
| 57 |
+
self.block_manager.allocate(s)
|
| 58 |
+
num_batched_tokens += len(s) - s.num_cached_tokens
|
| 59 |
+
s.status = SequenceStatus.RUNNING
|
| 60 |
+
self.waiting.remove(s)
|
| 61 |
+
self.running.append(s)
|
| 62 |
+
scheduled_seqs.append(s)
|
| 63 |
+
processed_seqs.add(s.seq_id)
|
| 64 |
+
else:
|
| 65 |
+
# Normal sequence or unconditional sequence (already processed with its conditional)
|
| 66 |
+
if seq.seq_id in processed_seqs:
|
| 67 |
+
# Skip if already processed as part of a CFG pair
|
| 68 |
+
self.waiting.popleft()
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
| 72 |
+
break
|
| 73 |
+
num_seqs += 1
|
| 74 |
+
self.block_manager.allocate(seq)
|
| 75 |
+
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
| 76 |
+
seq.status = SequenceStatus.RUNNING
|
| 77 |
+
self.waiting.popleft()
|
| 78 |
+
self.running.append(seq)
|
| 79 |
+
scheduled_seqs.append(seq)
|
| 80 |
+
|
| 81 |
+
if scheduled_seqs:
|
| 82 |
+
# For CFG batches, ensure conditional sequences come before their unconditional pairs
|
| 83 |
+
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
|
| 84 |
+
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
|
| 85 |
+
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
|
| 86 |
+
|
| 87 |
+
# Reorder: non-CFG, then CFG conditional, then CFG unconditional
|
| 88 |
+
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
|
| 89 |
+
return scheduled_seqs, True
|
| 90 |
+
|
| 91 |
+
# decode
|
| 92 |
+
processed_seqs = set()
|
| 93 |
+
temp_running = list(self.running) # Work with a copy
|
| 94 |
+
|
| 95 |
+
while temp_running and num_seqs < self.max_num_seqs:
|
| 96 |
+
seq = temp_running.pop(0)
|
| 97 |
+
|
| 98 |
+
# For CFG sequences, ensure conditional and unconditional are scheduled together
|
| 99 |
+
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
|
| 100 |
+
paired_seq = seq.paired_seq
|
| 101 |
+
if paired_seq not in temp_running:
|
| 102 |
+
# Paired sequence not available, skip for now
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
# Remove paired_seq from temp_running
|
| 106 |
+
temp_running.remove(paired_seq)
|
| 107 |
+
|
| 108 |
+
# FIX: Check if we have enough blocks for BOTH sequences to append
|
| 109 |
+
# Each sequence needs 1 block when at block boundary (len % block_size == 1)
|
| 110 |
+
block_size = self.block_manager.block_size
|
| 111 |
+
blocks_needed_seq = 1 if len(seq) % block_size == 1 else 0
|
| 112 |
+
blocks_needed_paired = 1 if len(paired_seq) % block_size == 1 else 0
|
| 113 |
+
total_blocks_needed = blocks_needed_seq + blocks_needed_paired
|
| 114 |
+
can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
| 115 |
+
|
| 116 |
+
if not can_append_both:
|
| 117 |
+
# Try preempting other sequences
|
| 118 |
+
preempted = False
|
| 119 |
+
while not can_append_both and temp_running:
|
| 120 |
+
other_seq = temp_running.pop(0)
|
| 121 |
+
if other_seq != seq and other_seq != paired_seq:
|
| 122 |
+
self.preempt(other_seq)
|
| 123 |
+
# Recalculate with the same correct logic
|
| 124 |
+
can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
|
| 125 |
+
preempted = True
|
| 126 |
+
else:
|
| 127 |
+
temp_running.append(other_seq)
|
| 128 |
+
break
|
| 129 |
+
|
| 130 |
+
if not can_append_both:
|
| 131 |
+
# Can't schedule this pair right now
|
| 132 |
+
temp_running.append(seq)
|
| 133 |
+
temp_running.append(paired_seq)
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
# Schedule both sequences
|
| 137 |
+
for s in [seq, paired_seq]:
|
| 138 |
+
num_seqs += 1
|
| 139 |
+
self.block_manager.may_append(s)
|
| 140 |
+
scheduled_seqs.append(s)
|
| 141 |
+
processed_seqs.add(s.seq_id)
|
| 142 |
+
# Remove from actual running list if scheduled
|
| 143 |
+
if s in self.running:
|
| 144 |
+
self.running.remove(s)
|
| 145 |
+
else:
|
| 146 |
+
# Normal sequence or unconditional (already processed)
|
| 147 |
+
if seq.seq_id in processed_seqs:
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
while not self.block_manager.can_append(seq):
|
| 151 |
+
if temp_running:
|
| 152 |
+
other_seq = temp_running.pop(0)
|
| 153 |
+
if other_seq != seq:
|
| 154 |
+
self.preempt(other_seq)
|
| 155 |
+
else:
|
| 156 |
+
temp_running.append(other_seq)
|
| 157 |
+
break
|
| 158 |
+
else:
|
| 159 |
+
self.preempt(seq)
|
| 160 |
+
if seq in self.running:
|
| 161 |
+
self.running.remove(seq)
|
| 162 |
+
break
|
| 163 |
+
else:
|
| 164 |
+
num_seqs += 1
|
| 165 |
+
self.block_manager.may_append(seq)
|
| 166 |
+
scheduled_seqs.append(seq)
|
| 167 |
+
if seq in self.running:
|
| 168 |
+
self.running.remove(seq)
|
| 169 |
+
|
| 170 |
+
assert scheduled_seqs
|
| 171 |
+
|
| 172 |
+
# For CFG batches in decode, ensure conditional sequences come before unconditional
|
| 173 |
+
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
|
| 174 |
+
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
|
| 175 |
+
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
|
| 176 |
+
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
|
| 177 |
+
|
| 178 |
+
self.running.extendleft(reversed(scheduled_seqs))
|
| 179 |
+
return scheduled_seqs, False
|
| 180 |
+
|
| 181 |
+
def preempt(self, seq: Sequence):
|
| 182 |
+
seq.status = SequenceStatus.WAITING
|
| 183 |
+
self.block_manager.deallocate(seq)
|
| 184 |
+
self.waiting.appendleft(seq)
|
| 185 |
+
|
| 186 |
+
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
| 187 |
+
# Check if this is a CFG batch
|
| 188 |
+
is_cfg_batch = False
|
| 189 |
+
if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
|
| 190 |
+
num_cond = len(seqs) // 2
|
| 191 |
+
is_cfg_batch = (num_cond > 0 and
|
| 192 |
+
not seqs[0].is_unconditional and
|
| 193 |
+
seqs[num_cond].is_unconditional)
|
| 194 |
+
|
| 195 |
+
if is_cfg_batch:
|
| 196 |
+
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 197 |
+
# token_ids correspond to conditional sequences only (sampled from CFG logits)
|
| 198 |
+
num_cond = len(seqs) // 2
|
| 199 |
+
cond_seqs = seqs[:num_cond]
|
| 200 |
+
uncond_seqs = seqs[num_cond:]
|
| 201 |
+
|
| 202 |
+
# Apply the same sampled token to both conditional and unconditional sequences
|
| 203 |
+
for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
|
| 204 |
+
cond_seq.append_token(token_id)
|
| 205 |
+
uncond_seq.append_token(token_id) # Same token for unconditional
|
| 206 |
+
|
| 207 |
+
# Check if either sequence is finished
|
| 208 |
+
cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
|
| 209 |
+
cond_seq.num_completion_tokens == cond_seq.max_tokens)
|
| 210 |
+
uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
|
| 211 |
+
uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
|
| 212 |
+
|
| 213 |
+
if cond_finished or uncond_finished:
|
| 214 |
+
# Mark both as finished
|
| 215 |
+
cond_seq.status = SequenceStatus.FINISHED
|
| 216 |
+
uncond_seq.status = SequenceStatus.FINISHED
|
| 217 |
+
self.block_manager.deallocate(cond_seq)
|
| 218 |
+
self.block_manager.deallocate(uncond_seq)
|
| 219 |
+
if cond_seq in self.running:
|
| 220 |
+
self.running.remove(cond_seq)
|
| 221 |
+
if uncond_seq in self.running:
|
| 222 |
+
self.running.remove(uncond_seq)
|
| 223 |
+
else:
|
| 224 |
+
# Normal batch
|
| 225 |
+
for seq, token_id in zip(seqs, token_ids):
|
| 226 |
+
seq.append_token(token_id)
|
| 227 |
+
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
| 228 |
+
seq.status = SequenceStatus.FINISHED
|
| 229 |
+
self.block_manager.deallocate(seq)
|
| 230 |
+
self.running.remove(seq)
|
acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import copy
|
| 2 |
+
from enum import Enum, auto
|
| 3 |
+
from itertools import count
|
| 4 |
+
from typing import Optional, Callable, Any
|
| 5 |
+
|
| 6 |
+
from nanovllm.sampling_params import SamplingParams
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SequenceStatus(Enum):
|
| 10 |
+
WAITING = auto()
|
| 11 |
+
RUNNING = auto()
|
| 12 |
+
FINISHED = auto()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Sequence:
|
| 16 |
+
block_size = 256
|
| 17 |
+
counter = count()
|
| 18 |
+
|
| 19 |
+
def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
|
| 20 |
+
self.seq_id = next(Sequence.counter)
|
| 21 |
+
self.status = SequenceStatus.WAITING
|
| 22 |
+
self.token_ids = copy(token_ids)
|
| 23 |
+
self.last_token = token_ids[-1]
|
| 24 |
+
self.num_tokens = len(self.token_ids)
|
| 25 |
+
self.num_prompt_tokens = len(token_ids)
|
| 26 |
+
self.num_cached_tokens = 0
|
| 27 |
+
self.block_table = []
|
| 28 |
+
self.temperature = sampling_params.temperature
|
| 29 |
+
self.max_tokens = sampling_params.max_tokens
|
| 30 |
+
self.ignore_eos = sampling_params.ignore_eos
|
| 31 |
+
self.cfg_scale = sampling_params.cfg_scale
|
| 32 |
+
self.top_k = sampling_params.top_k
|
| 33 |
+
self.top_p = sampling_params.top_p
|
| 34 |
+
self.repetition_penalty = sampling_params.repetition_penalty
|
| 35 |
+
# For CFG: mark if this is an unconditional sequence
|
| 36 |
+
self.is_unconditional = is_unconditional
|
| 37 |
+
# For CFG: reference to the corresponding conditional sequence (if this is unconditional)
|
| 38 |
+
# For conditional sequences, this points to the unconditional sequence
|
| 39 |
+
self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
|
| 40 |
+
# For constrained decoding: logits processor and state update callback
|
| 41 |
+
self.logits_processor: Optional[Any] = sampling_params.logits_processor
|
| 42 |
+
self.logits_processor_update_state: Optional[Callable[[int], None]] = sampling_params.logits_processor_update_state
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return self.num_tokens
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, key):
|
| 48 |
+
return self.token_ids[key]
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def is_finished(self):
|
| 52 |
+
return self.status == SequenceStatus.FINISHED
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def num_completion_tokens(self):
|
| 56 |
+
return self.num_tokens - self.num_prompt_tokens
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def prompt_token_ids(self):
|
| 60 |
+
return self.token_ids[:self.num_prompt_tokens]
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def completion_token_ids(self):
|
| 64 |
+
return self.token_ids[self.num_prompt_tokens:]
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def num_cached_blocks(self):
|
| 68 |
+
return self.num_cached_tokens // self.block_size
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def num_blocks(self):
|
| 72 |
+
return (self.num_tokens + self.block_size - 1) // self.block_size
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def last_block_num_tokens(self):
|
| 76 |
+
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
| 77 |
+
|
| 78 |
+
def block(self, i):
|
| 79 |
+
assert 0 <= i < self.num_blocks
|
| 80 |
+
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
| 81 |
+
|
| 82 |
+
def append_token(self, token_id: int):
|
| 83 |
+
self.token_ids.append(token_id)
|
| 84 |
+
self.last_token = token_id
|
| 85 |
+
self.num_tokens += 1
|
| 86 |
+
|
| 87 |
+
def __getstate__(self):
|
| 88 |
+
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
| 89 |
+
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
| 90 |
+
|
| 91 |
+
def __setstate__(self, state):
|
| 92 |
+
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
| 93 |
+
if self.num_completion_tokens == 0:
|
| 94 |
+
self.token_ids = state[-1]
|
| 95 |
+
else:
|
| 96 |
+
self.last_token = state[-1]
|
acestep/third_parts/nano-vllm/nanovllm/layers/activation.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SiluAndMul(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
@torch.compile
|
| 12 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
x, y = x.chunk(2, -1)
|
| 14 |
+
return F.silu(x) * y
|
acestep/third_parts/nano-vllm/nanovllm/layers/attention.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import triton
|
| 4 |
+
import triton.language as tl
|
| 5 |
+
|
| 6 |
+
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
| 7 |
+
from nanovllm.utils.context import get_context
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.jit
|
| 11 |
+
def store_kvcache_kernel(
|
| 12 |
+
key_ptr,
|
| 13 |
+
key_stride,
|
| 14 |
+
value_ptr,
|
| 15 |
+
value_stride,
|
| 16 |
+
k_cache_ptr,
|
| 17 |
+
v_cache_ptr,
|
| 18 |
+
slot_mapping_ptr,
|
| 19 |
+
D: tl.constexpr,
|
| 20 |
+
):
|
| 21 |
+
idx = tl.program_id(0)
|
| 22 |
+
slot = tl.load(slot_mapping_ptr + idx)
|
| 23 |
+
if slot == -1: return
|
| 24 |
+
key_offsets = idx * key_stride + tl.arange(0, D)
|
| 25 |
+
value_offsets = idx * value_stride + tl.arange(0, D)
|
| 26 |
+
key = tl.load(key_ptr + key_offsets)
|
| 27 |
+
value = tl.load(value_ptr + value_offsets)
|
| 28 |
+
cache_offsets = slot * D + tl.arange(0, D)
|
| 29 |
+
tl.store(k_cache_ptr + cache_offsets, key)
|
| 30 |
+
tl.store(v_cache_ptr + cache_offsets, value)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
| 34 |
+
N, num_heads, head_dim = key.shape
|
| 35 |
+
D = num_heads * head_dim
|
| 36 |
+
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
| 37 |
+
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
| 38 |
+
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
| 39 |
+
assert slot_mapping.numel() == N
|
| 40 |
+
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Attention(nn.Module):
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
num_heads,
|
| 48 |
+
head_dim,
|
| 49 |
+
scale,
|
| 50 |
+
num_kv_heads,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.num_heads = num_heads
|
| 54 |
+
self.head_dim = head_dim
|
| 55 |
+
self.scale = scale
|
| 56 |
+
self.num_kv_heads = num_kv_heads
|
| 57 |
+
self.k_cache = self.v_cache = torch.tensor([])
|
| 58 |
+
|
| 59 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
| 60 |
+
context = get_context()
|
| 61 |
+
k_cache, v_cache = self.k_cache, self.v_cache
|
| 62 |
+
if k_cache.numel() and v_cache.numel():
|
| 63 |
+
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
| 64 |
+
if context.is_prefill:
|
| 65 |
+
if context.block_tables is not None: # prefix cache
|
| 66 |
+
k, v = k_cache, v_cache
|
| 67 |
+
o = flash_attn_varlen_func(q, k, v,
|
| 68 |
+
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
| 69 |
+
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
| 70 |
+
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
| 71 |
+
else: # decode
|
| 72 |
+
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
| 73 |
+
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
| 74 |
+
softmax_scale=self.scale, causal=True)
|
| 75 |
+
return o
|
acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
from nanovllm.utils.context import get_context
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VocabParallelEmbedding(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
num_embeddings: int,
|
| 14 |
+
embedding_dim: int,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.tp_rank = dist.get_rank()
|
| 18 |
+
self.tp_size = dist.get_world_size()
|
| 19 |
+
assert num_embeddings % self.tp_size == 0
|
| 20 |
+
self.num_embeddings = num_embeddings
|
| 21 |
+
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
| 22 |
+
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
| 23 |
+
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
| 24 |
+
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
| 25 |
+
self.weight.weight_loader = self.weight_loader
|
| 26 |
+
|
| 27 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 28 |
+
param_data = param.data
|
| 29 |
+
shard_size = param_data.size(0)
|
| 30 |
+
start_idx = self.tp_rank * shard_size
|
| 31 |
+
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
| 32 |
+
param_data.copy_(loaded_weight)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor):
|
| 35 |
+
if self.tp_size > 1:
|
| 36 |
+
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
| 37 |
+
x = mask * (x - self.vocab_start_idx)
|
| 38 |
+
y = F.embedding(x, self.weight)
|
| 39 |
+
if self.tp_size > 1:
|
| 40 |
+
y = mask.unsqueeze(1) * y
|
| 41 |
+
dist.all_reduce(y)
|
| 42 |
+
return y
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ParallelLMHead(VocabParallelEmbedding):
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
num_embeddings: int,
|
| 50 |
+
embedding_dim: int,
|
| 51 |
+
bias: bool = False,
|
| 52 |
+
):
|
| 53 |
+
assert not bias
|
| 54 |
+
super().__init__(num_embeddings, embedding_dim)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor):
|
| 57 |
+
context = get_context()
|
| 58 |
+
if context.is_prefill:
|
| 59 |
+
last_indices = context.cu_seqlens_q[1:] - 1
|
| 60 |
+
x = x[last_indices].contiguous()
|
| 61 |
+
logits = F.linear(x, self.weight)
|
| 62 |
+
if self.tp_size > 1:
|
| 63 |
+
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
| 64 |
+
dist.gather(logits, all_logits, 0)
|
| 65 |
+
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
| 66 |
+
return logits
|
acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RMSNorm(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
hidden_size: int,
|
| 10 |
+
eps: float = 1e-6,
|
| 11 |
+
) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.eps = eps
|
| 14 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 15 |
+
|
| 16 |
+
@torch.compile
|
| 17 |
+
def rms_forward(
|
| 18 |
+
self,
|
| 19 |
+
x: torch.Tensor,
|
| 20 |
+
) -> torch.Tensor:
|
| 21 |
+
orig_dtype = x.dtype
|
| 22 |
+
x = x.float()
|
| 23 |
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
| 24 |
+
x.mul_(torch.rsqrt(var + self.eps))
|
| 25 |
+
x = x.to(orig_dtype).mul_(self.weight)
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
@torch.compile
|
| 29 |
+
def add_rms_forward(
|
| 30 |
+
self,
|
| 31 |
+
x: torch.Tensor,
|
| 32 |
+
residual: torch.Tensor,
|
| 33 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 34 |
+
orig_dtype = x.dtype
|
| 35 |
+
x = x.float().add_(residual.float())
|
| 36 |
+
residual = x.to(orig_dtype)
|
| 37 |
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
| 38 |
+
x.mul_(torch.rsqrt(var + self.eps))
|
| 39 |
+
x = x.to(orig_dtype).mul_(self.weight)
|
| 40 |
+
return x, residual
|
| 41 |
+
|
| 42 |
+
def forward(
|
| 43 |
+
self,
|
| 44 |
+
x: torch.Tensor,
|
| 45 |
+
residual: torch.Tensor | None = None,
|
| 46 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
| 47 |
+
if residual is None:
|
| 48 |
+
return self.rms_forward(x)
|
| 49 |
+
else:
|
| 50 |
+
return self.add_rms_forward(x, residual)
|
acestep/third_parts/nano-vllm/nanovllm/layers/linear.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def divide(numerator, denominator):
|
| 8 |
+
assert numerator % denominator == 0
|
| 9 |
+
return numerator // denominator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LinearBase(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
input_size: int,
|
| 17 |
+
output_size: int,
|
| 18 |
+
bias: bool = False,
|
| 19 |
+
tp_dim: int | None = None,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.tp_dim = tp_dim
|
| 23 |
+
self.tp_rank = dist.get_rank()
|
| 24 |
+
self.tp_size = dist.get_world_size()
|
| 25 |
+
self.weight = nn.Parameter(torch.empty(output_size, input_size))
|
| 26 |
+
self.weight.weight_loader = self.weight_loader
|
| 27 |
+
if bias:
|
| 28 |
+
self.bias = nn.Parameter(torch.empty(output_size))
|
| 29 |
+
self.bias.weight_loader = self.weight_loader
|
| 30 |
+
else:
|
| 31 |
+
self.register_parameter("bias", None)
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
raise NotImplementedError
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ReplicatedLinear(LinearBase):
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
input_size: int,
|
| 42 |
+
output_size: int,
|
| 43 |
+
bias: bool = False,
|
| 44 |
+
):
|
| 45 |
+
super().__init__(input_size, output_size, bias)
|
| 46 |
+
|
| 47 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 48 |
+
param.data.copy_(loaded_weight)
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
return F.linear(x, self.weight, self.bias)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ColumnParallelLinear(LinearBase):
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
input_size: int,
|
| 59 |
+
output_size: int,
|
| 60 |
+
bias: bool = False,
|
| 61 |
+
):
|
| 62 |
+
tp_size = dist.get_world_size()
|
| 63 |
+
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
|
| 64 |
+
|
| 65 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 66 |
+
param_data = param.data
|
| 67 |
+
shard_size = param_data.size(self.tp_dim)
|
| 68 |
+
start_idx = self.tp_rank * shard_size
|
| 69 |
+
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
| 70 |
+
param_data.copy_(loaded_weight)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return F.linear(x, self.weight, self.bias)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MergedColumnParallelLinear(ColumnParallelLinear):
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
input_size: int,
|
| 81 |
+
output_sizes: list[int],
|
| 82 |
+
bias: bool = False,
|
| 83 |
+
):
|
| 84 |
+
self.output_sizes = output_sizes
|
| 85 |
+
super().__init__(input_size, sum(output_sizes), bias)
|
| 86 |
+
|
| 87 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
| 88 |
+
param_data = param.data
|
| 89 |
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
| 90 |
+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
| 91 |
+
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
| 92 |
+
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
| 93 |
+
param_data.copy_(loaded_weight)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class QKVParallelLinear(ColumnParallelLinear):
|
| 97 |
+
|
| 98 |
+
def __init__(
|
| 99 |
+
self,
|
| 100 |
+
hidden_size: int,
|
| 101 |
+
head_size: int,
|
| 102 |
+
total_num_heads: int,
|
| 103 |
+
total_num_kv_heads: int | None = None,
|
| 104 |
+
bias: bool = False,
|
| 105 |
+
):
|
| 106 |
+
tp_size = dist.get_world_size()
|
| 107 |
+
total_num_kv_heads = total_num_kv_heads or total_num_heads
|
| 108 |
+
self.head_size = head_size
|
| 109 |
+
self.num_heads = divide(total_num_heads, tp_size)
|
| 110 |
+
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
|
| 111 |
+
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
|
| 112 |
+
super().__init__(hidden_size, output_size, bias)
|
| 113 |
+
|
| 114 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
| 115 |
+
param_data = param.data
|
| 116 |
+
assert loaded_shard_id in ["q", "k", "v"]
|
| 117 |
+
if loaded_shard_id == "q":
|
| 118 |
+
shard_size = self.num_heads * self.head_size
|
| 119 |
+
shard_offset = 0
|
| 120 |
+
elif loaded_shard_id == "k":
|
| 121 |
+
shard_size = self.num_kv_heads * self.head_size
|
| 122 |
+
shard_offset = self.num_heads * self.head_size
|
| 123 |
+
else:
|
| 124 |
+
shard_size = self.num_kv_heads * self.head_size
|
| 125 |
+
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
| 126 |
+
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
| 127 |
+
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
| 128 |
+
param_data.copy_(loaded_weight)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class RowParallelLinear(LinearBase):
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
input_size: int,
|
| 136 |
+
output_size: int,
|
| 137 |
+
bias: bool = False,
|
| 138 |
+
):
|
| 139 |
+
tp_size = dist.get_world_size()
|
| 140 |
+
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
|
| 141 |
+
|
| 142 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 143 |
+
param_data = param.data
|
| 144 |
+
shard_size = param_data.size(self.tp_dim)
|
| 145 |
+
start_idx = self.tp_rank * shard_size
|
| 146 |
+
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
| 147 |
+
param_data.copy_(loaded_weight)
|
| 148 |
+
|
| 149 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 150 |
+
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
| 151 |
+
if self.tp_size > 1:
|
| 152 |
+
dist.all_reduce(y)
|
| 153 |
+
return y
|
acestep/third_parts/nano-vllm/nanovllm/layers/rotary_embedding.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import lru_cache
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def apply_rotary_emb(
|
| 7 |
+
x: torch.Tensor,
|
| 8 |
+
cos: torch.Tensor,
|
| 9 |
+
sin: torch.Tensor,
|
| 10 |
+
) -> torch.Tensor:
|
| 11 |
+
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
| 12 |
+
y1 = x1 * cos - x2 * sin
|
| 13 |
+
y2 = x2 * cos + x1 * sin
|
| 14 |
+
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RotaryEmbedding(nn.Module):
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
head_size: int,
|
| 22 |
+
rotary_dim: int,
|
| 23 |
+
max_position_embeddings: int,
|
| 24 |
+
base: float,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.head_size = head_size
|
| 28 |
+
assert rotary_dim == head_size
|
| 29 |
+
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
| 30 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
| 31 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
| 32 |
+
cos = freqs.cos()
|
| 33 |
+
sin = freqs.sin()
|
| 34 |
+
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
|
| 35 |
+
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
| 36 |
+
|
| 37 |
+
@torch.compile
|
| 38 |
+
def forward(
|
| 39 |
+
self,
|
| 40 |
+
positions: torch.Tensor,
|
| 41 |
+
query: torch.Tensor,
|
| 42 |
+
key: torch.Tensor,
|
| 43 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 44 |
+
cos_sin = self.cos_sin_cache[positions]
|
| 45 |
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
| 46 |
+
query = apply_rotary_emb(query, cos, sin)
|
| 47 |
+
key = apply_rotary_emb(key, cos, sin)
|
| 48 |
+
return query, key
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@lru_cache(1)
|
| 52 |
+
def get_rope(
|
| 53 |
+
head_size: int,
|
| 54 |
+
rotary_dim: int,
|
| 55 |
+
max_position: int,
|
| 56 |
+
base: float,
|
| 57 |
+
rope_scaling: dict | None = None,
|
| 58 |
+
):
|
| 59 |
+
assert rope_scaling is None
|
| 60 |
+
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
| 61 |
+
return rotary_emb
|
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def apply_top_k_top_p(
|
| 7 |
+
logits: torch.Tensor,
|
| 8 |
+
k: Optional[torch.Tensor],
|
| 9 |
+
p: Optional[torch.Tensor],
|
| 10 |
+
) -> torch.Tensor:
|
| 11 |
+
"""Apply top-k and top-p masks to the logits (vLLM style).
|
| 12 |
+
|
| 13 |
+
The logits tensor is updated in-place.
|
| 14 |
+
"""
|
| 15 |
+
if p is None:
|
| 16 |
+
if k is None:
|
| 17 |
+
return logits
|
| 18 |
+
# Avoid sorting vocab for top-k only case
|
| 19 |
+
return apply_top_k_only(logits, k)
|
| 20 |
+
|
| 21 |
+
# Need to sort for top-p
|
| 22 |
+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
| 23 |
+
|
| 24 |
+
if k is not None:
|
| 25 |
+
# Apply top-k first
|
| 26 |
+
vocab_size = logits_sort.size(1)
|
| 27 |
+
# Clamp k to valid range
|
| 28 |
+
k_clamped = k.clamp(1, vocab_size).long()
|
| 29 |
+
top_k_mask_idx = vocab_size - k_clamped # shape: [B]
|
| 30 |
+
# Get the threshold value for each batch
|
| 31 |
+
top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
|
| 32 |
+
top_k_mask = logits_sort < top_k_thresh
|
| 33 |
+
logits_sort.masked_fill_(top_k_mask, float('-inf'))
|
| 34 |
+
|
| 35 |
+
# Apply top-p
|
| 36 |
+
probs_sort = logits_sort.softmax(dim=-1)
|
| 37 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
|
| 38 |
+
top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
|
| 39 |
+
# Ensure at least one token is kept
|
| 40 |
+
top_p_mask[:, -1] = False
|
| 41 |
+
logits_sort.masked_fill_(top_p_mask, float('-inf'))
|
| 42 |
+
|
| 43 |
+
# Re-sort back to original positions
|
| 44 |
+
logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
| 45 |
+
return logits
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def apply_top_k_only(
|
| 49 |
+
logits: torch.Tensor,
|
| 50 |
+
k: torch.Tensor,
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""Apply top-k mask without sorting the entire vocab (vLLM style).
|
| 53 |
+
|
| 54 |
+
This is much faster than sorting for top-k only cases.
|
| 55 |
+
The logits tensor is updated in-place.
|
| 56 |
+
"""
|
| 57 |
+
vocab_size = logits.shape[1]
|
| 58 |
+
# Handle cases where k >= vocab_size (no filtering needed)
|
| 59 |
+
no_top_k_mask = (k <= 0) | (k >= vocab_size)
|
| 60 |
+
# Set invalid k to 1 so we can still gather
|
| 61 |
+
k_safe = k.masked_fill(no_top_k_mask, 1).long()
|
| 62 |
+
# NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
|
| 63 |
+
max_top_k = int(k_safe.max().clamp(max=vocab_size))
|
| 64 |
+
|
| 65 |
+
# Get top-k values for all batches
|
| 66 |
+
# topk.values has shape [batch_size, max_top_k]
|
| 67 |
+
topk_values = logits.topk(max_top_k, dim=1).values
|
| 68 |
+
|
| 69 |
+
# Convert k to 0-based index: we want the k-th largest value (index k-1)
|
| 70 |
+
# Clamp to valid range for gather
|
| 71 |
+
k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
|
| 72 |
+
# Gather the threshold value (the k-th largest)
|
| 73 |
+
top_k_thresh = topk_values.gather(1, k_index)
|
| 74 |
+
|
| 75 |
+
# For rows with no top-k filtering, set threshold to -inf so nothing gets masked
|
| 76 |
+
top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
|
| 77 |
+
|
| 78 |
+
# Mask all values below the threshold
|
| 79 |
+
logits.masked_fill_(logits < top_k_thresh, float('-inf'))
|
| 80 |
+
return logits
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Sampler(nn.Module):
|
| 84 |
+
|
| 85 |
+
def __init__(self):
|
| 86 |
+
super().__init__()
|
| 87 |
+
|
| 88 |
+
@torch.compile
|
| 89 |
+
def forward(
|
| 90 |
+
self,
|
| 91 |
+
logits: torch.Tensor,
|
| 92 |
+
temperatures: torch.Tensor,
|
| 93 |
+
top_ks: Optional[torch.Tensor] = None,
|
| 94 |
+
top_ps: Optional[torch.Tensor] = None,
|
| 95 |
+
repetition_penalties: Optional[torch.Tensor] = None,
|
| 96 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 97 |
+
):
|
| 98 |
+
"""
|
| 99 |
+
Sample tokens from logits with optional top-k and top-p filtering.
|
| 100 |
+
|
| 101 |
+
Condition checking is done OUTSIDE the compiled function to avoid
|
| 102 |
+
graph breaks from .any() calls.
|
| 103 |
+
"""
|
| 104 |
+
# Apply temperature
|
| 105 |
+
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
|
| 106 |
+
|
| 107 |
+
logits = apply_top_k_top_p(
|
| 108 |
+
logits,
|
| 109 |
+
top_ks,
|
| 110 |
+
top_ps,
|
| 111 |
+
)
|
| 112 |
+
probs = torch.softmax(logits, dim=-1)
|
| 113 |
+
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
|
| 114 |
+
return sample_tokens
|
acestep/third_parts/nano-vllm/nanovllm/llm.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nanovllm.engine.llm_engine import LLMEngine
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LLM(LLMEngine):
|
| 5 |
+
pass
|