""" Gradio UI Training Tab Module Contains the dataset builder and LoRA training interface components. """ import os import gradio as gr from acestep.gradio_ui.i18n import t def create_training_section(dit_handler, llm_handler, init_params=None) -> dict: """Create the training tab section with dataset builder and training controls. Args: dit_handler: DiT handler instance llm_handler: LLM handler instance init_params: Dictionary containing initialization parameters and state. If None, service will not be pre-initialized. Returns: Dictionary of Gradio components for event handling """ # Check if running in service mode (hide training tab) service_mode = init_params is not None and init_params.get('service_mode', False) with gr.Tab("🎓 LoRA Training", visible=not service_mode): gr.HTML("""

đŸŽĩ LoRA Training for ACE-Step

Build datasets from your audio files and train custom LoRA adapters

""") with gr.Tabs(): # ==================== Dataset Builder Tab ==================== with gr.Tab("📁 Dataset Builder"): # ========== Load Existing OR Scan New ========== gr.HTML("""

🚀 Quick Start

Choose one: Load existing dataset OR Scan new directory

""") with gr.Row(): with gr.Column(scale=1): gr.HTML("

📂 Load Existing Dataset

") with gr.Row(): load_json_path = gr.Textbox( label="Dataset JSON Path", placeholder="./datasets/my_lora_dataset.json", info="Load a previously saved dataset", scale=3, ) load_json_btn = gr.Button("📂 Load", variant="primary", scale=1) load_json_status = gr.Textbox( label="Load Status", interactive=False, ) with gr.Column(scale=1): gr.HTML("

🔍 Scan New Directory

") with gr.Row(): audio_directory = gr.Textbox( label="Audio Directory Path", placeholder="/path/to/your/audio/folder", info="Scan for audio files (wav, mp3, flac, ogg, opus)", scale=3, ) scan_btn = gr.Button("🔍 Scan", variant="secondary", scale=1) scan_status = gr.Textbox( label="Scan Status", interactive=False, ) gr.HTML("
") with gr.Row(): with gr.Column(scale=2): # Audio files table audio_files_table = gr.Dataframe( headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"], datatype=["number", "str", "str", "str", "str", "str", "str"], label="Found Audio Files", interactive=False, wrap=True, ) with gr.Column(scale=1): gr.HTML("

âš™ī¸ Dataset Settings

") dataset_name = gr.Textbox( label="Dataset Name", value="my_lora_dataset", placeholder="Enter dataset name", ) all_instrumental = gr.Checkbox( label="All Instrumental", value=True, info="Check if all tracks are instrumental (no vocals)", ) need_lyrics = gr.Checkbox( label="Transcribe Lyrics", value=False, info="Attempt to transcribe lyrics (slower)", interactive=False, # Disabled for now ) custom_tag = gr.Textbox( label="Custom Activation Tag", placeholder="e.g., 8bit_retro, my_style", info="Unique tag to activate this LoRA's style", ) tag_position = gr.Radio( choices=[ ("Prepend (tag, caption)", "prepend"), ("Append (caption, tag)", "append"), ("Replace caption", "replace"), ], value="replace", label="Tag Position", info="Where to place the custom tag in the caption", ) gr.HTML("

🤖 Step 2: Auto-Label with AI

") with gr.Row(): with gr.Column(scale=3): gr.Markdown(""" Click the button below to automatically generate metadata for all audio files using AI: - **Caption**: Music style, genre, mood description - **BPM**: Beats per minute - **Key**: Musical key (e.g., C Major, Am) - **Time Signature**: 4/4, 3/4, etc. """) skip_metas = gr.Checkbox( label="Skip Metas (No LLM)", value=False, info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental", ) with gr.Column(scale=1): auto_label_btn = gr.Button( "đŸˇī¸ Auto-Label All", variant="primary", size="lg", ) label_progress = gr.Textbox( label="Labeling Progress", interactive=False, lines=2, ) gr.HTML("

👀 Step 3: Preview & Edit

") with gr.Row(): with gr.Column(scale=1): sample_selector = gr.Slider( minimum=0, maximum=0, step=1, value=0, label="Select Sample #", info="Choose a sample to preview and edit", ) preview_audio = gr.Audio( label="Audio Preview", type="filepath", interactive=False, ) preview_filename = gr.Textbox( label="Filename", interactive=False, ) with gr.Column(scale=2): with gr.Row(): edit_caption = gr.Textbox( label="Caption", lines=3, placeholder="Music description...", ) with gr.Row(): edit_lyrics = gr.Textbox( label="Lyrics", lines=4, placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...", ) with gr.Row(): edit_bpm = gr.Number( label="BPM", precision=0, ) edit_keyscale = gr.Textbox( label="Key", placeholder="C Major", ) edit_timesig = gr.Dropdown( choices=["", "2", "3", "4", "6"], label="Time Signature", ) edit_duration = gr.Number( label="Duration (s)", precision=1, interactive=False, ) with gr.Row(): edit_language = gr.Dropdown( choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"], value="instrumental", label="Language", ) edit_instrumental = gr.Checkbox( label="Instrumental", value=True, ) save_edit_btn = gr.Button("💾 Save Changes", variant="secondary") edit_status = gr.Textbox( label="Edit Status", interactive=False, ) gr.HTML("

💾 Step 4: Save Dataset

") with gr.Row(): with gr.Column(scale=3): save_path = gr.Textbox( label="Save Path", value="./datasets/my_lora_dataset.json", placeholder="./datasets/dataset_name.json", info="Path where the dataset JSON will be saved", ) with gr.Column(scale=1): save_dataset_btn = gr.Button( "💾 Save Dataset", variant="primary", size="lg", ) save_status = gr.Textbox( label="Save Status", interactive=False, lines=2, ) gr.HTML("

⚡ Step 5: Preprocess to Tensors

") gr.Markdown(""" **Preprocessing converts your dataset to pre-computed tensors for fast training.** You can either: - Use the dataset from Steps 1-4 above, **OR** - Load an existing dataset JSON file (if you've already saved one) """) with gr.Row(): with gr.Column(scale=3): load_existing_dataset_path = gr.Textbox( label="Load Existing Dataset (Optional)", placeholder="./datasets/my_lora_dataset.json", info="Path to a previously saved dataset JSON file", ) with gr.Column(scale=1): load_existing_dataset_btn = gr.Button( "📂 Load Dataset", variant="secondary", size="lg", ) load_existing_status = gr.Textbox( label="Load Status", interactive=False, ) gr.Markdown(""" This step: - Encodes audio to VAE latents - Encodes captions and lyrics to text embeddings - Runs the condition encoder - Saves all tensors to `.pt` files âš ī¸ **This requires the model to be loaded and may take a few minutes.** """) with gr.Row(): with gr.Column(scale=3): preprocess_output_dir = gr.Textbox( label="Tensor Output Directory", value="./datasets/preprocessed_tensors", placeholder="./datasets/preprocessed_tensors", info="Directory to save preprocessed tensor files", ) with gr.Column(scale=1): preprocess_btn = gr.Button( "⚡ Preprocess", variant="primary", size="lg", ) preprocess_progress = gr.Textbox( label="Preprocessing Progress", interactive=False, lines=3, ) # ==================== Training Tab ==================== with gr.Tab("🚀 Train LoRA"): with gr.Row(): with gr.Column(scale=2): gr.HTML("

📊 Preprocessed Dataset Selection

") gr.Markdown(""" Select the directory containing preprocessed tensor files (`.pt` files). These are created in the "Dataset Builder" tab using the "Preprocess" button. """) training_tensor_dir = gr.Textbox( label="Preprocessed Tensors Directory", placeholder="./datasets/preprocessed_tensors", value="./datasets/preprocessed_tensors", info="Directory containing preprocessed .pt tensor files", ) load_dataset_btn = gr.Button("📂 Load Dataset", variant="secondary") training_dataset_info = gr.Textbox( label="Dataset Info", interactive=False, lines=3, ) with gr.Column(scale=1): gr.HTML("

âš™ī¸ LoRA Settings

") lora_rank = gr.Slider( minimum=4, maximum=256, step=4, value=64, label="LoRA Rank (r)", info="Higher = more capacity, more memory", ) lora_alpha = gr.Slider( minimum=4, maximum=512, step=4, value=128, label="LoRA Alpha", info="Scaling factor (typically 2x rank)", ) lora_dropout = gr.Slider( minimum=0.0, maximum=0.5, step=0.05, value=0.1, label="LoRA Dropout", ) gr.HTML("

đŸŽ›ī¸ Training Parameters

") with gr.Row(): learning_rate = gr.Number( label="Learning Rate", value=1e-4, info="Start with 1e-4, adjust if needed", ) train_epochs = gr.Slider( minimum=100, maximum=4000, step=100, value=500, label="Max Epochs", ) train_batch_size = gr.Slider( minimum=1, maximum=8, step=1, value=1, label="Batch Size", info="Increase if you have enough VRAM", ) gradient_accumulation = gr.Slider( minimum=1, maximum=16, step=1, value=1, label="Gradient Accumulation", info="Effective batch = batch_size × accumulation", ) with gr.Row(): save_every_n_epochs = gr.Slider( minimum=50, maximum=1000, step=50, value=200, label="Save Every N Epochs", ) training_shift = gr.Slider( minimum=1.0, maximum=5.0, step=0.5, value=3.0, label="Shift", info="Timestep shift for turbo model", ) training_seed = gr.Number( label="Seed", value=42, precision=0, ) with gr.Row(): lora_output_dir = gr.Textbox( label="Output Directory", value="./lora_output", placeholder="./lora_output", info="Directory to save trained LoRA weights", ) gr.HTML("
") with gr.Row(): with gr.Column(scale=1): start_training_btn = gr.Button( "🚀 Start Training", variant="primary", size="lg", ) with gr.Column(scale=1): stop_training_btn = gr.Button( "âšī¸ Stop Training", variant="stop", size="lg", ) training_progress = gr.Textbox( label="Training Progress", interactive=False, lines=2, ) with gr.Row(): training_log = gr.Textbox( label="Training Log", interactive=False, lines=10, max_lines=15, scale=1, ) training_loss_plot = gr.LinePlot( x="step", y="loss", title="Training Loss", x_title="Step", y_title="Loss", scale=1, ) gr.HTML("

đŸ“Ļ Export LoRA

") with gr.Row(): export_path = gr.Textbox( label="Export Path", value="./lora_output/final_lora", placeholder="./lora_output/my_lora", ) export_lora_btn = gr.Button("đŸ“Ļ Export LoRA", variant="secondary") export_status = gr.Textbox( label="Export Status", interactive=False, ) # Store dataset builder state dataset_builder_state = gr.State(None) training_state = gr.State({"is_training": False, "should_stop": False}) return { # Dataset Builder - Load or Scan "load_json_path": load_json_path, "load_json_btn": load_json_btn, "load_json_status": load_json_status, "audio_directory": audio_directory, "scan_btn": scan_btn, "scan_status": scan_status, "audio_files_table": audio_files_table, "dataset_name": dataset_name, "all_instrumental": all_instrumental, "need_lyrics": need_lyrics, "custom_tag": custom_tag, "tag_position": tag_position, "skip_metas": skip_metas, "auto_label_btn": auto_label_btn, "label_progress": label_progress, "sample_selector": sample_selector, "preview_audio": preview_audio, "preview_filename": preview_filename, "edit_caption": edit_caption, "edit_lyrics": edit_lyrics, "edit_bpm": edit_bpm, "edit_keyscale": edit_keyscale, "edit_timesig": edit_timesig, "edit_duration": edit_duration, "edit_language": edit_language, "edit_instrumental": edit_instrumental, "save_edit_btn": save_edit_btn, "edit_status": edit_status, "save_path": save_path, "save_dataset_btn": save_dataset_btn, "save_status": save_status, # Preprocessing "load_existing_dataset_path": load_existing_dataset_path, "load_existing_dataset_btn": load_existing_dataset_btn, "load_existing_status": load_existing_status, "preprocess_output_dir": preprocess_output_dir, "preprocess_btn": preprocess_btn, "preprocess_progress": preprocess_progress, "dataset_builder_state": dataset_builder_state, # Training "training_tensor_dir": training_tensor_dir, "load_dataset_btn": load_dataset_btn, "training_dataset_info": training_dataset_info, "lora_rank": lora_rank, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout, "learning_rate": learning_rate, "train_epochs": train_epochs, "train_batch_size": train_batch_size, "gradient_accumulation": gradient_accumulation, "save_every_n_epochs": save_every_n_epochs, "training_shift": training_shift, "training_seed": training_seed, "lora_output_dir": lora_output_dir, "start_training_btn": start_training_btn, "stop_training_btn": stop_training_btn, "training_progress": training_progress, "training_log": training_log, "training_loss_plot": training_loss_plot, "export_path": export_path, "export_lora_btn": export_lora_btn, "export_status": export_status, "training_state": training_state, }