|
|
""" |
|
|
Gradio UI Training Tab Module |
|
|
|
|
|
Contains the dataset builder and LoRA training interface components. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import gradio as gr |
|
|
from acestep.gradio_ui.i18n import t |
|
|
|
|
|
|
|
|
def create_training_section(dit_handler, llm_handler, init_params=None) -> dict: |
|
|
"""Create the training tab section with dataset builder and training controls. |
|
|
|
|
|
Args: |
|
|
dit_handler: DiT handler instance |
|
|
llm_handler: LLM handler instance |
|
|
init_params: Dictionary containing initialization parameters and state. |
|
|
If None, service will not be pre-initialized. |
|
|
|
|
|
Returns: |
|
|
Dictionary of Gradio components for event handling |
|
|
""" |
|
|
|
|
|
service_mode = init_params is not None and init_params.get('service_mode', False) |
|
|
|
|
|
with gr.Tab("π LoRA Training", visible=not service_mode): |
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; padding: 10px; margin-bottom: 15px;"> |
|
|
<h2>π΅ LoRA Training for ACE-Step</h2> |
|
|
<p>Build datasets from your audio files and train custom LoRA adapters</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("π Dataset Builder"): |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);"> |
|
|
<h3 style="margin: 0 0 5px 0;">π Quick Start</h3> |
|
|
<p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.HTML("<h4>π Load Existing Dataset</h4>") |
|
|
with gr.Row(): |
|
|
load_json_path = gr.Textbox( |
|
|
label="Dataset JSON Path", |
|
|
placeholder="./datasets/my_lora_dataset.json", |
|
|
info="Load a previously saved dataset", |
|
|
scale=3, |
|
|
) |
|
|
load_json_btn = gr.Button("π Load", variant="primary", scale=1) |
|
|
load_json_status = gr.Textbox( |
|
|
label="Load Status", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.HTML("<h4>π Scan New Directory</h4>") |
|
|
with gr.Row(): |
|
|
audio_directory = gr.Textbox( |
|
|
label="Audio Directory Path", |
|
|
placeholder="/path/to/your/audio/folder", |
|
|
info="Scan for audio files (wav, mp3, flac, ogg, opus)", |
|
|
scale=3, |
|
|
) |
|
|
scan_btn = gr.Button("π Scan", variant="secondary", scale=1) |
|
|
scan_status = gr.Textbox( |
|
|
label="Scan Status", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
gr.HTML("<hr>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
|
|
|
|
|
|
audio_files_table = gr.Dataframe( |
|
|
headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"], |
|
|
datatype=["number", "str", "str", "str", "str", "str", "str"], |
|
|
label="Found Audio Files", |
|
|
interactive=False, |
|
|
wrap=True, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.HTML("<h3>βοΈ Dataset Settings</h3>") |
|
|
|
|
|
dataset_name = gr.Textbox( |
|
|
label="Dataset Name", |
|
|
value="my_lora_dataset", |
|
|
placeholder="Enter dataset name", |
|
|
) |
|
|
|
|
|
all_instrumental = gr.Checkbox( |
|
|
label="All Instrumental", |
|
|
value=True, |
|
|
info="Check if all tracks are instrumental (no vocals)", |
|
|
) |
|
|
|
|
|
need_lyrics = gr.Checkbox( |
|
|
label="Transcribe Lyrics", |
|
|
value=False, |
|
|
info="Attempt to transcribe lyrics (slower)", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
custom_tag = gr.Textbox( |
|
|
label="Custom Activation Tag", |
|
|
placeholder="e.g., 8bit_retro, my_style", |
|
|
info="Unique tag to activate this LoRA's style", |
|
|
) |
|
|
|
|
|
tag_position = gr.Radio( |
|
|
choices=[ |
|
|
("Prepend (tag, caption)", "prepend"), |
|
|
("Append (caption, tag)", "append"), |
|
|
("Replace caption", "replace"), |
|
|
], |
|
|
value="replace", |
|
|
label="Tag Position", |
|
|
info="Where to place the custom tag in the caption", |
|
|
) |
|
|
|
|
|
gr.HTML("<hr><h3>π€ Step 2: Auto-Label with AI</h3>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
gr.Markdown(""" |
|
|
Click the button below to automatically generate metadata for all audio files using AI: |
|
|
- **Caption**: Music style, genre, mood description |
|
|
- **BPM**: Beats per minute |
|
|
- **Key**: Musical key (e.g., C Major, Am) |
|
|
- **Time Signature**: 4/4, 3/4, etc. |
|
|
""") |
|
|
skip_metas = gr.Checkbox( |
|
|
label="Skip Metas (No LLM)", |
|
|
value=False, |
|
|
info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental", |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
auto_label_btn = gr.Button( |
|
|
"π·οΈ Auto-Label All", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
) |
|
|
|
|
|
label_progress = gr.Textbox( |
|
|
label="Labeling Progress", |
|
|
interactive=False, |
|
|
lines=2, |
|
|
) |
|
|
|
|
|
gr.HTML("<hr><h3>π Step 3: Preview & Edit</h3>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
sample_selector = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=0, |
|
|
step=1, |
|
|
value=0, |
|
|
label="Select Sample #", |
|
|
info="Choose a sample to preview and edit", |
|
|
) |
|
|
|
|
|
preview_audio = gr.Audio( |
|
|
label="Audio Preview", |
|
|
type="filepath", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
preview_filename = gr.Textbox( |
|
|
label="Filename", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
with gr.Row(): |
|
|
edit_caption = gr.Textbox( |
|
|
label="Caption", |
|
|
lines=3, |
|
|
placeholder="Music description...", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
edit_lyrics = gr.Textbox( |
|
|
label="Lyrics", |
|
|
lines=4, |
|
|
placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
edit_bpm = gr.Number( |
|
|
label="BPM", |
|
|
precision=0, |
|
|
) |
|
|
edit_keyscale = gr.Textbox( |
|
|
label="Key", |
|
|
placeholder="C Major", |
|
|
) |
|
|
edit_timesig = gr.Dropdown( |
|
|
choices=["", "2", "3", "4", "6"], |
|
|
label="Time Signature", |
|
|
) |
|
|
edit_duration = gr.Number( |
|
|
label="Duration (s)", |
|
|
precision=1, |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
edit_language = gr.Dropdown( |
|
|
choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"], |
|
|
value="instrumental", |
|
|
label="Language", |
|
|
) |
|
|
edit_instrumental = gr.Checkbox( |
|
|
label="Instrumental", |
|
|
value=True, |
|
|
) |
|
|
save_edit_btn = gr.Button("πΎ Save Changes", variant="secondary") |
|
|
|
|
|
edit_status = gr.Textbox( |
|
|
label="Edit Status", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
gr.HTML("<hr><h3>πΎ Step 4: Save Dataset</h3>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
save_path = gr.Textbox( |
|
|
label="Save Path", |
|
|
value="./datasets/my_lora_dataset.json", |
|
|
placeholder="./datasets/dataset_name.json", |
|
|
info="Path where the dataset JSON will be saved", |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
save_dataset_btn = gr.Button( |
|
|
"πΎ Save Dataset", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
) |
|
|
|
|
|
save_status = gr.Textbox( |
|
|
label="Save Status", |
|
|
interactive=False, |
|
|
lines=2, |
|
|
) |
|
|
|
|
|
gr.HTML("<hr><h3>β‘ Step 5: Preprocess to Tensors</h3>") |
|
|
|
|
|
gr.Markdown(""" |
|
|
**Preprocessing converts your dataset to pre-computed tensors for fast training.** |
|
|
|
|
|
You can either: |
|
|
- Use the dataset from Steps 1-4 above, **OR** |
|
|
- Load an existing dataset JSON file (if you've already saved one) |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
load_existing_dataset_path = gr.Textbox( |
|
|
label="Load Existing Dataset (Optional)", |
|
|
placeholder="./datasets/my_lora_dataset.json", |
|
|
info="Path to a previously saved dataset JSON file", |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
load_existing_dataset_btn = gr.Button( |
|
|
"π Load Dataset", |
|
|
variant="secondary", |
|
|
size="lg", |
|
|
) |
|
|
|
|
|
load_existing_status = gr.Textbox( |
|
|
label="Load Status", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
This step: |
|
|
- Encodes audio to VAE latents |
|
|
- Encodes captions and lyrics to text embeddings |
|
|
- Runs the condition encoder |
|
|
- Saves all tensors to `.pt` files |
|
|
|
|
|
β οΈ **This requires the model to be loaded and may take a few minutes.** |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
preprocess_output_dir = gr.Textbox( |
|
|
label="Tensor Output Directory", |
|
|
value="./datasets/preprocessed_tensors", |
|
|
placeholder="./datasets/preprocessed_tensors", |
|
|
info="Directory to save preprocessed tensor files", |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
preprocess_btn = gr.Button( |
|
|
"β‘ Preprocess", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
) |
|
|
|
|
|
preprocess_progress = gr.Textbox( |
|
|
label="Preprocessing Progress", |
|
|
interactive=False, |
|
|
lines=3, |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π Train LoRA"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.HTML("<h3>π Preprocessed Dataset Selection</h3>") |
|
|
|
|
|
gr.Markdown(""" |
|
|
Select the directory containing preprocessed tensor files (`.pt` files). |
|
|
These are created in the "Dataset Builder" tab using the "Preprocess" button. |
|
|
""") |
|
|
|
|
|
training_tensor_dir = gr.Textbox( |
|
|
label="Preprocessed Tensors Directory", |
|
|
placeholder="./datasets/preprocessed_tensors", |
|
|
value="./datasets/preprocessed_tensors", |
|
|
info="Directory containing preprocessed .pt tensor files", |
|
|
) |
|
|
|
|
|
load_dataset_btn = gr.Button("π Load Dataset", variant="secondary") |
|
|
|
|
|
training_dataset_info = gr.Textbox( |
|
|
label="Dataset Info", |
|
|
interactive=False, |
|
|
lines=3, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.HTML("<h3>βοΈ LoRA Settings</h3>") |
|
|
|
|
|
lora_rank = gr.Slider( |
|
|
minimum=4, |
|
|
maximum=256, |
|
|
step=4, |
|
|
value=64, |
|
|
label="LoRA Rank (r)", |
|
|
info="Higher = more capacity, more memory", |
|
|
) |
|
|
|
|
|
lora_alpha = gr.Slider( |
|
|
minimum=4, |
|
|
maximum=512, |
|
|
step=4, |
|
|
value=128, |
|
|
label="LoRA Alpha", |
|
|
info="Scaling factor (typically 2x rank)", |
|
|
) |
|
|
|
|
|
lora_dropout = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=0.5, |
|
|
step=0.05, |
|
|
value=0.1, |
|
|
label="LoRA Dropout", |
|
|
) |
|
|
|
|
|
gr.HTML("<hr><h3>ποΈ Training Parameters</h3>") |
|
|
|
|
|
with gr.Row(): |
|
|
learning_rate = gr.Number( |
|
|
label="Learning Rate", |
|
|
value=1e-4, |
|
|
info="Start with 1e-4, adjust if needed", |
|
|
) |
|
|
|
|
|
train_epochs = gr.Slider( |
|
|
minimum=100, |
|
|
maximum=4000, |
|
|
step=100, |
|
|
value=500, |
|
|
label="Max Epochs", |
|
|
) |
|
|
|
|
|
train_batch_size = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=8, |
|
|
step=1, |
|
|
value=1, |
|
|
label="Batch Size", |
|
|
info="Increase if you have enough VRAM", |
|
|
) |
|
|
|
|
|
gradient_accumulation = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=16, |
|
|
step=1, |
|
|
value=1, |
|
|
label="Gradient Accumulation", |
|
|
info="Effective batch = batch_size Γ accumulation", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
save_every_n_epochs = gr.Slider( |
|
|
minimum=50, |
|
|
maximum=1000, |
|
|
step=50, |
|
|
value=200, |
|
|
label="Save Every N Epochs", |
|
|
) |
|
|
|
|
|
training_shift = gr.Slider( |
|
|
minimum=1.0, |
|
|
maximum=5.0, |
|
|
step=0.5, |
|
|
value=3.0, |
|
|
label="Shift", |
|
|
info="Timestep shift for turbo model", |
|
|
) |
|
|
|
|
|
training_seed = gr.Number( |
|
|
label="Seed", |
|
|
value=42, |
|
|
precision=0, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
lora_output_dir = gr.Textbox( |
|
|
label="Output Directory", |
|
|
value="./lora_output", |
|
|
placeholder="./lora_output", |
|
|
info="Directory to save trained LoRA weights", |
|
|
) |
|
|
|
|
|
gr.HTML("<hr>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
start_training_btn = gr.Button( |
|
|
"π Start Training", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
stop_training_btn = gr.Button( |
|
|
"βΉοΈ Stop Training", |
|
|
variant="stop", |
|
|
size="lg", |
|
|
) |
|
|
|
|
|
training_progress = gr.Textbox( |
|
|
label="Training Progress", |
|
|
interactive=False, |
|
|
lines=2, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
training_log = gr.Textbox( |
|
|
label="Training Log", |
|
|
interactive=False, |
|
|
lines=10, |
|
|
max_lines=15, |
|
|
scale=1, |
|
|
) |
|
|
training_loss_plot = gr.LinePlot( |
|
|
x="step", |
|
|
y="loss", |
|
|
title="Training Loss", |
|
|
x_title="Step", |
|
|
y_title="Loss", |
|
|
scale=1, |
|
|
) |
|
|
|
|
|
gr.HTML("<hr><h3>π¦ Export LoRA</h3>") |
|
|
|
|
|
with gr.Row(): |
|
|
export_path = gr.Textbox( |
|
|
label="Export Path", |
|
|
value="./lora_output/final_lora", |
|
|
placeholder="./lora_output/my_lora", |
|
|
) |
|
|
export_lora_btn = gr.Button("π¦ Export LoRA", variant="secondary") |
|
|
|
|
|
export_status = gr.Textbox( |
|
|
label="Export Status", |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
|
|
|
dataset_builder_state = gr.State(None) |
|
|
training_state = gr.State({"is_training": False, "should_stop": False}) |
|
|
|
|
|
return { |
|
|
|
|
|
"load_json_path": load_json_path, |
|
|
"load_json_btn": load_json_btn, |
|
|
"load_json_status": load_json_status, |
|
|
"audio_directory": audio_directory, |
|
|
"scan_btn": scan_btn, |
|
|
"scan_status": scan_status, |
|
|
"audio_files_table": audio_files_table, |
|
|
"dataset_name": dataset_name, |
|
|
"all_instrumental": all_instrumental, |
|
|
"need_lyrics": need_lyrics, |
|
|
"custom_tag": custom_tag, |
|
|
"tag_position": tag_position, |
|
|
"skip_metas": skip_metas, |
|
|
"auto_label_btn": auto_label_btn, |
|
|
"label_progress": label_progress, |
|
|
"sample_selector": sample_selector, |
|
|
"preview_audio": preview_audio, |
|
|
"preview_filename": preview_filename, |
|
|
"edit_caption": edit_caption, |
|
|
"edit_lyrics": edit_lyrics, |
|
|
"edit_bpm": edit_bpm, |
|
|
"edit_keyscale": edit_keyscale, |
|
|
"edit_timesig": edit_timesig, |
|
|
"edit_duration": edit_duration, |
|
|
"edit_language": edit_language, |
|
|
"edit_instrumental": edit_instrumental, |
|
|
"save_edit_btn": save_edit_btn, |
|
|
"edit_status": edit_status, |
|
|
"save_path": save_path, |
|
|
"save_dataset_btn": save_dataset_btn, |
|
|
"save_status": save_status, |
|
|
|
|
|
"load_existing_dataset_path": load_existing_dataset_path, |
|
|
"load_existing_dataset_btn": load_existing_dataset_btn, |
|
|
"load_existing_status": load_existing_status, |
|
|
"preprocess_output_dir": preprocess_output_dir, |
|
|
"preprocess_btn": preprocess_btn, |
|
|
"preprocess_progress": preprocess_progress, |
|
|
"dataset_builder_state": dataset_builder_state, |
|
|
|
|
|
"training_tensor_dir": training_tensor_dir, |
|
|
"load_dataset_btn": load_dataset_btn, |
|
|
"training_dataset_info": training_dataset_info, |
|
|
"lora_rank": lora_rank, |
|
|
"lora_alpha": lora_alpha, |
|
|
"lora_dropout": lora_dropout, |
|
|
"learning_rate": learning_rate, |
|
|
"train_epochs": train_epochs, |
|
|
"train_batch_size": train_batch_size, |
|
|
"gradient_accumulation": gradient_accumulation, |
|
|
"save_every_n_epochs": save_every_n_epochs, |
|
|
"training_shift": training_shift, |
|
|
"training_seed": training_seed, |
|
|
"lora_output_dir": lora_output_dir, |
|
|
"start_training_btn": start_training_btn, |
|
|
"stop_training_btn": stop_training_btn, |
|
|
"training_progress": training_progress, |
|
|
"training_log": training_log, |
|
|
"training_loss_plot": training_loss_plot, |
|
|
"export_path": export_path, |
|
|
"export_lora_btn": export_lora_btn, |
|
|
"export_status": export_status, |
|
|
"training_state": training_state, |
|
|
} |
|
|
|