ACE-Step Custom
Deploy ACE-Step Custom Edition with bug fixes
a602628
"""
Gradio UI Training Tab Module
Contains the dataset builder and LoRA training interface components.
"""
import os
import gradio as gr
from acestep.gradio_ui.i18n import t
from acestep.constants import DEBUG_TRAINING
def create_training_section(dit_handler, llm_handler, init_params=None) -> dict:
"""Create the training tab section with dataset builder and training controls.
Args:
dit_handler: DiT handler instance
llm_handler: LLM handler instance
init_params: Dictionary containing initialization parameters and state.
If None, service will not be pre-initialized.
Returns:
Dictionary of Gradio components for event handling
"""
# Check if running in service mode (hide training tab)
service_mode = init_params is not None and init_params.get('service_mode', False)
debug_training_enabled = str(DEBUG_TRAINING).strip().upper() != "OFF"
epoch_min = 1 if debug_training_enabled else 100
epoch_step = 1 if debug_training_enabled else 100
epoch_default = 1 if debug_training_enabled else 1000
with gr.Tab(t("training.tab_title"), visible=not service_mode):
gr.HTML("""
<div style="text-align: center; padding: 10px; margin-bottom: 15px;">
<h2>🎵 LoRA Training for ACE-Step</h2>
<p>Build datasets from your audio files and train custom LoRA adapters</p>
</div>
""")
with gr.Tabs():
# ==================== Dataset Builder Tab ====================
with gr.Tab(t("training.tab_dataset_builder")):
# ========== Load Existing OR Scan New ==========
gr.HTML(f"""
<div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
<h3 style="margin: 0 0 5px 0;">{t("training.quick_start_title")}</h3>
<p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h4>📂 Load Existing Dataset</h4>")
with gr.Row():
load_json_path = gr.Textbox(
label=t("training.load_dataset_label"),
placeholder="./datasets/my_lora_dataset.json",
info=t("training.load_dataset_info"),
scale=3,
)
load_json_btn = gr.Button(t("training.load_btn"), variant="primary", scale=1)
load_json_status = gr.Textbox(
label=t("training.load_status"),
interactive=False,
)
with gr.Column(scale=1):
gr.HTML("<h4>🔍 Scan New Directory</h4>")
with gr.Row():
audio_directory = gr.Textbox(
label=t("training.scan_label"),
placeholder="/path/to/your/audio/folder",
info=t("training.scan_info"),
scale=3,
)
scan_btn = gr.Button(t("training.scan_btn"), variant="secondary", scale=1)
scan_status = gr.Textbox(
label=t("training.scan_status"),
interactive=False,
)
gr.HTML("<hr>")
with gr.Row():
with gr.Column(scale=2):
# Audio files table
audio_files_table = gr.Dataframe(
headers=["#", "Filename", "Duration", "Lyrics", "Labeled", "BPM", "Key", "Caption"],
datatype=["number", "str", "str", "str", "str", "str", "str", "str"],
label=t("training.found_audio_files"),
interactive=False,
wrap=True,
)
with gr.Column(scale=1):
gr.HTML(f"<h3>⚙️ {t('training.dataset_settings_header')}</h3>")
dataset_name = gr.Textbox(
label=t("training.dataset_name"),
value="my_lora_dataset",
placeholder=t("training.dataset_name_placeholder"),
)
all_instrumental = gr.Checkbox(
label=t("training.all_instrumental"),
value=True,
info=t("training.all_instrumental_info"),
)
format_lyrics = gr.Checkbox(
label="Format Lyrics (LM)",
value=False,
info="Use LM to format/structure user-provided lyrics from .txt files (coming soon)",
interactive=False, # Disabled for now - model update needed
)
transcribe_lyrics = gr.Checkbox(
label="Transcribe Lyrics (LM)",
value=False,
info="Use LM to transcribe lyrics from audio (coming soon)",
interactive=False, # Disabled for now - model update needed
)
custom_tag = gr.Textbox(
label=t("training.custom_tag"),
placeholder="e.g., 8bit_retro, my_style",
info=t("training.custom_tag_info"),
)
tag_position = gr.Radio(
choices=[
(t("training.tag_prepend"), "prepend"),
(t("training.tag_append"), "append"),
(t("training.tag_replace"), "replace"),
],
value="replace",
label=t("training.tag_position"),
info=t("training.tag_position_info"),
)
genre_ratio = gr.Slider(
minimum=0,
maximum=100,
step=10,
value=0,
label=t("training.genre_ratio"),
info=t("training.genre_ratio_info"),
)
gr.HTML(f"<hr><h3>🤖 {t('training.step2_title')}</h3>")
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("""
Click the button below to automatically generate metadata for all audio files using AI:
- **Caption**: Music style, genre, mood description
- **BPM**: Beats per minute
- **Key**: Musical key (e.g., C Major, Am)
- **Time Signature**: 4/4, 3/4, etc.
""")
skip_metas = gr.Checkbox(
label=t("training.skip_metas"),
value=False,
info=t("training.skip_metas_info"),
)
only_unlabeled = gr.Checkbox(
label=t("training.only_unlabeled"),
value=False,
info=t("training.only_unlabeled_info"),
)
with gr.Column(scale=1):
auto_label_btn = gr.Button(
t("training.auto_label_btn"),
variant="primary",
size="lg",
)
label_progress = gr.Textbox(
label=t("training.label_progress"),
interactive=False,
lines=2,
)
gr.HTML(f"<hr><h3>👀 {t('training.step3_title')}</h3>")
with gr.Row():
with gr.Column(scale=1):
sample_selector = gr.Slider(
minimum=0,
maximum=0,
step=1,
value=0,
label=t("training.select_sample"),
info=t("training.select_sample_info"),
)
preview_audio = gr.Audio(
label=t("training.audio_preview"),
type="filepath",
interactive=False,
)
preview_filename = gr.Textbox(
label=t("training.filename"),
interactive=False,
)
with gr.Column(scale=2):
with gr.Row():
edit_caption = gr.Textbox(
label=t("training.caption"),
lines=3,
placeholder="Music description...",
)
with gr.Row():
edit_genre = gr.Textbox(
label=t("training.genre"),
lines=1,
placeholder="pop, electronic, dance...",
)
prompt_override = gr.Dropdown(
choices=["Use Global Ratio", "Caption", "Genre"],
value="Use Global Ratio",
label=t("training.prompt_override_label"),
info=t("training.prompt_override_info"),
)
with gr.Row():
edit_lyrics = gr.Textbox(
label=t("training.lyrics_editable_label"),
lines=6,
placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
)
raw_lyrics_display = gr.Textbox(
label=t("training.raw_lyrics_label"),
lines=6,
placeholder=t("training.no_lyrics_placeholder"),
interactive=False, # Read-only, can copy but not edit
visible=False, # Hidden when no raw lyrics
)
has_raw_lyrics_state = gr.State(False) # Track visibility
with gr.Row():
edit_bpm = gr.Number(
label=t("training.bpm"),
precision=0,
)
edit_keyscale = gr.Textbox(
label=t("training.key_label"),
placeholder=t("training.key_placeholder"),
)
edit_timesig = gr.Dropdown(
choices=["", "2", "3", "4", "6", "N/A"],
label=t("training.time_sig"),
)
edit_duration = gr.Number(
label=t("training.duration_s"),
precision=1,
interactive=False,
)
with gr.Row():
edit_language = gr.Dropdown(
choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
value="instrumental",
label=t("training.language"),
)
edit_instrumental = gr.Checkbox(
label=t("training.instrumental"),
value=True,
)
save_edit_btn = gr.Button(t("training.save_changes_btn"), variant="secondary")
edit_status = gr.Textbox(
label=t("training.edit_status"),
interactive=False,
)
gr.HTML(f"<hr><h3>💾 {t('training.step4_title')}</h3>")
with gr.Row():
with gr.Column(scale=3):
save_path = gr.Textbox(
label=t("training.save_path"),
value="./datasets/my_lora_dataset.json",
placeholder="./datasets/dataset_name.json",
info=t("training.save_path_info"),
)
with gr.Column(scale=1):
save_dataset_btn = gr.Button(
t("training.save_dataset_btn"),
variant="primary",
size="lg",
)
save_status = gr.Textbox(
label=t("training.save_status"),
interactive=False,
lines=2,
)
gr.HTML(f"<hr><h3>⚡ {t('training.step5_title')}</h3>")
gr.Markdown("""
**Preprocessing converts your dataset to pre-computed tensors for fast training.**
You can either:
- Use the dataset from Steps 1-4 above, **OR**
- Load an existing dataset JSON file (if you've already saved one)
""")
with gr.Row():
with gr.Column(scale=3):
load_existing_dataset_path = gr.Textbox(
label=t("training.load_existing_label"),
placeholder="./datasets/my_lora_dataset.json",
info=t("training.load_existing_info"),
)
with gr.Column(scale=1):
load_existing_dataset_btn = gr.Button(
t("training.load_dataset_btn"),
variant="secondary",
size="lg",
)
load_existing_status = gr.Textbox(
label=t("training.load_status"),
interactive=False,
)
gr.Markdown("""
This step:
- Encodes audio to VAE latents
- Encodes captions and lyrics to text embeddings
- Runs the condition encoder
- Saves all tensors to `.pt` files
⚠️ **This requires the model to be loaded and may take a few minutes.**
""")
with gr.Row():
with gr.Column(scale=3):
preprocess_output_dir = gr.Textbox(
label=t("training.tensor_output_dir"),
value="./datasets/preprocessed_tensors",
placeholder="./datasets/preprocessed_tensors",
info=t("training.tensor_output_info"),
)
with gr.Column(scale=1):
preprocess_btn = gr.Button(
t("training.preprocess_btn"),
variant="primary",
size="lg",
)
preprocess_progress = gr.Textbox(
label=t("training.preprocess_progress"),
interactive=False,
lines=3,
)
# ==================== Training Tab ====================
with gr.Tab(t("training.tab_train_lora")):
with gr.Row():
with gr.Column(scale=2):
gr.HTML(f"<h3>📊 {t('training.train_section_tensors')}</h3>")
gr.Markdown("""
Select the directory containing preprocessed tensor files (`.pt` files).
These are created in the "Dataset Builder" tab using the "Preprocess" button.
""")
training_tensor_dir = gr.Textbox(
label=t("training.preprocessed_tensors_dir"),
placeholder="./datasets/preprocessed_tensors",
value="./datasets/preprocessed_tensors",
info=t("training.preprocessed_tensors_info"),
)
load_dataset_btn = gr.Button(t("training.load_dataset_btn"), variant="secondary")
training_dataset_info = gr.Textbox(
label=t("training.dataset_info"),
interactive=False,
lines=3,
)
with gr.Column(scale=1):
gr.HTML(f"<h3>⚙️ {t('training.train_section_lora')}</h3>")
lora_rank = gr.Slider(
minimum=4,
maximum=256,
step=4,
value=64,
label=t("training.lora_rank"),
info=t("training.lora_rank_info"),
)
lora_alpha = gr.Slider(
minimum=4,
maximum=512,
step=4,
value=128,
label=t("training.lora_alpha"),
info=t("training.lora_alpha_info"),
)
lora_dropout = gr.Slider(
minimum=0.0,
maximum=0.5,
step=0.05,
value=0.1,
label=t("training.lora_dropout"),
)
gr.HTML(f"<hr><h3>🎛️ {t('training.train_section_params')}</h3>")
with gr.Row():
learning_rate = gr.Number(
label=t("training.learning_rate"),
value=3e-4,
info=t("training.learning_rate_info"),
)
train_epochs = gr.Slider(
minimum=epoch_min,
maximum=4000,
step=epoch_step,
value=epoch_default,
label=t("training.max_epochs"),
)
train_batch_size = gr.Slider(
minimum=1,
maximum=8,
step=1,
value=1,
label=t("training.batch_size"),
info=t("training.batch_size_info"),
)
gradient_accumulation = gr.Slider(
minimum=1,
maximum=16,
step=1,
value=1,
label=t("training.gradient_accumulation"),
info=t("training.gradient_accumulation_info"),
)
with gr.Row():
save_every_n_epochs = gr.Slider(
minimum=50,
maximum=1000,
step=50,
value=200,
label=t("training.save_every_n_epochs"),
)
training_shift = gr.Slider(
minimum=1.0,
maximum=5.0,
step=0.5,
value=3.0,
label=t("training.shift"),
info=t("training.shift_info"),
)
training_seed = gr.Number(
label=t("training.seed"),
value=42,
precision=0,
)
with gr.Row():
lora_output_dir = gr.Textbox(
label=t("training.output_dir"),
value="./lora_output",
placeholder="./lora_output",
info=t("training.output_dir_info"),
)
with gr.Row():
resume_checkpoint_dir = gr.Textbox(
label="Resume Checkpoint (optional)",
placeholder="./lora_output/checkpoints/epoch_200",
info="Directory of a saved LoRA checkpoint to resume from",
)
gr.HTML("<hr>")
with gr.Row():
with gr.Column(scale=1):
start_training_btn = gr.Button(
t("training.start_training_btn"),
variant="primary",
size="lg",
)
with gr.Column(scale=1):
stop_training_btn = gr.Button(
t("training.stop_training_btn"),
variant="stop",
size="lg",
)
training_progress = gr.Textbox(
label=t("training.training_progress"),
interactive=False,
lines=2,
)
with gr.Row():
training_log = gr.Textbox(
label=t("training.training_log"),
interactive=False,
lines=10,
max_lines=15,
scale=1,
)
training_loss_plot = gr.LinePlot(
x="step",
y="loss",
title=t("training.training_loss_title"),
x_title=t("training.step"),
y_title=t("training.loss"),
scale=1,
)
gr.HTML(f"<hr><h3>📦 {t('training.export_header')}</h3>")
with gr.Row():
export_path = gr.Textbox(
label=t("training.export_path"),
value="./lora_output/final_lora",
placeholder="./lora_output/my_lora",
)
export_lora_btn = gr.Button(t("training.export_lora_btn"), variant="secondary")
export_status = gr.Textbox(
label=t("training.export_status"),
interactive=False,
)
# Store dataset builder state
dataset_builder_state = gr.State(None)
training_state = gr.State({"is_training": False, "should_stop": False})
return {
# Dataset Builder - Load or Scan
"load_json_path": load_json_path,
"load_json_btn": load_json_btn,
"load_json_status": load_json_status,
"audio_directory": audio_directory,
"scan_btn": scan_btn,
"scan_status": scan_status,
"audio_files_table": audio_files_table,
"dataset_name": dataset_name,
"all_instrumental": all_instrumental,
"format_lyrics": format_lyrics,
"transcribe_lyrics": transcribe_lyrics,
"custom_tag": custom_tag,
"tag_position": tag_position,
"skip_metas": skip_metas,
"only_unlabeled": only_unlabeled,
"auto_label_btn": auto_label_btn,
"label_progress": label_progress,
"sample_selector": sample_selector,
"preview_audio": preview_audio,
"preview_filename": preview_filename,
"edit_caption": edit_caption,
"edit_genre": edit_genre,
"prompt_override": prompt_override,
"genre_ratio": genre_ratio,
"edit_lyrics": edit_lyrics,
"raw_lyrics_display": raw_lyrics_display,
"has_raw_lyrics_state": has_raw_lyrics_state,
"edit_bpm": edit_bpm,
"edit_keyscale": edit_keyscale,
"edit_timesig": edit_timesig,
"edit_duration": edit_duration,
"edit_language": edit_language,
"edit_instrumental": edit_instrumental,
"save_edit_btn": save_edit_btn,
"edit_status": edit_status,
"save_path": save_path,
"save_dataset_btn": save_dataset_btn,
"save_status": save_status,
# Preprocessing
"load_existing_dataset_path": load_existing_dataset_path,
"load_existing_dataset_btn": load_existing_dataset_btn,
"load_existing_status": load_existing_status,
"preprocess_output_dir": preprocess_output_dir,
"preprocess_btn": preprocess_btn,
"preprocess_progress": preprocess_progress,
"dataset_builder_state": dataset_builder_state,
# Training
"training_tensor_dir": training_tensor_dir,
"load_dataset_btn": load_dataset_btn,
"training_dataset_info": training_dataset_info,
"lora_rank": lora_rank,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
"learning_rate": learning_rate,
"train_epochs": train_epochs,
"train_batch_size": train_batch_size,
"gradient_accumulation": gradient_accumulation,
"save_every_n_epochs": save_every_n_epochs,
"training_shift": training_shift,
"training_seed": training_seed,
"lora_output_dir": lora_output_dir,
"resume_checkpoint_dir": resume_checkpoint_dir,
"start_training_btn": start_training_btn,
"stop_training_btn": stop_training_btn,
"training_progress": training_progress,
"training_log": training_log,
"training_loss_plot": training_loss_plot,
"export_path": export_path,
"export_lora_btn": export_lora_btn,
"export_status": export_status,
"training_state": training_state,
}