Spaces:
Running
on
A100
Running
on
A100
support lora trianing & inter
Browse files- .gitignore +3 -1
- acestep/gradio_ui/events/__init__.py +268 -0
- acestep/gradio_ui/events/training_handlers.py +644 -0
- acestep/gradio_ui/interfaces/__init__.py +8 -1
- acestep/gradio_ui/interfaces/generation.py +31 -0
- acestep/gradio_ui/interfaces/training.py +558 -0
- acestep/handler.py +140 -0
- acestep/training/__init__.py +61 -0
- acestep/training/configs.py +107 -0
- acestep/training/data_module.py +465 -0
- acestep/training/dataset_builder.py +755 -0
- acestep/training/lora_utils.py +305 -0
- acestep/training/trainer.py +503 -0
- requirements.txt +4 -0
.gitignore
CHANGED
|
@@ -221,4 +221,6 @@ feishu_bot/
|
|
| 221 |
tmp*
|
| 222 |
torchinductor_root/
|
| 223 |
scripts/
|
| 224 |
-
checkpoints_legacy/
|
|
|
|
|
|
|
|
|
| 221 |
tmp*
|
| 222 |
torchinductor_root/
|
| 223 |
scripts/
|
| 224 |
+
checkpoints_legacy/
|
| 225 |
+
lora_output/
|
| 226 |
+
datasets/
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -8,6 +8,7 @@ from typing import Optional
|
|
| 8 |
# Import handler modules
|
| 9 |
from . import generation_handlers as gen_h
|
| 10 |
from . import results_handlers as res_h
|
|
|
|
| 11 |
from acestep.gradio_ui.i18n import t
|
| 12 |
|
| 13 |
|
|
@@ -69,6 +70,32 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 69 |
]
|
| 70 |
)
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
# ========== UI Visibility Updates ==========
|
| 73 |
generation_section["init_llm_checkbox"].change(
|
| 74 |
fn=gen_h.update_negative_prompt_visibility,
|
|
@@ -859,3 +886,244 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 859 |
results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
|
| 860 |
]
|
| 861 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# Import handler modules
|
| 9 |
from . import generation_handlers as gen_h
|
| 10 |
from . import results_handlers as res_h
|
| 11 |
+
from . import training_handlers as train_h
|
| 12 |
from acestep.gradio_ui.i18n import t
|
| 13 |
|
| 14 |
|
|
|
|
| 70 |
]
|
| 71 |
)
|
| 72 |
|
| 73 |
+
# ========== LoRA Handlers ==========
|
| 74 |
+
generation_section["load_lora_btn"].click(
|
| 75 |
+
fn=dit_handler.load_lora,
|
| 76 |
+
inputs=[generation_section["lora_path"]],
|
| 77 |
+
outputs=[generation_section["lora_status"]]
|
| 78 |
+
).then(
|
| 79 |
+
# Update checkbox to enabled state after loading
|
| 80 |
+
fn=lambda: gr.update(value=True),
|
| 81 |
+
outputs=[generation_section["use_lora_checkbox"]]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
generation_section["unload_lora_btn"].click(
|
| 85 |
+
fn=dit_handler.unload_lora,
|
| 86 |
+
outputs=[generation_section["lora_status"]]
|
| 87 |
+
).then(
|
| 88 |
+
# Update checkbox to disabled state after unloading
|
| 89 |
+
fn=lambda: gr.update(value=False),
|
| 90 |
+
outputs=[generation_section["use_lora_checkbox"]]
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
generation_section["use_lora_checkbox"].change(
|
| 94 |
+
fn=dit_handler.set_use_lora,
|
| 95 |
+
inputs=[generation_section["use_lora_checkbox"]],
|
| 96 |
+
outputs=[generation_section["lora_status"]]
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
# ========== UI Visibility Updates ==========
|
| 100 |
generation_section["init_llm_checkbox"].change(
|
| 101 |
fn=gen_h.update_negative_prompt_visibility,
|
|
|
|
| 886 |
results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
|
| 887 |
]
|
| 888 |
)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
|
| 892 |
+
"""Setup event handlers for the training tab (dataset builder and LoRA training)"""
|
| 893 |
+
|
| 894 |
+
# ========== Load Existing Dataset (Top Section) ==========
|
| 895 |
+
|
| 896 |
+
# Load existing dataset JSON at the top of Dataset Builder
|
| 897 |
+
training_section["load_json_btn"].click(
|
| 898 |
+
fn=train_h.load_existing_dataset_for_preprocess,
|
| 899 |
+
inputs=[
|
| 900 |
+
training_section["load_json_path"],
|
| 901 |
+
training_section["dataset_builder_state"],
|
| 902 |
+
],
|
| 903 |
+
outputs=[
|
| 904 |
+
training_section["load_json_status"],
|
| 905 |
+
training_section["audio_files_table"],
|
| 906 |
+
training_section["sample_selector"],
|
| 907 |
+
training_section["dataset_builder_state"],
|
| 908 |
+
# Also update preview fields with first sample
|
| 909 |
+
training_section["preview_audio"],
|
| 910 |
+
training_section["preview_filename"],
|
| 911 |
+
training_section["edit_caption"],
|
| 912 |
+
training_section["edit_lyrics"],
|
| 913 |
+
training_section["edit_bpm"],
|
| 914 |
+
training_section["edit_keyscale"],
|
| 915 |
+
training_section["edit_timesig"],
|
| 916 |
+
training_section["edit_duration"],
|
| 917 |
+
training_section["edit_language"],
|
| 918 |
+
training_section["edit_instrumental"],
|
| 919 |
+
]
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
# ========== Dataset Builder Handlers ==========
|
| 923 |
+
|
| 924 |
+
# Scan directory for audio files
|
| 925 |
+
training_section["scan_btn"].click(
|
| 926 |
+
fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
|
| 927 |
+
dir, name, tag, pos, instr, state
|
| 928 |
+
),
|
| 929 |
+
inputs=[
|
| 930 |
+
training_section["audio_directory"],
|
| 931 |
+
training_section["dataset_name"],
|
| 932 |
+
training_section["custom_tag"],
|
| 933 |
+
training_section["tag_position"],
|
| 934 |
+
training_section["all_instrumental"],
|
| 935 |
+
training_section["dataset_builder_state"],
|
| 936 |
+
],
|
| 937 |
+
outputs=[
|
| 938 |
+
training_section["audio_files_table"],
|
| 939 |
+
training_section["scan_status"],
|
| 940 |
+
training_section["sample_selector"],
|
| 941 |
+
training_section["dataset_builder_state"],
|
| 942 |
+
]
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
# Auto-label all samples
|
| 946 |
+
training_section["auto_label_btn"].click(
|
| 947 |
+
fn=lambda state, skip: train_h.auto_label_all(dit_handler, llm_handler, state, skip),
|
| 948 |
+
inputs=[
|
| 949 |
+
training_section["dataset_builder_state"],
|
| 950 |
+
training_section["skip_metas"],
|
| 951 |
+
],
|
| 952 |
+
outputs=[
|
| 953 |
+
training_section["audio_files_table"],
|
| 954 |
+
training_section["label_progress"],
|
| 955 |
+
training_section["dataset_builder_state"],
|
| 956 |
+
]
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
# Sample selector change - update preview
|
| 960 |
+
training_section["sample_selector"].change(
|
| 961 |
+
fn=train_h.get_sample_preview,
|
| 962 |
+
inputs=[
|
| 963 |
+
training_section["sample_selector"],
|
| 964 |
+
training_section["dataset_builder_state"],
|
| 965 |
+
],
|
| 966 |
+
outputs=[
|
| 967 |
+
training_section["preview_audio"],
|
| 968 |
+
training_section["preview_filename"],
|
| 969 |
+
training_section["edit_caption"],
|
| 970 |
+
training_section["edit_lyrics"],
|
| 971 |
+
training_section["edit_bpm"],
|
| 972 |
+
training_section["edit_keyscale"],
|
| 973 |
+
training_section["edit_timesig"],
|
| 974 |
+
training_section["edit_duration"],
|
| 975 |
+
training_section["edit_language"],
|
| 976 |
+
training_section["edit_instrumental"],
|
| 977 |
+
]
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
# Save sample edit
|
| 981 |
+
training_section["save_edit_btn"].click(
|
| 982 |
+
fn=train_h.save_sample_edit,
|
| 983 |
+
inputs=[
|
| 984 |
+
training_section["sample_selector"],
|
| 985 |
+
training_section["edit_caption"],
|
| 986 |
+
training_section["edit_lyrics"],
|
| 987 |
+
training_section["edit_bpm"],
|
| 988 |
+
training_section["edit_keyscale"],
|
| 989 |
+
training_section["edit_timesig"],
|
| 990 |
+
training_section["edit_language"],
|
| 991 |
+
training_section["edit_instrumental"],
|
| 992 |
+
training_section["dataset_builder_state"],
|
| 993 |
+
],
|
| 994 |
+
outputs=[
|
| 995 |
+
training_section["audio_files_table"],
|
| 996 |
+
training_section["edit_status"],
|
| 997 |
+
training_section["dataset_builder_state"],
|
| 998 |
+
]
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
# Update settings when changed
|
| 1002 |
+
for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"]]:
|
| 1003 |
+
trigger.change(
|
| 1004 |
+
fn=train_h.update_settings,
|
| 1005 |
+
inputs=[
|
| 1006 |
+
training_section["custom_tag"],
|
| 1007 |
+
training_section["tag_position"],
|
| 1008 |
+
training_section["all_instrumental"],
|
| 1009 |
+
training_section["dataset_builder_state"],
|
| 1010 |
+
],
|
| 1011 |
+
outputs=[training_section["dataset_builder_state"]]
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
# Save dataset
|
| 1015 |
+
training_section["save_dataset_btn"].click(
|
| 1016 |
+
fn=train_h.save_dataset,
|
| 1017 |
+
inputs=[
|
| 1018 |
+
training_section["save_path"],
|
| 1019 |
+
training_section["dataset_name"],
|
| 1020 |
+
training_section["dataset_builder_state"],
|
| 1021 |
+
],
|
| 1022 |
+
outputs=[training_section["save_status"]]
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
# ========== Preprocess Handlers ==========
|
| 1026 |
+
|
| 1027 |
+
# Load existing dataset JSON for preprocessing
|
| 1028 |
+
# This also updates the preview section so users can view/edit samples
|
| 1029 |
+
training_section["load_existing_dataset_btn"].click(
|
| 1030 |
+
fn=train_h.load_existing_dataset_for_preprocess,
|
| 1031 |
+
inputs=[
|
| 1032 |
+
training_section["load_existing_dataset_path"],
|
| 1033 |
+
training_section["dataset_builder_state"],
|
| 1034 |
+
],
|
| 1035 |
+
outputs=[
|
| 1036 |
+
training_section["load_existing_status"],
|
| 1037 |
+
training_section["audio_files_table"],
|
| 1038 |
+
training_section["sample_selector"],
|
| 1039 |
+
training_section["dataset_builder_state"],
|
| 1040 |
+
# Also update preview fields with first sample
|
| 1041 |
+
training_section["preview_audio"],
|
| 1042 |
+
training_section["preview_filename"],
|
| 1043 |
+
training_section["edit_caption"],
|
| 1044 |
+
training_section["edit_lyrics"],
|
| 1045 |
+
training_section["edit_bpm"],
|
| 1046 |
+
training_section["edit_keyscale"],
|
| 1047 |
+
training_section["edit_timesig"],
|
| 1048 |
+
training_section["edit_duration"],
|
| 1049 |
+
training_section["edit_language"],
|
| 1050 |
+
training_section["edit_instrumental"],
|
| 1051 |
+
]
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
# Preprocess dataset to tensor files
|
| 1055 |
+
training_section["preprocess_btn"].click(
|
| 1056 |
+
fn=lambda output_dir, state: train_h.preprocess_dataset(
|
| 1057 |
+
output_dir, dit_handler, state
|
| 1058 |
+
),
|
| 1059 |
+
inputs=[
|
| 1060 |
+
training_section["preprocess_output_dir"],
|
| 1061 |
+
training_section["dataset_builder_state"],
|
| 1062 |
+
],
|
| 1063 |
+
outputs=[training_section["preprocess_progress"]]
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
# ========== Training Tab Handlers ==========
|
| 1067 |
+
|
| 1068 |
+
# Load preprocessed tensor dataset
|
| 1069 |
+
training_section["load_dataset_btn"].click(
|
| 1070 |
+
fn=train_h.load_training_dataset,
|
| 1071 |
+
inputs=[training_section["training_tensor_dir"]],
|
| 1072 |
+
outputs=[training_section["training_dataset_info"]]
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
# Start training from preprocessed tensors
|
| 1076 |
+
def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts):
|
| 1077 |
+
try:
|
| 1078 |
+
for progress, log, plot, state in train_h.start_training(
|
| 1079 |
+
tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts
|
| 1080 |
+
):
|
| 1081 |
+
yield progress, log, plot, state
|
| 1082 |
+
except Exception as e:
|
| 1083 |
+
logger.exception("Training wrapper error")
|
| 1084 |
+
yield f"❌ Error: {str(e)}", str(e), None, ts
|
| 1085 |
+
|
| 1086 |
+
training_section["start_training_btn"].click(
|
| 1087 |
+
fn=training_wrapper,
|
| 1088 |
+
inputs=[
|
| 1089 |
+
training_section["training_tensor_dir"],
|
| 1090 |
+
training_section["lora_rank"],
|
| 1091 |
+
training_section["lora_alpha"],
|
| 1092 |
+
training_section["lora_dropout"],
|
| 1093 |
+
training_section["learning_rate"],
|
| 1094 |
+
training_section["train_epochs"],
|
| 1095 |
+
training_section["train_batch_size"],
|
| 1096 |
+
training_section["gradient_accumulation"],
|
| 1097 |
+
training_section["save_every_n_epochs"],
|
| 1098 |
+
training_section["training_shift"],
|
| 1099 |
+
training_section["training_seed"],
|
| 1100 |
+
training_section["lora_output_dir"],
|
| 1101 |
+
training_section["training_state"],
|
| 1102 |
+
],
|
| 1103 |
+
outputs=[
|
| 1104 |
+
training_section["training_progress"],
|
| 1105 |
+
training_section["training_log"],
|
| 1106 |
+
training_section["training_loss_plot"],
|
| 1107 |
+
training_section["training_state"],
|
| 1108 |
+
]
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
# Stop training
|
| 1112 |
+
training_section["stop_training_btn"].click(
|
| 1113 |
+
fn=train_h.stop_training,
|
| 1114 |
+
inputs=[training_section["training_state"]],
|
| 1115 |
+
outputs=[
|
| 1116 |
+
training_section["training_progress"],
|
| 1117 |
+
training_section["training_state"],
|
| 1118 |
+
]
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
+
# Export LoRA
|
| 1122 |
+
training_section["export_lora_btn"].click(
|
| 1123 |
+
fn=train_h.export_lora,
|
| 1124 |
+
inputs=[
|
| 1125 |
+
training_section["export_path"],
|
| 1126 |
+
training_section["lora_output_dir"],
|
| 1127 |
+
],
|
| 1128 |
+
outputs=[training_section["export_status"]]
|
| 1129 |
+
)
|
acestep/gradio_ui/events/training_handlers.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Event Handlers for Training Tab
|
| 3 |
+
|
| 4 |
+
Contains all event handler functions for the dataset builder and training UI.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Optional
|
| 10 |
+
from loguru import logger
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_dataset_builder() -> DatasetBuilder:
|
| 17 |
+
"""Create a new DatasetBuilder instance."""
|
| 18 |
+
return DatasetBuilder()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def scan_directory(
|
| 22 |
+
audio_dir: str,
|
| 23 |
+
dataset_name: str,
|
| 24 |
+
custom_tag: str,
|
| 25 |
+
tag_position: str,
|
| 26 |
+
all_instrumental: bool,
|
| 27 |
+
builder_state: Optional[DatasetBuilder],
|
| 28 |
+
) -> Tuple[Any, str, Any, DatasetBuilder]:
|
| 29 |
+
"""Scan a directory for audio files.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Tuple of (table_data, status, slider_update, builder_state)
|
| 33 |
+
"""
|
| 34 |
+
if not audio_dir or not audio_dir.strip():
|
| 35 |
+
return [], "❌ Please enter a directory path", gr.Slider(maximum=0, value=0), builder_state
|
| 36 |
+
|
| 37 |
+
# Create or use existing builder
|
| 38 |
+
builder = builder_state if builder_state else DatasetBuilder()
|
| 39 |
+
|
| 40 |
+
# Set metadata before scanning
|
| 41 |
+
builder.metadata.name = dataset_name
|
| 42 |
+
builder.metadata.custom_tag = custom_tag
|
| 43 |
+
builder.metadata.tag_position = tag_position
|
| 44 |
+
builder.metadata.all_instrumental = all_instrumental
|
| 45 |
+
|
| 46 |
+
# Scan directory
|
| 47 |
+
samples, status = builder.scan_directory(audio_dir.strip())
|
| 48 |
+
|
| 49 |
+
if not samples:
|
| 50 |
+
return [], status, gr.Slider(maximum=0, value=0), builder
|
| 51 |
+
|
| 52 |
+
# Set instrumental and tag for all samples
|
| 53 |
+
builder.set_all_instrumental(all_instrumental)
|
| 54 |
+
if custom_tag:
|
| 55 |
+
builder.set_custom_tag(custom_tag, tag_position)
|
| 56 |
+
|
| 57 |
+
# Get table data
|
| 58 |
+
table_data = builder.get_samples_dataframe_data()
|
| 59 |
+
|
| 60 |
+
# Calculate slider max and return as Slider update
|
| 61 |
+
slider_max = max(0, len(samples) - 1)
|
| 62 |
+
|
| 63 |
+
return table_data, status, gr.Slider(maximum=slider_max, value=0), builder
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def auto_label_all(
|
| 67 |
+
dit_handler,
|
| 68 |
+
llm_handler,
|
| 69 |
+
builder_state: Optional[DatasetBuilder],
|
| 70 |
+
skip_metas: bool = False,
|
| 71 |
+
progress=None,
|
| 72 |
+
) -> Tuple[List[List[Any]], str, DatasetBuilder]:
|
| 73 |
+
"""Auto-label all samples in the dataset.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
dit_handler: DiT handler for audio processing
|
| 77 |
+
llm_handler: LLM handler for caption generation
|
| 78 |
+
builder_state: Dataset builder state
|
| 79 |
+
skip_metas: If True, skip LLM labeling. BPM/Key/TimeSig = N/A, Language = unknown for instrumental
|
| 80 |
+
progress: Progress callback
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Tuple of (table_data, status, builder_state)
|
| 84 |
+
"""
|
| 85 |
+
if builder_state is None:
|
| 86 |
+
return [], "❌ Please scan a directory first", builder_state
|
| 87 |
+
|
| 88 |
+
if not builder_state.samples:
|
| 89 |
+
return [], "❌ No samples to label. Please scan a directory first.", builder_state
|
| 90 |
+
|
| 91 |
+
# If skip_metas is True, just set default values without LLM
|
| 92 |
+
if skip_metas:
|
| 93 |
+
for sample in builder_state.samples:
|
| 94 |
+
sample.bpm = None # Will display as N/A
|
| 95 |
+
sample.keyscale = "N/A"
|
| 96 |
+
sample.timesignature = "N/A"
|
| 97 |
+
# For instrumental, language should be "unknown"
|
| 98 |
+
if sample.is_instrumental:
|
| 99 |
+
sample.language = "unknown"
|
| 100 |
+
else:
|
| 101 |
+
sample.language = "unknown"
|
| 102 |
+
# Use custom tag as caption if set, otherwise use filename
|
| 103 |
+
if builder_state.metadata.custom_tag:
|
| 104 |
+
sample.caption = builder_state.metadata.custom_tag
|
| 105 |
+
else:
|
| 106 |
+
sample.caption = sample.filename
|
| 107 |
+
|
| 108 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 109 |
+
return table_data, f"✅ Skipped AI labeling. {len(builder_state.samples)} samples set with default values.", builder_state
|
| 110 |
+
|
| 111 |
+
# Check if handlers are initialized
|
| 112 |
+
if dit_handler is None or dit_handler.model is None:
|
| 113 |
+
return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
|
| 114 |
+
|
| 115 |
+
if llm_handler is None or not llm_handler.llm_initialized:
|
| 116 |
+
return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
|
| 117 |
+
|
| 118 |
+
def progress_callback(msg):
|
| 119 |
+
if progress:
|
| 120 |
+
try:
|
| 121 |
+
progress(msg)
|
| 122 |
+
except:
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
# Label all samples
|
| 126 |
+
samples, status = builder_state.label_all_samples(
|
| 127 |
+
dit_handler=dit_handler,
|
| 128 |
+
llm_handler=llm_handler,
|
| 129 |
+
progress_callback=progress_callback,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Get updated table data
|
| 133 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 134 |
+
|
| 135 |
+
return table_data, status, builder_state
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_sample_preview(
|
| 139 |
+
sample_idx: int,
|
| 140 |
+
builder_state: Optional[DatasetBuilder],
|
| 141 |
+
) -> Tuple[str, str, str, str, Optional[int], str, str, float, str, bool]:
|
| 142 |
+
"""Get preview data for a specific sample.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tuple of (audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
|
| 146 |
+
"""
|
| 147 |
+
if builder_state is None or not builder_state.samples:
|
| 148 |
+
return None, "", "", "", None, "", "", 0.0, "instrumental", True
|
| 149 |
+
|
| 150 |
+
idx = int(sample_idx)
|
| 151 |
+
if idx < 0 or idx >= len(builder_state.samples):
|
| 152 |
+
return None, "", "", "", None, "", "", 0.0, "instrumental", True
|
| 153 |
+
|
| 154 |
+
sample = builder_state.samples[idx]
|
| 155 |
+
|
| 156 |
+
return (
|
| 157 |
+
sample.audio_path,
|
| 158 |
+
sample.filename,
|
| 159 |
+
sample.caption,
|
| 160 |
+
sample.lyrics,
|
| 161 |
+
sample.bpm,
|
| 162 |
+
sample.keyscale,
|
| 163 |
+
sample.timesignature,
|
| 164 |
+
sample.duration,
|
| 165 |
+
sample.language,
|
| 166 |
+
sample.is_instrumental,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def save_sample_edit(
|
| 171 |
+
sample_idx: int,
|
| 172 |
+
caption: str,
|
| 173 |
+
lyrics: str,
|
| 174 |
+
bpm: Optional[int],
|
| 175 |
+
keyscale: str,
|
| 176 |
+
timesig: str,
|
| 177 |
+
language: str,
|
| 178 |
+
is_instrumental: bool,
|
| 179 |
+
builder_state: Optional[DatasetBuilder],
|
| 180 |
+
) -> Tuple[List[List[Any]], str, DatasetBuilder]:
|
| 181 |
+
"""Save edits to a sample.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tuple of (table_data, status, builder_state)
|
| 185 |
+
"""
|
| 186 |
+
if builder_state is None:
|
| 187 |
+
return [], "❌ No dataset loaded", builder_state
|
| 188 |
+
|
| 189 |
+
idx = int(sample_idx)
|
| 190 |
+
|
| 191 |
+
# Update sample
|
| 192 |
+
sample, status = builder_state.update_sample(
|
| 193 |
+
idx,
|
| 194 |
+
caption=caption,
|
| 195 |
+
lyrics=lyrics if not is_instrumental else "[Instrumental]",
|
| 196 |
+
bpm=int(bpm) if bpm else None,
|
| 197 |
+
keyscale=keyscale,
|
| 198 |
+
timesignature=timesig,
|
| 199 |
+
language="instrumental" if is_instrumental else language,
|
| 200 |
+
is_instrumental=is_instrumental,
|
| 201 |
+
labeled=True,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Get updated table data
|
| 205 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 206 |
+
|
| 207 |
+
return table_data, status, builder_state
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def update_settings(
|
| 211 |
+
custom_tag: str,
|
| 212 |
+
tag_position: str,
|
| 213 |
+
all_instrumental: bool,
|
| 214 |
+
builder_state: Optional[DatasetBuilder],
|
| 215 |
+
) -> DatasetBuilder:
|
| 216 |
+
"""Update dataset settings.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Updated builder_state
|
| 220 |
+
"""
|
| 221 |
+
if builder_state is None:
|
| 222 |
+
return builder_state
|
| 223 |
+
|
| 224 |
+
if custom_tag:
|
| 225 |
+
builder_state.set_custom_tag(custom_tag, tag_position)
|
| 226 |
+
|
| 227 |
+
builder_state.set_all_instrumental(all_instrumental)
|
| 228 |
+
|
| 229 |
+
return builder_state
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def save_dataset(
|
| 233 |
+
save_path: str,
|
| 234 |
+
dataset_name: str,
|
| 235 |
+
builder_state: Optional[DatasetBuilder],
|
| 236 |
+
) -> str:
|
| 237 |
+
"""Save the dataset to a JSON file.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
Status message
|
| 241 |
+
"""
|
| 242 |
+
if builder_state is None:
|
| 243 |
+
return "❌ No dataset to save. Please scan a directory first."
|
| 244 |
+
|
| 245 |
+
if not builder_state.samples:
|
| 246 |
+
return "❌ No samples in dataset."
|
| 247 |
+
|
| 248 |
+
if not save_path or not save_path.strip():
|
| 249 |
+
return "❌ Please enter a save path."
|
| 250 |
+
|
| 251 |
+
# Check if any samples are labeled
|
| 252 |
+
labeled_count = builder_state.get_labeled_count()
|
| 253 |
+
if labeled_count == 0:
|
| 254 |
+
return "⚠️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway..."
|
| 255 |
+
|
| 256 |
+
return builder_state.save_dataset(save_path.strip(), dataset_name)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def load_existing_dataset_for_preprocess(
|
| 260 |
+
dataset_path: str,
|
| 261 |
+
builder_state: Optional[DatasetBuilder],
|
| 262 |
+
) -> Tuple[str, Any, Any, DatasetBuilder, str, str, str, str, Optional[int], str, str, float, str, bool]:
|
| 263 |
+
"""Load an existing dataset JSON file for preprocessing.
|
| 264 |
+
|
| 265 |
+
This allows users to load a previously saved dataset and proceed to preprocessing
|
| 266 |
+
without having to re-scan and re-label.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Tuple of (status, table_data, slider_update, builder_state,
|
| 270 |
+
audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
|
| 271 |
+
"""
|
| 272 |
+
empty_preview = (None, "", "", "", None, "", "", 0.0, "instrumental", True)
|
| 273 |
+
|
| 274 |
+
if not dataset_path or not dataset_path.strip():
|
| 275 |
+
return ("❌ Please enter a dataset path", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
|
| 276 |
+
|
| 277 |
+
dataset_path = dataset_path.strip()
|
| 278 |
+
|
| 279 |
+
if not os.path.exists(dataset_path):
|
| 280 |
+
return (f"❌ Dataset not found: {dataset_path}", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
|
| 281 |
+
|
| 282 |
+
# Create new builder (don't reuse old state when loading a file)
|
| 283 |
+
builder = DatasetBuilder()
|
| 284 |
+
|
| 285 |
+
# Load the dataset
|
| 286 |
+
samples, status = builder.load_dataset(dataset_path)
|
| 287 |
+
|
| 288 |
+
if not samples:
|
| 289 |
+
return (status, [], gr.Slider(maximum=0, value=0), builder) + empty_preview
|
| 290 |
+
|
| 291 |
+
# Get table data
|
| 292 |
+
table_data = builder.get_samples_dataframe_data()
|
| 293 |
+
|
| 294 |
+
# Calculate slider max
|
| 295 |
+
slider_max = max(0, len(samples) - 1)
|
| 296 |
+
|
| 297 |
+
# Create info text
|
| 298 |
+
labeled_count = builder.get_labeled_count()
|
| 299 |
+
info = f"✅ Loaded dataset: {builder.metadata.name}\n"
|
| 300 |
+
info += f"📊 Samples: {len(samples)} ({labeled_count} labeled)\n"
|
| 301 |
+
info += f"🏷️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
|
| 302 |
+
info += "📝 Ready for preprocessing! You can also edit samples below."
|
| 303 |
+
|
| 304 |
+
# Get first sample preview
|
| 305 |
+
first_sample = builder.samples[0]
|
| 306 |
+
preview = (
|
| 307 |
+
first_sample.audio_path,
|
| 308 |
+
first_sample.filename,
|
| 309 |
+
first_sample.caption,
|
| 310 |
+
first_sample.lyrics,
|
| 311 |
+
first_sample.bpm,
|
| 312 |
+
first_sample.keyscale,
|
| 313 |
+
first_sample.timesignature,
|
| 314 |
+
first_sample.duration,
|
| 315 |
+
first_sample.language,
|
| 316 |
+
first_sample.is_instrumental,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return (info, table_data, gr.Slider(maximum=slider_max, value=0), builder) + preview
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def preprocess_dataset(
|
| 323 |
+
output_dir: str,
|
| 324 |
+
dit_handler,
|
| 325 |
+
builder_state: Optional[DatasetBuilder],
|
| 326 |
+
progress=None,
|
| 327 |
+
) -> str:
|
| 328 |
+
"""Preprocess dataset to tensor files for fast training.
|
| 329 |
+
|
| 330 |
+
This converts audio files to VAE latents and text to embeddings.
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
Status message
|
| 334 |
+
"""
|
| 335 |
+
if builder_state is None:
|
| 336 |
+
return "❌ No dataset loaded. Please scan a directory first."
|
| 337 |
+
|
| 338 |
+
if not builder_state.samples:
|
| 339 |
+
return "❌ No samples in dataset."
|
| 340 |
+
|
| 341 |
+
labeled_count = builder_state.get_labeled_count()
|
| 342 |
+
if labeled_count == 0:
|
| 343 |
+
return "❌ No labeled samples. Please auto-label or manually label samples first."
|
| 344 |
+
|
| 345 |
+
if not output_dir or not output_dir.strip():
|
| 346 |
+
return "❌ Please enter an output directory."
|
| 347 |
+
|
| 348 |
+
if dit_handler is None or dit_handler.model is None:
|
| 349 |
+
return "❌ Model not initialized. Please initialize the service first."
|
| 350 |
+
|
| 351 |
+
def progress_callback(msg):
|
| 352 |
+
if progress:
|
| 353 |
+
try:
|
| 354 |
+
progress(msg)
|
| 355 |
+
except:
|
| 356 |
+
pass
|
| 357 |
+
|
| 358 |
+
# Run preprocessing
|
| 359 |
+
output_paths, status = builder_state.preprocess_to_tensors(
|
| 360 |
+
dit_handler=dit_handler,
|
| 361 |
+
output_dir=output_dir.strip(),
|
| 362 |
+
progress_callback=progress_callback,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
return status
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def load_training_dataset(
|
| 369 |
+
tensor_dir: str,
|
| 370 |
+
) -> str:
|
| 371 |
+
"""Load a preprocessed tensor dataset for training.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Info text about the dataset
|
| 375 |
+
"""
|
| 376 |
+
if not tensor_dir or not tensor_dir.strip():
|
| 377 |
+
return "❌ Please enter a tensor directory path"
|
| 378 |
+
|
| 379 |
+
tensor_dir = tensor_dir.strip()
|
| 380 |
+
|
| 381 |
+
if not os.path.exists(tensor_dir):
|
| 382 |
+
return f"❌ Directory not found: {tensor_dir}"
|
| 383 |
+
|
| 384 |
+
if not os.path.isdir(tensor_dir):
|
| 385 |
+
return f"❌ Not a directory: {tensor_dir}"
|
| 386 |
+
|
| 387 |
+
# Check for manifest
|
| 388 |
+
manifest_path = os.path.join(tensor_dir, "manifest.json")
|
| 389 |
+
if os.path.exists(manifest_path):
|
| 390 |
+
try:
|
| 391 |
+
with open(manifest_path, 'r') as f:
|
| 392 |
+
manifest = json.load(f)
|
| 393 |
+
|
| 394 |
+
num_samples = manifest.get("num_samples", 0)
|
| 395 |
+
metadata = manifest.get("metadata", {})
|
| 396 |
+
name = metadata.get("name", "Unknown")
|
| 397 |
+
custom_tag = metadata.get("custom_tag", "")
|
| 398 |
+
|
| 399 |
+
info = f"✅ Loaded preprocessed dataset: {name}\n"
|
| 400 |
+
info += f"📊 Samples: {num_samples} preprocessed tensors\n"
|
| 401 |
+
info += f"🏷️ Custom Tag: {custom_tag or '(none)'}"
|
| 402 |
+
|
| 403 |
+
return info
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.warning(f"Failed to read manifest: {e}")
|
| 406 |
+
|
| 407 |
+
# Fallback: count .pt files
|
| 408 |
+
pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
|
| 409 |
+
|
| 410 |
+
if not pt_files:
|
| 411 |
+
return f"❌ No .pt tensor files found in {tensor_dir}"
|
| 412 |
+
|
| 413 |
+
info = f"✅ Found {len(pt_files)} tensor files in {tensor_dir}\n"
|
| 414 |
+
info += "⚠️ No manifest.json found - using all .pt files"
|
| 415 |
+
|
| 416 |
+
return info
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# Training handlers
|
| 420 |
+
|
| 421 |
+
import time
|
| 422 |
+
import re
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _format_duration(seconds):
|
| 426 |
+
"""Format seconds to human readable string."""
|
| 427 |
+
seconds = int(seconds)
|
| 428 |
+
if seconds < 60:
|
| 429 |
+
return f"{seconds}s"
|
| 430 |
+
elif seconds < 3600:
|
| 431 |
+
return f"{seconds // 60}m {seconds % 60}s"
|
| 432 |
+
else:
|
| 433 |
+
return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def start_training(
|
| 437 |
+
tensor_dir: str,
|
| 438 |
+
dit_handler,
|
| 439 |
+
lora_rank: int,
|
| 440 |
+
lora_alpha: int,
|
| 441 |
+
lora_dropout: float,
|
| 442 |
+
learning_rate: float,
|
| 443 |
+
train_epochs: int,
|
| 444 |
+
train_batch_size: int,
|
| 445 |
+
gradient_accumulation: int,
|
| 446 |
+
save_every_n_epochs: int,
|
| 447 |
+
training_shift: float,
|
| 448 |
+
training_seed: int,
|
| 449 |
+
lora_output_dir: str,
|
| 450 |
+
training_state: Dict,
|
| 451 |
+
progress=None,
|
| 452 |
+
):
|
| 453 |
+
"""Start LoRA training from preprocessed tensors.
|
| 454 |
+
|
| 455 |
+
This is a generator function that yields progress updates.
|
| 456 |
+
"""
|
| 457 |
+
if not tensor_dir or not tensor_dir.strip():
|
| 458 |
+
yield "❌ Please enter a tensor directory path", "", None, training_state
|
| 459 |
+
return
|
| 460 |
+
|
| 461 |
+
tensor_dir = tensor_dir.strip()
|
| 462 |
+
|
| 463 |
+
if not os.path.exists(tensor_dir):
|
| 464 |
+
yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
|
| 465 |
+
return
|
| 466 |
+
|
| 467 |
+
if dit_handler is None or dit_handler.model is None:
|
| 468 |
+
yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
|
| 469 |
+
return
|
| 470 |
+
|
| 471 |
+
# Check for required training dependencies
|
| 472 |
+
try:
|
| 473 |
+
from lightning.fabric import Fabric
|
| 474 |
+
from peft import get_peft_model, LoraConfig
|
| 475 |
+
except ImportError as e:
|
| 476 |
+
yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
|
| 477 |
+
return
|
| 478 |
+
|
| 479 |
+
training_state["is_training"] = True
|
| 480 |
+
training_state["should_stop"] = False
|
| 481 |
+
|
| 482 |
+
try:
|
| 483 |
+
from acestep.training.trainer import LoRATrainer
|
| 484 |
+
from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
|
| 485 |
+
|
| 486 |
+
# Create configs
|
| 487 |
+
lora_config = LoRAConfigClass(
|
| 488 |
+
r=lora_rank,
|
| 489 |
+
alpha=lora_alpha,
|
| 490 |
+
dropout=lora_dropout,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
training_config = TrainingConfig(
|
| 494 |
+
shift=training_shift,
|
| 495 |
+
learning_rate=learning_rate,
|
| 496 |
+
batch_size=train_batch_size,
|
| 497 |
+
gradient_accumulation_steps=gradient_accumulation,
|
| 498 |
+
max_epochs=train_epochs,
|
| 499 |
+
save_every_n_epochs=save_every_n_epochs,
|
| 500 |
+
seed=training_seed,
|
| 501 |
+
output_dir=lora_output_dir,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
import pandas as pd
|
| 505 |
+
|
| 506 |
+
# Initialize training log and loss history
|
| 507 |
+
log_lines = []
|
| 508 |
+
loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
|
| 509 |
+
|
| 510 |
+
# Start timer
|
| 511 |
+
start_time = time.time()
|
| 512 |
+
|
| 513 |
+
yield f"🚀 Starting training from {tensor_dir}...", "", loss_data, training_state
|
| 514 |
+
|
| 515 |
+
# Create trainer
|
| 516 |
+
trainer = LoRATrainer(
|
| 517 |
+
dit_handler=dit_handler,
|
| 518 |
+
lora_config=lora_config,
|
| 519 |
+
training_config=training_config,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Collect loss history
|
| 523 |
+
step_list = []
|
| 524 |
+
loss_list = []
|
| 525 |
+
|
| 526 |
+
# Train with progress updates using preprocessed tensors
|
| 527 |
+
for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
|
| 528 |
+
# Calculate elapsed time and ETA
|
| 529 |
+
elapsed_seconds = time.time() - start_time
|
| 530 |
+
time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
|
| 531 |
+
|
| 532 |
+
# Parse "Epoch x/y" from status to calculate ETA
|
| 533 |
+
match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
|
| 534 |
+
if match:
|
| 535 |
+
current_ep = int(match.group(1))
|
| 536 |
+
total_ep = int(match.group(2))
|
| 537 |
+
if current_ep > 0:
|
| 538 |
+
eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
|
| 539 |
+
time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
|
| 540 |
+
|
| 541 |
+
# Display status with time info
|
| 542 |
+
display_status = f"{status}\n{time_info}"
|
| 543 |
+
|
| 544 |
+
# Terminal log
|
| 545 |
+
log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
|
| 546 |
+
logger.info(log_msg)
|
| 547 |
+
|
| 548 |
+
# Add to UI log
|
| 549 |
+
log_lines.append(status)
|
| 550 |
+
if len(log_lines) > 15:
|
| 551 |
+
log_lines = log_lines[-15:]
|
| 552 |
+
log_text = "\n".join(log_lines)
|
| 553 |
+
|
| 554 |
+
# Track loss for plot (only valid values)
|
| 555 |
+
if step > 0 and loss is not None and loss == loss: # Check for NaN
|
| 556 |
+
step_list.append(step)
|
| 557 |
+
loss_list.append(float(loss))
|
| 558 |
+
loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
|
| 559 |
+
|
| 560 |
+
yield display_status, log_text, loss_data, training_state
|
| 561 |
+
|
| 562 |
+
if training_state.get("should_stop", False):
|
| 563 |
+
logger.info("⏹️ Training stopped by user")
|
| 564 |
+
log_lines.append("⏹️ Training stopped by user")
|
| 565 |
+
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
|
| 566 |
+
break
|
| 567 |
+
|
| 568 |
+
total_time = time.time() - start_time
|
| 569 |
+
training_state["is_training"] = False
|
| 570 |
+
completion_msg = f"✅ Training completed! Total time: {_format_duration(total_time)}"
|
| 571 |
+
|
| 572 |
+
logger.info(completion_msg)
|
| 573 |
+
log_lines.append(completion_msg)
|
| 574 |
+
|
| 575 |
+
yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
|
| 576 |
+
|
| 577 |
+
except Exception as e:
|
| 578 |
+
logger.exception("Training error")
|
| 579 |
+
training_state["is_training"] = False
|
| 580 |
+
import pandas as pd
|
| 581 |
+
empty_df = pd.DataFrame({"step": [], "loss": []})
|
| 582 |
+
yield f"❌ Error: {str(e)}", str(e), empty_df, training_state
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def stop_training(training_state: Dict) -> Tuple[str, Dict]:
|
| 586 |
+
"""Stop the current training process.
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
Tuple of (status, training_state)
|
| 590 |
+
"""
|
| 591 |
+
if not training_state.get("is_training", False):
|
| 592 |
+
return "⚠️ No training in progress", training_state
|
| 593 |
+
|
| 594 |
+
training_state["should_stop"] = True
|
| 595 |
+
return "⏹️ Stopping training...", training_state
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def export_lora(
|
| 599 |
+
export_path: str,
|
| 600 |
+
lora_output_dir: str,
|
| 601 |
+
) -> str:
|
| 602 |
+
"""Export the trained LoRA weights.
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
Status message
|
| 606 |
+
"""
|
| 607 |
+
if not export_path or not export_path.strip():
|
| 608 |
+
return "❌ Please enter an export path"
|
| 609 |
+
|
| 610 |
+
# Check if there's a trained model to export
|
| 611 |
+
final_dir = os.path.join(lora_output_dir, "final")
|
| 612 |
+
checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
|
| 613 |
+
|
| 614 |
+
# Prefer final, fallback to checkpoints
|
| 615 |
+
if os.path.exists(final_dir):
|
| 616 |
+
source_path = final_dir
|
| 617 |
+
elif os.path.exists(checkpoint_dir):
|
| 618 |
+
# Find the latest checkpoint
|
| 619 |
+
checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
|
| 620 |
+
if not checkpoints:
|
| 621 |
+
return "❌ No checkpoints found"
|
| 622 |
+
|
| 623 |
+
checkpoints.sort(key=lambda x: int(x.split("_")[1]))
|
| 624 |
+
latest = checkpoints[-1]
|
| 625 |
+
source_path = os.path.join(checkpoint_dir, latest)
|
| 626 |
+
else:
|
| 627 |
+
return f"❌ No trained model found in {lora_output_dir}"
|
| 628 |
+
|
| 629 |
+
try:
|
| 630 |
+
import shutil
|
| 631 |
+
|
| 632 |
+
export_path = export_path.strip()
|
| 633 |
+
os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
|
| 634 |
+
|
| 635 |
+
if os.path.exists(export_path):
|
| 636 |
+
shutil.rmtree(export_path)
|
| 637 |
+
|
| 638 |
+
shutil.copytree(source_path, export_path)
|
| 639 |
+
|
| 640 |
+
return f"✅ LoRA exported to {export_path}"
|
| 641 |
+
|
| 642 |
+
except Exception as e:
|
| 643 |
+
logger.exception("Export error")
|
| 644 |
+
return f"❌ Export failed: {str(e)}"
|
acestep/gradio_ui/interfaces/__init__.py
CHANGED
|
@@ -7,7 +7,8 @@ from acestep.gradio_ui.i18n import get_i18n, t
|
|
| 7 |
from acestep.gradio_ui.interfaces.dataset import create_dataset_section
|
| 8 |
from acestep.gradio_ui.interfaces.generation import create_generation_section
|
| 9 |
from acestep.gradio_ui.interfaces.result import create_results_section
|
| 10 |
-
from acestep.gradio_ui.
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
|
|
@@ -76,7 +77,13 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
|
|
| 76 |
# Results Section
|
| 77 |
results_section = create_results_section(dit_handler)
|
| 78 |
|
|
|
|
|
|
|
|
|
|
| 79 |
# Connect event handlers
|
| 80 |
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
return demo
|
|
|
|
| 7 |
from acestep.gradio_ui.interfaces.dataset import create_dataset_section
|
| 8 |
from acestep.gradio_ui.interfaces.generation import create_generation_section
|
| 9 |
from acestep.gradio_ui.interfaces.result import create_results_section
|
| 10 |
+
from acestep.gradio_ui.interfaces.training import create_training_section
|
| 11 |
+
from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
|
| 12 |
|
| 13 |
|
| 14 |
def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
|
|
|
|
| 77 |
# Results Section
|
| 78 |
results_section = create_results_section(dit_handler)
|
| 79 |
|
| 80 |
+
# Training Section (LoRA training and dataset builder)
|
| 81 |
+
training_section = create_training_section(dit_handler, llm_handler)
|
| 82 |
+
|
| 83 |
# Connect event handlers
|
| 84 |
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
|
| 85 |
+
|
| 86 |
+
# Connect training event handlers
|
| 87 |
+
setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
|
| 88 |
|
| 89 |
return demo
|
acestep/gradio_ui/interfaces/generation.py
CHANGED
|
@@ -144,6 +144,31 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 144 |
# Set init_status value from init_params if pre-initialized
|
| 145 |
init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
|
| 146 |
init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# Inputs
|
| 149 |
with gr.Row():
|
|
@@ -653,6 +678,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
|
|
| 653 |
"use_flash_attention_checkbox": use_flash_attention_checkbox,
|
| 654 |
"offload_to_cpu_checkbox": offload_to_cpu_checkbox,
|
| 655 |
"offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
"task_type": task_type,
|
| 657 |
"instruction_display_gen": instruction_display_gen,
|
| 658 |
"track_name": track_name,
|
|
|
|
| 144 |
# Set init_status value from init_params if pre-initialized
|
| 145 |
init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
|
| 146 |
init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
|
| 147 |
+
|
| 148 |
+
# LoRA Configuration Section
|
| 149 |
+
gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
|
| 150 |
+
with gr.Row():
|
| 151 |
+
lora_path = gr.Textbox(
|
| 152 |
+
label="LoRA Path",
|
| 153 |
+
placeholder="./lora_output/final/adapter",
|
| 154 |
+
info="Path to trained LoRA adapter directory",
|
| 155 |
+
scale=3,
|
| 156 |
+
)
|
| 157 |
+
load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
|
| 158 |
+
unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
|
| 159 |
+
with gr.Row():
|
| 160 |
+
use_lora_checkbox = gr.Checkbox(
|
| 161 |
+
label="Use LoRA",
|
| 162 |
+
value=False,
|
| 163 |
+
info="Enable LoRA adapter for inference",
|
| 164 |
+
scale=1,
|
| 165 |
+
)
|
| 166 |
+
lora_status = gr.Textbox(
|
| 167 |
+
label="LoRA Status",
|
| 168 |
+
value="No LoRA loaded",
|
| 169 |
+
interactive=False,
|
| 170 |
+
scale=2,
|
| 171 |
+
)
|
| 172 |
|
| 173 |
# Inputs
|
| 174 |
with gr.Row():
|
|
|
|
| 678 |
"use_flash_attention_checkbox": use_flash_attention_checkbox,
|
| 679 |
"offload_to_cpu_checkbox": offload_to_cpu_checkbox,
|
| 680 |
"offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
|
| 681 |
+
# LoRA components
|
| 682 |
+
"lora_path": lora_path,
|
| 683 |
+
"load_lora_btn": load_lora_btn,
|
| 684 |
+
"unload_lora_btn": unload_lora_btn,
|
| 685 |
+
"use_lora_checkbox": use_lora_checkbox,
|
| 686 |
+
"lora_status": lora_status,
|
| 687 |
"task_type": task_type,
|
| 688 |
"instruction_display_gen": instruction_display_gen,
|
| 689 |
"track_name": track_name,
|
acestep/gradio_ui/interfaces/training.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Training Tab Module
|
| 3 |
+
|
| 4 |
+
Contains the dataset builder and LoRA training interface components.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from acestep.gradio_ui.i18n import t
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_training_section(dit_handler, llm_handler) -> dict:
|
| 13 |
+
"""Create the training tab section with dataset builder and training controls.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
dit_handler: DiT handler instance
|
| 17 |
+
llm_handler: LLM handler instance
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Dictionary of Gradio components for event handling
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
with gr.Tab("🎓 LoRA Training"):
|
| 24 |
+
gr.HTML("""
|
| 25 |
+
<div style="text-align: center; padding: 10px; margin-bottom: 15px;">
|
| 26 |
+
<h2>🎵 LoRA Training for ACE-Step</h2>
|
| 27 |
+
<p>Build datasets from your audio files and train custom LoRA adapters</p>
|
| 28 |
+
</div>
|
| 29 |
+
""")
|
| 30 |
+
|
| 31 |
+
with gr.Tabs():
|
| 32 |
+
# ==================== Dataset Builder Tab ====================
|
| 33 |
+
with gr.Tab("📁 Dataset Builder"):
|
| 34 |
+
# ========== Load Existing OR Scan New ==========
|
| 35 |
+
gr.HTML("""
|
| 36 |
+
<div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
|
| 37 |
+
<h3 style="margin: 0 0 5px 0;">🚀 Quick Start</h3>
|
| 38 |
+
<p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
|
| 39 |
+
</div>
|
| 40 |
+
""")
|
| 41 |
+
|
| 42 |
+
with gr.Row():
|
| 43 |
+
with gr.Column(scale=1):
|
| 44 |
+
gr.HTML("<h4>📂 Load Existing Dataset</h4>")
|
| 45 |
+
with gr.Row():
|
| 46 |
+
load_json_path = gr.Textbox(
|
| 47 |
+
label="Dataset JSON Path",
|
| 48 |
+
placeholder="./datasets/my_lora_dataset.json",
|
| 49 |
+
info="Load a previously saved dataset",
|
| 50 |
+
scale=3,
|
| 51 |
+
)
|
| 52 |
+
load_json_btn = gr.Button("📂 Load", variant="primary", scale=1)
|
| 53 |
+
load_json_status = gr.Textbox(
|
| 54 |
+
label="Load Status",
|
| 55 |
+
interactive=False,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
with gr.Column(scale=1):
|
| 59 |
+
gr.HTML("<h4>🔍 Scan New Directory</h4>")
|
| 60 |
+
with gr.Row():
|
| 61 |
+
audio_directory = gr.Textbox(
|
| 62 |
+
label="Audio Directory Path",
|
| 63 |
+
placeholder="/path/to/your/audio/folder",
|
| 64 |
+
info="Scan for audio files (wav, mp3, flac, ogg, opus)",
|
| 65 |
+
scale=3,
|
| 66 |
+
)
|
| 67 |
+
scan_btn = gr.Button("🔍 Scan", variant="secondary", scale=1)
|
| 68 |
+
scan_status = gr.Textbox(
|
| 69 |
+
label="Scan Status",
|
| 70 |
+
interactive=False,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
gr.HTML("<hr>")
|
| 74 |
+
|
| 75 |
+
with gr.Row():
|
| 76 |
+
with gr.Column(scale=2):
|
| 77 |
+
|
| 78 |
+
# Audio files table
|
| 79 |
+
audio_files_table = gr.Dataframe(
|
| 80 |
+
headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"],
|
| 81 |
+
datatype=["number", "str", "str", "str", "str", "str", "str"],
|
| 82 |
+
label="Found Audio Files",
|
| 83 |
+
interactive=False,
|
| 84 |
+
wrap=True,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
with gr.Column(scale=1):
|
| 88 |
+
gr.HTML("<h3>⚙️ Dataset Settings</h3>")
|
| 89 |
+
|
| 90 |
+
dataset_name = gr.Textbox(
|
| 91 |
+
label="Dataset Name",
|
| 92 |
+
value="my_lora_dataset",
|
| 93 |
+
placeholder="Enter dataset name",
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
all_instrumental = gr.Checkbox(
|
| 97 |
+
label="All Instrumental",
|
| 98 |
+
value=True,
|
| 99 |
+
info="Check if all tracks are instrumental (no vocals)",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
need_lyrics = gr.Checkbox(
|
| 103 |
+
label="Transcribe Lyrics",
|
| 104 |
+
value=False,
|
| 105 |
+
info="Attempt to transcribe lyrics (slower)",
|
| 106 |
+
interactive=False, # Disabled for now
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
custom_tag = gr.Textbox(
|
| 110 |
+
label="Custom Activation Tag",
|
| 111 |
+
placeholder="e.g., 8bit_retro, my_style",
|
| 112 |
+
info="Unique tag to activate this LoRA's style",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
tag_position = gr.Radio(
|
| 116 |
+
choices=[
|
| 117 |
+
("Prepend (tag, caption)", "prepend"),
|
| 118 |
+
("Append (caption, tag)", "append"),
|
| 119 |
+
("Replace caption", "replace"),
|
| 120 |
+
],
|
| 121 |
+
value="replace",
|
| 122 |
+
label="Tag Position",
|
| 123 |
+
info="Where to place the custom tag in the caption",
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
gr.HTML("<hr><h3>🤖 Step 2: Auto-Label with AI</h3>")
|
| 127 |
+
|
| 128 |
+
with gr.Row():
|
| 129 |
+
with gr.Column(scale=3):
|
| 130 |
+
gr.Markdown("""
|
| 131 |
+
Click the button below to automatically generate metadata for all audio files using AI:
|
| 132 |
+
- **Caption**: Music style, genre, mood description
|
| 133 |
+
- **BPM**: Beats per minute
|
| 134 |
+
- **Key**: Musical key (e.g., C Major, Am)
|
| 135 |
+
- **Time Signature**: 4/4, 3/4, etc.
|
| 136 |
+
""")
|
| 137 |
+
skip_metas = gr.Checkbox(
|
| 138 |
+
label="Skip Metas (No LLM)",
|
| 139 |
+
value=False,
|
| 140 |
+
info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental",
|
| 141 |
+
)
|
| 142 |
+
with gr.Column(scale=1):
|
| 143 |
+
auto_label_btn = gr.Button(
|
| 144 |
+
"🏷️ Auto-Label All",
|
| 145 |
+
variant="primary",
|
| 146 |
+
size="lg",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
label_progress = gr.Textbox(
|
| 150 |
+
label="Labeling Progress",
|
| 151 |
+
interactive=False,
|
| 152 |
+
lines=2,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
gr.HTML("<hr><h3>👀 Step 3: Preview & Edit</h3>")
|
| 156 |
+
|
| 157 |
+
with gr.Row():
|
| 158 |
+
with gr.Column(scale=1):
|
| 159 |
+
sample_selector = gr.Slider(
|
| 160 |
+
minimum=0,
|
| 161 |
+
maximum=0,
|
| 162 |
+
step=1,
|
| 163 |
+
value=0,
|
| 164 |
+
label="Select Sample #",
|
| 165 |
+
info="Choose a sample to preview and edit",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
preview_audio = gr.Audio(
|
| 169 |
+
label="Audio Preview",
|
| 170 |
+
type="filepath",
|
| 171 |
+
interactive=False,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
preview_filename = gr.Textbox(
|
| 175 |
+
label="Filename",
|
| 176 |
+
interactive=False,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
with gr.Column(scale=2):
|
| 180 |
+
with gr.Row():
|
| 181 |
+
edit_caption = gr.Textbox(
|
| 182 |
+
label="Caption",
|
| 183 |
+
lines=3,
|
| 184 |
+
placeholder="Music description...",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
with gr.Row():
|
| 188 |
+
edit_lyrics = gr.Textbox(
|
| 189 |
+
label="Lyrics",
|
| 190 |
+
lines=4,
|
| 191 |
+
placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
with gr.Row():
|
| 195 |
+
edit_bpm = gr.Number(
|
| 196 |
+
label="BPM",
|
| 197 |
+
precision=0,
|
| 198 |
+
)
|
| 199 |
+
edit_keyscale = gr.Textbox(
|
| 200 |
+
label="Key",
|
| 201 |
+
placeholder="C Major",
|
| 202 |
+
)
|
| 203 |
+
edit_timesig = gr.Dropdown(
|
| 204 |
+
choices=["", "2", "3", "4", "6"],
|
| 205 |
+
label="Time Signature",
|
| 206 |
+
)
|
| 207 |
+
edit_duration = gr.Number(
|
| 208 |
+
label="Duration (s)",
|
| 209 |
+
precision=1,
|
| 210 |
+
interactive=False,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
with gr.Row():
|
| 214 |
+
edit_language = gr.Dropdown(
|
| 215 |
+
choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
|
| 216 |
+
value="instrumental",
|
| 217 |
+
label="Language",
|
| 218 |
+
)
|
| 219 |
+
edit_instrumental = gr.Checkbox(
|
| 220 |
+
label="Instrumental",
|
| 221 |
+
value=True,
|
| 222 |
+
)
|
| 223 |
+
save_edit_btn = gr.Button("💾 Save Changes", variant="secondary")
|
| 224 |
+
|
| 225 |
+
edit_status = gr.Textbox(
|
| 226 |
+
label="Edit Status",
|
| 227 |
+
interactive=False,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
gr.HTML("<hr><h3>💾 Step 4: Save Dataset</h3>")
|
| 231 |
+
|
| 232 |
+
with gr.Row():
|
| 233 |
+
with gr.Column(scale=3):
|
| 234 |
+
save_path = gr.Textbox(
|
| 235 |
+
label="Save Path",
|
| 236 |
+
value="./datasets/my_lora_dataset.json",
|
| 237 |
+
placeholder="./datasets/dataset_name.json",
|
| 238 |
+
info="Path where the dataset JSON will be saved",
|
| 239 |
+
)
|
| 240 |
+
with gr.Column(scale=1):
|
| 241 |
+
save_dataset_btn = gr.Button(
|
| 242 |
+
"💾 Save Dataset",
|
| 243 |
+
variant="primary",
|
| 244 |
+
size="lg",
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
save_status = gr.Textbox(
|
| 248 |
+
label="Save Status",
|
| 249 |
+
interactive=False,
|
| 250 |
+
lines=2,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
gr.HTML("<hr><h3>⚡ Step 5: Preprocess to Tensors</h3>")
|
| 254 |
+
|
| 255 |
+
gr.Markdown("""
|
| 256 |
+
**Preprocessing converts your dataset to pre-computed tensors for fast training.**
|
| 257 |
+
|
| 258 |
+
You can either:
|
| 259 |
+
- Use the dataset from Steps 1-4 above, **OR**
|
| 260 |
+
- Load an existing dataset JSON file (if you've already saved one)
|
| 261 |
+
""")
|
| 262 |
+
|
| 263 |
+
with gr.Row():
|
| 264 |
+
with gr.Column(scale=3):
|
| 265 |
+
load_existing_dataset_path = gr.Textbox(
|
| 266 |
+
label="Load Existing Dataset (Optional)",
|
| 267 |
+
placeholder="./datasets/my_lora_dataset.json",
|
| 268 |
+
info="Path to a previously saved dataset JSON file",
|
| 269 |
+
)
|
| 270 |
+
with gr.Column(scale=1):
|
| 271 |
+
load_existing_dataset_btn = gr.Button(
|
| 272 |
+
"📂 Load Dataset",
|
| 273 |
+
variant="secondary",
|
| 274 |
+
size="lg",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
load_existing_status = gr.Textbox(
|
| 278 |
+
label="Load Status",
|
| 279 |
+
interactive=False,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
gr.Markdown("""
|
| 283 |
+
This step:
|
| 284 |
+
- Encodes audio to VAE latents
|
| 285 |
+
- Encodes captions and lyrics to text embeddings
|
| 286 |
+
- Runs the condition encoder
|
| 287 |
+
- Saves all tensors to `.pt` files
|
| 288 |
+
|
| 289 |
+
⚠️ **This requires the model to be loaded and may take a few minutes.**
|
| 290 |
+
""")
|
| 291 |
+
|
| 292 |
+
with gr.Row():
|
| 293 |
+
with gr.Column(scale=3):
|
| 294 |
+
preprocess_output_dir = gr.Textbox(
|
| 295 |
+
label="Tensor Output Directory",
|
| 296 |
+
value="./datasets/preprocessed_tensors",
|
| 297 |
+
placeholder="./datasets/preprocessed_tensors",
|
| 298 |
+
info="Directory to save preprocessed tensor files",
|
| 299 |
+
)
|
| 300 |
+
with gr.Column(scale=1):
|
| 301 |
+
preprocess_btn = gr.Button(
|
| 302 |
+
"⚡ Preprocess",
|
| 303 |
+
variant="primary",
|
| 304 |
+
size="lg",
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
preprocess_progress = gr.Textbox(
|
| 308 |
+
label="Preprocessing Progress",
|
| 309 |
+
interactive=False,
|
| 310 |
+
lines=3,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# ==================== Training Tab ====================
|
| 314 |
+
with gr.Tab("🚀 Train LoRA"):
|
| 315 |
+
with gr.Row():
|
| 316 |
+
with gr.Column(scale=2):
|
| 317 |
+
gr.HTML("<h3>📊 Preprocessed Dataset Selection</h3>")
|
| 318 |
+
|
| 319 |
+
gr.Markdown("""
|
| 320 |
+
Select the directory containing preprocessed tensor files (`.pt` files).
|
| 321 |
+
These are created in the "Dataset Builder" tab using the "Preprocess" button.
|
| 322 |
+
""")
|
| 323 |
+
|
| 324 |
+
training_tensor_dir = gr.Textbox(
|
| 325 |
+
label="Preprocessed Tensors Directory",
|
| 326 |
+
placeholder="./datasets/preprocessed_tensors",
|
| 327 |
+
value="./datasets/preprocessed_tensors",
|
| 328 |
+
info="Directory containing preprocessed .pt tensor files",
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
load_dataset_btn = gr.Button("📂 Load Dataset", variant="secondary")
|
| 332 |
+
|
| 333 |
+
training_dataset_info = gr.Textbox(
|
| 334 |
+
label="Dataset Info",
|
| 335 |
+
interactive=False,
|
| 336 |
+
lines=3,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
with gr.Column(scale=1):
|
| 340 |
+
gr.HTML("<h3>⚙️ LoRA Settings</h3>")
|
| 341 |
+
|
| 342 |
+
lora_rank = gr.Slider(
|
| 343 |
+
minimum=4,
|
| 344 |
+
maximum=256,
|
| 345 |
+
step=4,
|
| 346 |
+
value=64,
|
| 347 |
+
label="LoRA Rank (r)",
|
| 348 |
+
info="Higher = more capacity, more memory",
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
lora_alpha = gr.Slider(
|
| 352 |
+
minimum=4,
|
| 353 |
+
maximum=512,
|
| 354 |
+
step=4,
|
| 355 |
+
value=128,
|
| 356 |
+
label="LoRA Alpha",
|
| 357 |
+
info="Scaling factor (typically 2x rank)",
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
lora_dropout = gr.Slider(
|
| 361 |
+
minimum=0.0,
|
| 362 |
+
maximum=0.5,
|
| 363 |
+
step=0.05,
|
| 364 |
+
value=0.1,
|
| 365 |
+
label="LoRA Dropout",
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
gr.HTML("<hr><h3>🎛️ Training Parameters</h3>")
|
| 369 |
+
|
| 370 |
+
with gr.Row():
|
| 371 |
+
learning_rate = gr.Number(
|
| 372 |
+
label="Learning Rate",
|
| 373 |
+
value=1e-4,
|
| 374 |
+
info="Start with 1e-4, adjust if needed",
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
train_epochs = gr.Slider(
|
| 378 |
+
minimum=100,
|
| 379 |
+
maximum=4000,
|
| 380 |
+
step=100,
|
| 381 |
+
value=500,
|
| 382 |
+
label="Max Epochs",
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
train_batch_size = gr.Slider(
|
| 386 |
+
minimum=1,
|
| 387 |
+
maximum=8,
|
| 388 |
+
step=1,
|
| 389 |
+
value=1,
|
| 390 |
+
label="Batch Size",
|
| 391 |
+
info="Increase if you have enough VRAM",
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
gradient_accumulation = gr.Slider(
|
| 395 |
+
minimum=1,
|
| 396 |
+
maximum=16,
|
| 397 |
+
step=1,
|
| 398 |
+
value=1,
|
| 399 |
+
label="Gradient Accumulation",
|
| 400 |
+
info="Effective batch = batch_size × accumulation",
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
with gr.Row():
|
| 404 |
+
save_every_n_epochs = gr.Slider(
|
| 405 |
+
minimum=50,
|
| 406 |
+
maximum=1000,
|
| 407 |
+
step=50,
|
| 408 |
+
value=200,
|
| 409 |
+
label="Save Every N Epochs",
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
training_shift = gr.Slider(
|
| 413 |
+
minimum=1.0,
|
| 414 |
+
maximum=5.0,
|
| 415 |
+
step=0.5,
|
| 416 |
+
value=3.0,
|
| 417 |
+
label="Shift",
|
| 418 |
+
info="Timestep shift for turbo model",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
training_seed = gr.Number(
|
| 422 |
+
label="Seed",
|
| 423 |
+
value=42,
|
| 424 |
+
precision=0,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
with gr.Row():
|
| 428 |
+
lora_output_dir = gr.Textbox(
|
| 429 |
+
label="Output Directory",
|
| 430 |
+
value="./lora_output",
|
| 431 |
+
placeholder="./lora_output",
|
| 432 |
+
info="Directory to save trained LoRA weights",
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
gr.HTML("<hr>")
|
| 436 |
+
|
| 437 |
+
with gr.Row():
|
| 438 |
+
with gr.Column(scale=1):
|
| 439 |
+
start_training_btn = gr.Button(
|
| 440 |
+
"🚀 Start Training",
|
| 441 |
+
variant="primary",
|
| 442 |
+
size="lg",
|
| 443 |
+
)
|
| 444 |
+
with gr.Column(scale=1):
|
| 445 |
+
stop_training_btn = gr.Button(
|
| 446 |
+
"⏹️ Stop Training",
|
| 447 |
+
variant="stop",
|
| 448 |
+
size="lg",
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
training_progress = gr.Textbox(
|
| 452 |
+
label="Training Progress",
|
| 453 |
+
interactive=False,
|
| 454 |
+
lines=2,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
with gr.Row():
|
| 458 |
+
training_log = gr.Textbox(
|
| 459 |
+
label="Training Log",
|
| 460 |
+
interactive=False,
|
| 461 |
+
lines=10,
|
| 462 |
+
max_lines=15,
|
| 463 |
+
scale=1,
|
| 464 |
+
)
|
| 465 |
+
training_loss_plot = gr.LinePlot(
|
| 466 |
+
x="step",
|
| 467 |
+
y="loss",
|
| 468 |
+
title="Training Loss",
|
| 469 |
+
x_title="Step",
|
| 470 |
+
y_title="Loss",
|
| 471 |
+
scale=1,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
gr.HTML("<hr><h3>📦 Export LoRA</h3>")
|
| 475 |
+
|
| 476 |
+
with gr.Row():
|
| 477 |
+
export_path = gr.Textbox(
|
| 478 |
+
label="Export Path",
|
| 479 |
+
value="./lora_output/final_lora",
|
| 480 |
+
placeholder="./lora_output/my_lora",
|
| 481 |
+
)
|
| 482 |
+
export_lora_btn = gr.Button("📦 Export LoRA", variant="secondary")
|
| 483 |
+
|
| 484 |
+
export_status = gr.Textbox(
|
| 485 |
+
label="Export Status",
|
| 486 |
+
interactive=False,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Store dataset builder state
|
| 490 |
+
dataset_builder_state = gr.State(None)
|
| 491 |
+
training_state = gr.State({"is_training": False, "should_stop": False})
|
| 492 |
+
|
| 493 |
+
return {
|
| 494 |
+
# Dataset Builder - Load or Scan
|
| 495 |
+
"load_json_path": load_json_path,
|
| 496 |
+
"load_json_btn": load_json_btn,
|
| 497 |
+
"load_json_status": load_json_status,
|
| 498 |
+
"audio_directory": audio_directory,
|
| 499 |
+
"scan_btn": scan_btn,
|
| 500 |
+
"scan_status": scan_status,
|
| 501 |
+
"audio_files_table": audio_files_table,
|
| 502 |
+
"dataset_name": dataset_name,
|
| 503 |
+
"all_instrumental": all_instrumental,
|
| 504 |
+
"need_lyrics": need_lyrics,
|
| 505 |
+
"custom_tag": custom_tag,
|
| 506 |
+
"tag_position": tag_position,
|
| 507 |
+
"skip_metas": skip_metas,
|
| 508 |
+
"auto_label_btn": auto_label_btn,
|
| 509 |
+
"label_progress": label_progress,
|
| 510 |
+
"sample_selector": sample_selector,
|
| 511 |
+
"preview_audio": preview_audio,
|
| 512 |
+
"preview_filename": preview_filename,
|
| 513 |
+
"edit_caption": edit_caption,
|
| 514 |
+
"edit_lyrics": edit_lyrics,
|
| 515 |
+
"edit_bpm": edit_bpm,
|
| 516 |
+
"edit_keyscale": edit_keyscale,
|
| 517 |
+
"edit_timesig": edit_timesig,
|
| 518 |
+
"edit_duration": edit_duration,
|
| 519 |
+
"edit_language": edit_language,
|
| 520 |
+
"edit_instrumental": edit_instrumental,
|
| 521 |
+
"save_edit_btn": save_edit_btn,
|
| 522 |
+
"edit_status": edit_status,
|
| 523 |
+
"save_path": save_path,
|
| 524 |
+
"save_dataset_btn": save_dataset_btn,
|
| 525 |
+
"save_status": save_status,
|
| 526 |
+
# Preprocessing
|
| 527 |
+
"load_existing_dataset_path": load_existing_dataset_path,
|
| 528 |
+
"load_existing_dataset_btn": load_existing_dataset_btn,
|
| 529 |
+
"load_existing_status": load_existing_status,
|
| 530 |
+
"preprocess_output_dir": preprocess_output_dir,
|
| 531 |
+
"preprocess_btn": preprocess_btn,
|
| 532 |
+
"preprocess_progress": preprocess_progress,
|
| 533 |
+
"dataset_builder_state": dataset_builder_state,
|
| 534 |
+
# Training
|
| 535 |
+
"training_tensor_dir": training_tensor_dir,
|
| 536 |
+
"load_dataset_btn": load_dataset_btn,
|
| 537 |
+
"training_dataset_info": training_dataset_info,
|
| 538 |
+
"lora_rank": lora_rank,
|
| 539 |
+
"lora_alpha": lora_alpha,
|
| 540 |
+
"lora_dropout": lora_dropout,
|
| 541 |
+
"learning_rate": learning_rate,
|
| 542 |
+
"train_epochs": train_epochs,
|
| 543 |
+
"train_batch_size": train_batch_size,
|
| 544 |
+
"gradient_accumulation": gradient_accumulation,
|
| 545 |
+
"save_every_n_epochs": save_every_n_epochs,
|
| 546 |
+
"training_shift": training_shift,
|
| 547 |
+
"training_seed": training_seed,
|
| 548 |
+
"lora_output_dir": lora_output_dir,
|
| 549 |
+
"start_training_btn": start_training_btn,
|
| 550 |
+
"stop_training_btn": stop_training_btn,
|
| 551 |
+
"training_progress": training_progress,
|
| 552 |
+
"training_log": training_log,
|
| 553 |
+
"training_loss_plot": training_loss_plot,
|
| 554 |
+
"export_path": export_path,
|
| 555 |
+
"export_lora_btn": export_lora_btn,
|
| 556 |
+
"export_status": export_status,
|
| 557 |
+
"training_state": training_state,
|
| 558 |
+
}
|
acestep/handler.py
CHANGED
|
@@ -3,6 +3,10 @@ Business Logic Handler
|
|
| 3 |
Encapsulates all data processing and business logic as a bridge between model and UI
|
| 4 |
"""
|
| 5 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import math
|
| 7 |
from copy import deepcopy
|
| 8 |
import tempfile
|
|
@@ -70,6 +74,11 @@ class AceStepHandler:
|
|
| 70 |
self.offload_to_cpu = False
|
| 71 |
self.offload_dit_to_cpu = False
|
| 72 |
self.current_offload_cost = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def get_available_checkpoints(self) -> str:
|
| 75 |
"""Return project root directory path"""
|
|
@@ -114,6 +123,137 @@ class AceStepHandler:
|
|
| 114 |
return False
|
| 115 |
return getattr(self.config, 'is_turbo', False)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
def initialize_service(
|
| 118 |
self,
|
| 119 |
project_root: str,
|
|
|
|
| 3 |
Encapsulates all data processing and business logic as a bridge between model and UI
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
+
|
| 7 |
+
# Disable tokenizers parallelism to avoid fork warning
|
| 8 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 9 |
+
|
| 10 |
import math
|
| 11 |
from copy import deepcopy
|
| 12 |
import tempfile
|
|
|
|
| 74 |
self.offload_to_cpu = False
|
| 75 |
self.offload_dit_to_cpu = False
|
| 76 |
self.current_offload_cost = 0.0
|
| 77 |
+
|
| 78 |
+
# LoRA state
|
| 79 |
+
self.lora_loaded = False
|
| 80 |
+
self.use_lora = False
|
| 81 |
+
self._base_decoder = None # Backup of original decoder
|
| 82 |
|
| 83 |
def get_available_checkpoints(self) -> str:
|
| 84 |
"""Return project root directory path"""
|
|
|
|
| 123 |
return False
|
| 124 |
return getattr(self.config, 'is_turbo', False)
|
| 125 |
|
| 126 |
+
def load_lora(self, lora_path: str) -> str:
|
| 127 |
+
"""Load LoRA adapter into the decoder.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
lora_path: Path to the LoRA adapter directory (containing adapter_config.json)
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Status message
|
| 134 |
+
"""
|
| 135 |
+
if self.model is None:
|
| 136 |
+
return "❌ Model not initialized. Please initialize service first."
|
| 137 |
+
|
| 138 |
+
if not lora_path or not lora_path.strip():
|
| 139 |
+
return "❌ Please provide a LoRA path."
|
| 140 |
+
|
| 141 |
+
lora_path = lora_path.strip()
|
| 142 |
+
|
| 143 |
+
# Check if path exists
|
| 144 |
+
if not os.path.exists(lora_path):
|
| 145 |
+
return f"❌ LoRA path not found: {lora_path}"
|
| 146 |
+
|
| 147 |
+
# Check if it's a valid PEFT adapter directory
|
| 148 |
+
config_file = os.path.join(lora_path, "adapter_config.json")
|
| 149 |
+
if not os.path.exists(config_file):
|
| 150 |
+
return f"❌ Invalid LoRA adapter: adapter_config.json not found in {lora_path}"
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
from peft import PeftModel, PeftConfig
|
| 154 |
+
except ImportError:
|
| 155 |
+
return "❌ PEFT library not installed. Please install with: pip install peft"
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
# Backup base decoder if not already backed up
|
| 159 |
+
if self._base_decoder is None:
|
| 160 |
+
import copy
|
| 161 |
+
self._base_decoder = copy.deepcopy(self.model.decoder)
|
| 162 |
+
logger.info("Base decoder backed up")
|
| 163 |
+
else:
|
| 164 |
+
# Restore base decoder before loading new LoRA
|
| 165 |
+
self.model.decoder = copy.deepcopy(self._base_decoder)
|
| 166 |
+
logger.info("Restored base decoder before loading new LoRA")
|
| 167 |
+
|
| 168 |
+
# Load PEFT adapter
|
| 169 |
+
logger.info(f"Loading LoRA adapter from {lora_path}")
|
| 170 |
+
self.model.decoder = PeftModel.from_pretrained(
|
| 171 |
+
self.model.decoder,
|
| 172 |
+
lora_path,
|
| 173 |
+
is_trainable=False,
|
| 174 |
+
)
|
| 175 |
+
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
|
| 176 |
+
self.model.decoder.eval()
|
| 177 |
+
|
| 178 |
+
self.lora_loaded = True
|
| 179 |
+
self.use_lora = True # Enable LoRA by default after loading
|
| 180 |
+
|
| 181 |
+
logger.info(f"LoRA adapter loaded successfully from {lora_path}")
|
| 182 |
+
return f"✅ LoRA loaded from {lora_path}"
|
| 183 |
+
|
| 184 |
+
except Exception as e:
|
| 185 |
+
logger.exception("Failed to load LoRA adapter")
|
| 186 |
+
return f"❌ Failed to load LoRA: {str(e)}"
|
| 187 |
+
|
| 188 |
+
def unload_lora(self) -> str:
|
| 189 |
+
"""Unload LoRA adapter and restore base decoder.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Status message
|
| 193 |
+
"""
|
| 194 |
+
if not self.lora_loaded:
|
| 195 |
+
return "⚠️ No LoRA adapter loaded."
|
| 196 |
+
|
| 197 |
+
if self._base_decoder is None:
|
| 198 |
+
return "❌ Base decoder backup not found. Cannot restore."
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
import copy
|
| 202 |
+
# Restore base decoder
|
| 203 |
+
self.model.decoder = copy.deepcopy(self._base_decoder)
|
| 204 |
+
self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
|
| 205 |
+
self.model.decoder.eval()
|
| 206 |
+
|
| 207 |
+
self.lora_loaded = False
|
| 208 |
+
self.use_lora = False
|
| 209 |
+
|
| 210 |
+
logger.info("LoRA unloaded, base decoder restored")
|
| 211 |
+
return "✅ LoRA unloaded, using base model"
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.exception("Failed to unload LoRA")
|
| 215 |
+
return f"❌ Failed to unload LoRA: {str(e)}"
|
| 216 |
+
|
| 217 |
+
def set_use_lora(self, use_lora: bool) -> str:
|
| 218 |
+
"""Toggle LoRA usage for inference.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
use_lora: Whether to use LoRA adapter
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Status message
|
| 225 |
+
"""
|
| 226 |
+
if use_lora and not self.lora_loaded:
|
| 227 |
+
return "❌ No LoRA adapter loaded. Please load a LoRA first."
|
| 228 |
+
|
| 229 |
+
self.use_lora = use_lora
|
| 230 |
+
|
| 231 |
+
# Use PEFT's enable/disable methods if available
|
| 232 |
+
if self.lora_loaded and hasattr(self.model.decoder, 'disable_adapter_layers'):
|
| 233 |
+
try:
|
| 234 |
+
if use_lora:
|
| 235 |
+
self.model.decoder.enable_adapter_layers()
|
| 236 |
+
logger.info("LoRA adapter enabled")
|
| 237 |
+
else:
|
| 238 |
+
self.model.decoder.disable_adapter_layers()
|
| 239 |
+
logger.info("LoRA adapter disabled")
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logger.warning(f"Could not toggle adapter layers: {e}")
|
| 242 |
+
|
| 243 |
+
status = "enabled" if use_lora else "disabled"
|
| 244 |
+
return f"✅ LoRA {status}"
|
| 245 |
+
|
| 246 |
+
def get_lora_status(self) -> Dict[str, Any]:
|
| 247 |
+
"""Get current LoRA status.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Dictionary with LoRA status info
|
| 251 |
+
"""
|
| 252 |
+
return {
|
| 253 |
+
"loaded": self.lora_loaded,
|
| 254 |
+
"active": self.use_lora,
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
def initialize_service(
|
| 258 |
self,
|
| 259 |
project_root: str,
|
acestep/training/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step Training Module
|
| 3 |
+
|
| 4 |
+
This module provides LoRA training functionality for ACE-Step models,
|
| 5 |
+
including dataset building, audio labeling, and training utilities.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
|
| 9 |
+
from acestep.training.configs import LoRAConfig, TrainingConfig
|
| 10 |
+
from acestep.training.lora_utils import (
|
| 11 |
+
inject_lora_into_dit,
|
| 12 |
+
save_lora_weights,
|
| 13 |
+
load_lora_weights,
|
| 14 |
+
merge_lora_weights,
|
| 15 |
+
check_peft_available,
|
| 16 |
+
)
|
| 17 |
+
from acestep.training.data_module import (
|
| 18 |
+
# Preprocessed (recommended)
|
| 19 |
+
PreprocessedTensorDataset,
|
| 20 |
+
PreprocessedDataModule,
|
| 21 |
+
collate_preprocessed_batch,
|
| 22 |
+
# Legacy (raw audio)
|
| 23 |
+
AceStepTrainingDataset,
|
| 24 |
+
AceStepDataModule,
|
| 25 |
+
collate_training_batch,
|
| 26 |
+
load_dataset_from_json,
|
| 27 |
+
)
|
| 28 |
+
from acestep.training.trainer import LoRATrainer, PreprocessedLoRAModule, LIGHTNING_AVAILABLE
|
| 29 |
+
|
| 30 |
+
def check_lightning_available():
|
| 31 |
+
"""Check if Lightning Fabric is available."""
|
| 32 |
+
return LIGHTNING_AVAILABLE
|
| 33 |
+
|
| 34 |
+
__all__ = [
|
| 35 |
+
# Dataset Builder
|
| 36 |
+
"DatasetBuilder",
|
| 37 |
+
"AudioSample",
|
| 38 |
+
# Configs
|
| 39 |
+
"LoRAConfig",
|
| 40 |
+
"TrainingConfig",
|
| 41 |
+
# LoRA Utils
|
| 42 |
+
"inject_lora_into_dit",
|
| 43 |
+
"save_lora_weights",
|
| 44 |
+
"load_lora_weights",
|
| 45 |
+
"merge_lora_weights",
|
| 46 |
+
"check_peft_available",
|
| 47 |
+
# Data Module (Preprocessed - Recommended)
|
| 48 |
+
"PreprocessedTensorDataset",
|
| 49 |
+
"PreprocessedDataModule",
|
| 50 |
+
"collate_preprocessed_batch",
|
| 51 |
+
# Data Module (Legacy)
|
| 52 |
+
"AceStepTrainingDataset",
|
| 53 |
+
"AceStepDataModule",
|
| 54 |
+
"collate_training_batch",
|
| 55 |
+
"load_dataset_from_json",
|
| 56 |
+
# Trainer
|
| 57 |
+
"LoRATrainer",
|
| 58 |
+
"PreprocessedLoRAModule",
|
| 59 |
+
"check_lightning_available",
|
| 60 |
+
"LIGHTNING_AVAILABLE",
|
| 61 |
+
]
|
acestep/training/configs.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training Configuration Classes
|
| 3 |
+
|
| 4 |
+
Contains dataclasses for LoRA and training configurations.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class LoRAConfig:
|
| 13 |
+
"""Configuration for LoRA (Low-Rank Adaptation) training.
|
| 14 |
+
|
| 15 |
+
Attributes:
|
| 16 |
+
r: LoRA rank (dimension of low-rank matrices)
|
| 17 |
+
alpha: LoRA scaling factor (alpha/r determines the scaling)
|
| 18 |
+
dropout: Dropout probability for LoRA layers
|
| 19 |
+
target_modules: List of module names to apply LoRA to
|
| 20 |
+
bias: Whether to train bias parameters ("none", "all", or "lora_only")
|
| 21 |
+
"""
|
| 22 |
+
r: int = 8
|
| 23 |
+
alpha: int = 16
|
| 24 |
+
dropout: float = 0.1
|
| 25 |
+
target_modules: List[str] = field(default_factory=lambda: [
|
| 26 |
+
"q_proj", "k_proj", "v_proj", "o_proj"
|
| 27 |
+
])
|
| 28 |
+
bias: str = "none"
|
| 29 |
+
|
| 30 |
+
def to_dict(self):
|
| 31 |
+
"""Convert to dictionary for PEFT config."""
|
| 32 |
+
return {
|
| 33 |
+
"r": self.r,
|
| 34 |
+
"lora_alpha": self.alpha,
|
| 35 |
+
"lora_dropout": self.dropout,
|
| 36 |
+
"target_modules": self.target_modules,
|
| 37 |
+
"bias": self.bias,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class TrainingConfig:
|
| 43 |
+
"""Configuration for LoRA training process.
|
| 44 |
+
|
| 45 |
+
Training uses:
|
| 46 |
+
- BFloat16 precision (only supported precision)
|
| 47 |
+
- Discrete timesteps from turbo shift=3.0 schedule (8 steps)
|
| 48 |
+
- Randomly samples one of 8 timesteps per training step:
|
| 49 |
+
[1.0, 0.9545, 0.9, 0.8333, 0.75, 0.6429, 0.5, 0.3]
|
| 50 |
+
|
| 51 |
+
Attributes:
|
| 52 |
+
shift: Timestep shift factor (fixed at 3.0 for turbo model)
|
| 53 |
+
num_inference_steps: Number of inference steps (fixed at 8 for turbo)
|
| 54 |
+
learning_rate: Initial learning rate
|
| 55 |
+
batch_size: Training batch size
|
| 56 |
+
gradient_accumulation_steps: Number of gradient accumulation steps
|
| 57 |
+
max_epochs: Maximum number of training epochs
|
| 58 |
+
save_every_n_epochs: Save checkpoint every N epochs
|
| 59 |
+
warmup_steps: Number of warmup steps for learning rate scheduler
|
| 60 |
+
weight_decay: Weight decay for optimizer
|
| 61 |
+
max_grad_norm: Maximum gradient norm for clipping
|
| 62 |
+
mixed_precision: Always "bf16" (only supported precision)
|
| 63 |
+
seed: Random seed for reproducibility
|
| 64 |
+
output_dir: Directory to save checkpoints and logs
|
| 65 |
+
"""
|
| 66 |
+
# Fixed for turbo model
|
| 67 |
+
shift: float = 3.0 # Fixed: turbo uses shift=3.0
|
| 68 |
+
num_inference_steps: int = 8 # Fixed: turbo uses 8 steps
|
| 69 |
+
learning_rate: float = 1e-4
|
| 70 |
+
batch_size: int = 1
|
| 71 |
+
gradient_accumulation_steps: int = 4
|
| 72 |
+
max_epochs: int = 100
|
| 73 |
+
save_every_n_epochs: int = 10
|
| 74 |
+
warmup_steps: int = 100
|
| 75 |
+
weight_decay: float = 0.01
|
| 76 |
+
max_grad_norm: float = 1.0
|
| 77 |
+
mixed_precision: str = "bf16" # Fixed: only bf16 supported
|
| 78 |
+
seed: int = 42
|
| 79 |
+
output_dir: str = "./lora_output"
|
| 80 |
+
|
| 81 |
+
# Data loading
|
| 82 |
+
num_workers: int = 4
|
| 83 |
+
pin_memory: bool = True
|
| 84 |
+
|
| 85 |
+
# Logging
|
| 86 |
+
log_every_n_steps: int = 10
|
| 87 |
+
|
| 88 |
+
def to_dict(self):
|
| 89 |
+
"""Convert to dictionary."""
|
| 90 |
+
return {
|
| 91 |
+
"shift": self.shift,
|
| 92 |
+
"num_inference_steps": self.num_inference_steps,
|
| 93 |
+
"learning_rate": self.learning_rate,
|
| 94 |
+
"batch_size": self.batch_size,
|
| 95 |
+
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
| 96 |
+
"max_epochs": self.max_epochs,
|
| 97 |
+
"save_every_n_epochs": self.save_every_n_epochs,
|
| 98 |
+
"warmup_steps": self.warmup_steps,
|
| 99 |
+
"weight_decay": self.weight_decay,
|
| 100 |
+
"max_grad_norm": self.max_grad_norm,
|
| 101 |
+
"mixed_precision": self.mixed_precision,
|
| 102 |
+
"seed": self.seed,
|
| 103 |
+
"output_dir": self.output_dir,
|
| 104 |
+
"num_workers": self.num_workers,
|
| 105 |
+
"pin_memory": self.pin_memory,
|
| 106 |
+
"log_every_n_steps": self.log_every_n_steps,
|
| 107 |
+
}
|
acestep/training/data_module.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch Lightning DataModule for LoRA Training
|
| 3 |
+
|
| 4 |
+
Handles data loading and preprocessing for training ACE-Step LoRA adapters.
|
| 5 |
+
Supports both raw audio loading and preprocessed tensor loading.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import random
|
| 11 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 12 |
+
from loguru import logger
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio
|
| 16 |
+
from torch.utils.data import Dataset, DataLoader
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from lightning.pytorch import LightningDataModule
|
| 20 |
+
LIGHTNING_AVAILABLE = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
LIGHTNING_AVAILABLE = False
|
| 23 |
+
logger.warning("Lightning not installed. Training module will not be available.")
|
| 24 |
+
# Create a dummy class for type hints
|
| 25 |
+
class LightningDataModule:
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ============================================================================
|
| 30 |
+
# Preprocessed Tensor Dataset (Recommended for Training)
|
| 31 |
+
# ============================================================================
|
| 32 |
+
|
| 33 |
+
class PreprocessedTensorDataset(Dataset):
|
| 34 |
+
"""Dataset that loads preprocessed tensor files.
|
| 35 |
+
|
| 36 |
+
This is the recommended dataset for training as all tensors are pre-computed:
|
| 37 |
+
- target_latents: VAE-encoded audio [T, 64]
|
| 38 |
+
- encoder_hidden_states: Condition encoder output [L, D]
|
| 39 |
+
- encoder_attention_mask: Condition mask [L]
|
| 40 |
+
- context_latents: Source context [T, 65]
|
| 41 |
+
- attention_mask: Audio latent mask [T]
|
| 42 |
+
|
| 43 |
+
No VAE/text encoder needed during training - just load tensors directly!
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, tensor_dir: str):
|
| 47 |
+
"""Initialize from a directory of preprocessed .pt files.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
tensor_dir: Directory containing preprocessed .pt files and manifest.json
|
| 51 |
+
"""
|
| 52 |
+
self.tensor_dir = tensor_dir
|
| 53 |
+
self.sample_paths = []
|
| 54 |
+
|
| 55 |
+
# Load manifest if exists
|
| 56 |
+
manifest_path = os.path.join(tensor_dir, "manifest.json")
|
| 57 |
+
if os.path.exists(manifest_path):
|
| 58 |
+
with open(manifest_path, 'r') as f:
|
| 59 |
+
manifest = json.load(f)
|
| 60 |
+
self.sample_paths = manifest.get("samples", [])
|
| 61 |
+
else:
|
| 62 |
+
# Fallback: scan directory for .pt files
|
| 63 |
+
for f in os.listdir(tensor_dir):
|
| 64 |
+
if f.endswith('.pt') and f != "manifest.json":
|
| 65 |
+
self.sample_paths.append(os.path.join(tensor_dir, f))
|
| 66 |
+
|
| 67 |
+
# Validate paths
|
| 68 |
+
self.valid_paths = [p for p in self.sample_paths if os.path.exists(p)]
|
| 69 |
+
|
| 70 |
+
if len(self.valid_paths) != len(self.sample_paths):
|
| 71 |
+
logger.warning(f"Some tensor files not found: {len(self.sample_paths) - len(self.valid_paths)} missing")
|
| 72 |
+
|
| 73 |
+
logger.info(f"PreprocessedTensorDataset: {len(self.valid_paths)} samples from {tensor_dir}")
|
| 74 |
+
|
| 75 |
+
def __len__(self) -> int:
|
| 76 |
+
return len(self.valid_paths)
|
| 77 |
+
|
| 78 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 79 |
+
"""Load a preprocessed tensor file.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Dictionary containing all pre-computed tensors for training
|
| 83 |
+
"""
|
| 84 |
+
tensor_path = self.valid_paths[idx]
|
| 85 |
+
data = torch.load(tensor_path, map_location='cpu')
|
| 86 |
+
|
| 87 |
+
return {
|
| 88 |
+
"target_latents": data["target_latents"], # [T, 64]
|
| 89 |
+
"attention_mask": data["attention_mask"], # [T]
|
| 90 |
+
"encoder_hidden_states": data["encoder_hidden_states"], # [L, D]
|
| 91 |
+
"encoder_attention_mask": data["encoder_attention_mask"], # [L]
|
| 92 |
+
"context_latents": data["context_latents"], # [T, 65]
|
| 93 |
+
"metadata": data.get("metadata", {}),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
|
| 98 |
+
"""Collate function for preprocessed tensor batches.
|
| 99 |
+
|
| 100 |
+
Handles variable-length tensors by padding to the longest in the batch.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
batch: List of sample dictionaries with pre-computed tensors
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Batched dictionary with all tensors stacked
|
| 107 |
+
"""
|
| 108 |
+
# Get max lengths
|
| 109 |
+
max_latent_len = max(s["target_latents"].shape[0] for s in batch)
|
| 110 |
+
max_encoder_len = max(s["encoder_hidden_states"].shape[0] for s in batch)
|
| 111 |
+
|
| 112 |
+
# Pad and stack tensors
|
| 113 |
+
target_latents = []
|
| 114 |
+
attention_masks = []
|
| 115 |
+
encoder_hidden_states = []
|
| 116 |
+
encoder_attention_masks = []
|
| 117 |
+
context_latents = []
|
| 118 |
+
|
| 119 |
+
for sample in batch:
|
| 120 |
+
# Pad target_latents [T, 64] -> [max_T, 64]
|
| 121 |
+
tl = sample["target_latents"]
|
| 122 |
+
if tl.shape[0] < max_latent_len:
|
| 123 |
+
pad = torch.zeros(max_latent_len - tl.shape[0], tl.shape[1])
|
| 124 |
+
tl = torch.cat([tl, pad], dim=0)
|
| 125 |
+
target_latents.append(tl)
|
| 126 |
+
|
| 127 |
+
# Pad attention_mask [T] -> [max_T]
|
| 128 |
+
am = sample["attention_mask"]
|
| 129 |
+
if am.shape[0] < max_latent_len:
|
| 130 |
+
pad = torch.zeros(max_latent_len - am.shape[0])
|
| 131 |
+
am = torch.cat([am, pad], dim=0)
|
| 132 |
+
attention_masks.append(am)
|
| 133 |
+
|
| 134 |
+
# Pad context_latents [T, 65] -> [max_T, 65]
|
| 135 |
+
cl = sample["context_latents"]
|
| 136 |
+
if cl.shape[0] < max_latent_len:
|
| 137 |
+
pad = torch.zeros(max_latent_len - cl.shape[0], cl.shape[1])
|
| 138 |
+
cl = torch.cat([cl, pad], dim=0)
|
| 139 |
+
context_latents.append(cl)
|
| 140 |
+
|
| 141 |
+
# Pad encoder_hidden_states [L, D] -> [max_L, D]
|
| 142 |
+
ehs = sample["encoder_hidden_states"]
|
| 143 |
+
if ehs.shape[0] < max_encoder_len:
|
| 144 |
+
pad = torch.zeros(max_encoder_len - ehs.shape[0], ehs.shape[1])
|
| 145 |
+
ehs = torch.cat([ehs, pad], dim=0)
|
| 146 |
+
encoder_hidden_states.append(ehs)
|
| 147 |
+
|
| 148 |
+
# Pad encoder_attention_mask [L] -> [max_L]
|
| 149 |
+
eam = sample["encoder_attention_mask"]
|
| 150 |
+
if eam.shape[0] < max_encoder_len:
|
| 151 |
+
pad = torch.zeros(max_encoder_len - eam.shape[0])
|
| 152 |
+
eam = torch.cat([eam, pad], dim=0)
|
| 153 |
+
encoder_attention_masks.append(eam)
|
| 154 |
+
|
| 155 |
+
return {
|
| 156 |
+
"target_latents": torch.stack(target_latents), # [B, T, 64]
|
| 157 |
+
"attention_mask": torch.stack(attention_masks), # [B, T]
|
| 158 |
+
"encoder_hidden_states": torch.stack(encoder_hidden_states), # [B, L, D]
|
| 159 |
+
"encoder_attention_mask": torch.stack(encoder_attention_masks), # [B, L]
|
| 160 |
+
"context_latents": torch.stack(context_latents), # [B, T, 65]
|
| 161 |
+
"metadata": [s["metadata"] for s in batch],
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object):
|
| 166 |
+
"""DataModule for preprocessed tensor files.
|
| 167 |
+
|
| 168 |
+
This is the recommended DataModule for training. It loads pre-computed tensors
|
| 169 |
+
directly without needing VAE, text encoder, or condition encoder at training time.
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
tensor_dir: str,
|
| 175 |
+
batch_size: int = 1,
|
| 176 |
+
num_workers: int = 4,
|
| 177 |
+
pin_memory: bool = True,
|
| 178 |
+
val_split: float = 0.0,
|
| 179 |
+
):
|
| 180 |
+
"""Initialize the data module.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
tensor_dir: Directory containing preprocessed .pt files
|
| 184 |
+
batch_size: Training batch size
|
| 185 |
+
num_workers: Number of data loading workers
|
| 186 |
+
pin_memory: Whether to pin memory for faster GPU transfer
|
| 187 |
+
val_split: Fraction of data for validation (0 = no validation)
|
| 188 |
+
"""
|
| 189 |
+
if LIGHTNING_AVAILABLE:
|
| 190 |
+
super().__init__()
|
| 191 |
+
|
| 192 |
+
self.tensor_dir = tensor_dir
|
| 193 |
+
self.batch_size = batch_size
|
| 194 |
+
self.num_workers = num_workers
|
| 195 |
+
self.pin_memory = pin_memory
|
| 196 |
+
self.val_split = val_split
|
| 197 |
+
|
| 198 |
+
self.train_dataset = None
|
| 199 |
+
self.val_dataset = None
|
| 200 |
+
|
| 201 |
+
def setup(self, stage: Optional[str] = None):
|
| 202 |
+
"""Setup datasets."""
|
| 203 |
+
if stage == 'fit' or stage is None:
|
| 204 |
+
# Create full dataset
|
| 205 |
+
full_dataset = PreprocessedTensorDataset(self.tensor_dir)
|
| 206 |
+
|
| 207 |
+
# Split if validation requested
|
| 208 |
+
if self.val_split > 0 and len(full_dataset) > 1:
|
| 209 |
+
n_val = max(1, int(len(full_dataset) * self.val_split))
|
| 210 |
+
n_train = len(full_dataset) - n_val
|
| 211 |
+
|
| 212 |
+
self.train_dataset, self.val_dataset = torch.utils.data.random_split(
|
| 213 |
+
full_dataset, [n_train, n_val]
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
self.train_dataset = full_dataset
|
| 217 |
+
self.val_dataset = None
|
| 218 |
+
|
| 219 |
+
def train_dataloader(self) -> DataLoader:
|
| 220 |
+
"""Create training dataloader."""
|
| 221 |
+
return DataLoader(
|
| 222 |
+
self.train_dataset,
|
| 223 |
+
batch_size=self.batch_size,
|
| 224 |
+
shuffle=True,
|
| 225 |
+
num_workers=self.num_workers,
|
| 226 |
+
pin_memory=self.pin_memory,
|
| 227 |
+
collate_fn=collate_preprocessed_batch,
|
| 228 |
+
drop_last=True,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def val_dataloader(self) -> Optional[DataLoader]:
|
| 232 |
+
"""Create validation dataloader."""
|
| 233 |
+
if self.val_dataset is None:
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
return DataLoader(
|
| 237 |
+
self.val_dataset,
|
| 238 |
+
batch_size=self.batch_size,
|
| 239 |
+
shuffle=False,
|
| 240 |
+
num_workers=self.num_workers,
|
| 241 |
+
pin_memory=self.pin_memory,
|
| 242 |
+
collate_fn=collate_preprocessed_batch,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# ============================================================================
|
| 247 |
+
# Raw Audio Dataset (Legacy - for backward compatibility)
|
| 248 |
+
# ============================================================================
|
| 249 |
+
|
| 250 |
+
class AceStepTrainingDataset(Dataset):
|
| 251 |
+
"""Dataset for ACE-Step LoRA training from raw audio.
|
| 252 |
+
|
| 253 |
+
DEPRECATED: Use PreprocessedTensorDataset instead for better performance.
|
| 254 |
+
|
| 255 |
+
Audio Format Requirements (handled automatically):
|
| 256 |
+
- Sample rate: 48kHz (resampled if different)
|
| 257 |
+
- Channels: Stereo (2 channels, mono is duplicated)
|
| 258 |
+
- Max duration: 240 seconds (4 minutes)
|
| 259 |
+
- Min duration: 5 seconds (padded if shorter)
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
samples: List[Dict[str, Any]],
|
| 265 |
+
dit_handler,
|
| 266 |
+
max_duration: float = 240.0,
|
| 267 |
+
target_sample_rate: int = 48000,
|
| 268 |
+
):
|
| 269 |
+
"""Initialize the dataset."""
|
| 270 |
+
self.samples = samples
|
| 271 |
+
self.dit_handler = dit_handler
|
| 272 |
+
self.max_duration = max_duration
|
| 273 |
+
self.target_sample_rate = target_sample_rate
|
| 274 |
+
|
| 275 |
+
self.valid_samples = self._validate_samples()
|
| 276 |
+
logger.info(f"Dataset initialized with {len(self.valid_samples)} valid samples")
|
| 277 |
+
|
| 278 |
+
def _validate_samples(self) -> List[Dict[str, Any]]:
|
| 279 |
+
"""Validate and filter samples."""
|
| 280 |
+
valid = []
|
| 281 |
+
for i, sample in enumerate(self.samples):
|
| 282 |
+
audio_path = sample.get("audio_path", "")
|
| 283 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 284 |
+
logger.warning(f"Sample {i}: Audio file not found: {audio_path}")
|
| 285 |
+
continue
|
| 286 |
+
|
| 287 |
+
if not sample.get("caption"):
|
| 288 |
+
logger.warning(f"Sample {i}: Missing caption")
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
valid.append(sample)
|
| 292 |
+
|
| 293 |
+
return valid
|
| 294 |
+
|
| 295 |
+
def __len__(self) -> int:
|
| 296 |
+
return len(self.valid_samples)
|
| 297 |
+
|
| 298 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 299 |
+
"""Get a single training sample."""
|
| 300 |
+
sample = self.valid_samples[idx]
|
| 301 |
+
|
| 302 |
+
audio_path = sample["audio_path"]
|
| 303 |
+
audio, sr = torchaudio.load(audio_path)
|
| 304 |
+
|
| 305 |
+
# Resample to 48kHz
|
| 306 |
+
if sr != self.target_sample_rate:
|
| 307 |
+
resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
|
| 308 |
+
audio = resampler(audio)
|
| 309 |
+
|
| 310 |
+
# Convert to stereo
|
| 311 |
+
if audio.shape[0] == 1:
|
| 312 |
+
audio = audio.repeat(2, 1)
|
| 313 |
+
elif audio.shape[0] > 2:
|
| 314 |
+
audio = audio[:2, :]
|
| 315 |
+
|
| 316 |
+
# Truncate/pad
|
| 317 |
+
max_samples = int(self.max_duration * self.target_sample_rate)
|
| 318 |
+
if audio.shape[1] > max_samples:
|
| 319 |
+
audio = audio[:, :max_samples]
|
| 320 |
+
|
| 321 |
+
min_samples = int(5.0 * self.target_sample_rate)
|
| 322 |
+
if audio.shape[1] < min_samples:
|
| 323 |
+
padding = min_samples - audio.shape[1]
|
| 324 |
+
audio = torch.nn.functional.pad(audio, (0, padding))
|
| 325 |
+
|
| 326 |
+
return {
|
| 327 |
+
"audio": audio,
|
| 328 |
+
"caption": sample.get("caption", ""),
|
| 329 |
+
"lyrics": sample.get("lyrics", "[Instrumental]"),
|
| 330 |
+
"metadata": {
|
| 331 |
+
"caption": sample.get("caption", ""),
|
| 332 |
+
"lyrics": sample.get("lyrics", "[Instrumental]"),
|
| 333 |
+
"bpm": sample.get("bpm"),
|
| 334 |
+
"keyscale": sample.get("keyscale", ""),
|
| 335 |
+
"timesignature": sample.get("timesignature", ""),
|
| 336 |
+
"duration": sample.get("duration", audio.shape[1] / self.target_sample_rate),
|
| 337 |
+
"language": sample.get("language", "instrumental"),
|
| 338 |
+
"is_instrumental": sample.get("is_instrumental", True),
|
| 339 |
+
},
|
| 340 |
+
"audio_path": audio_path,
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def collate_training_batch(batch: List[Dict]) -> Dict[str, Any]:
|
| 345 |
+
"""Collate function for raw audio batches (legacy)."""
|
| 346 |
+
max_len = max(sample["audio"].shape[1] for sample in batch)
|
| 347 |
+
|
| 348 |
+
padded_audio = []
|
| 349 |
+
attention_masks = []
|
| 350 |
+
|
| 351 |
+
for sample in batch:
|
| 352 |
+
audio = sample["audio"]
|
| 353 |
+
audio_len = audio.shape[1]
|
| 354 |
+
|
| 355 |
+
if audio_len < max_len:
|
| 356 |
+
padding = max_len - audio_len
|
| 357 |
+
audio = torch.nn.functional.pad(audio, (0, padding))
|
| 358 |
+
|
| 359 |
+
padded_audio.append(audio)
|
| 360 |
+
|
| 361 |
+
mask = torch.ones(max_len)
|
| 362 |
+
if audio_len < max_len:
|
| 363 |
+
mask[audio_len:] = 0
|
| 364 |
+
attention_masks.append(mask)
|
| 365 |
+
|
| 366 |
+
return {
|
| 367 |
+
"audio": torch.stack(padded_audio),
|
| 368 |
+
"attention_mask": torch.stack(attention_masks),
|
| 369 |
+
"captions": [s["caption"] for s in batch],
|
| 370 |
+
"lyrics": [s["lyrics"] for s in batch],
|
| 371 |
+
"metadata": [s["metadata"] for s in batch],
|
| 372 |
+
"audio_paths": [s["audio_path"] for s in batch],
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class AceStepDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object):
|
| 377 |
+
"""DataModule for raw audio loading (legacy).
|
| 378 |
+
|
| 379 |
+
DEPRECATED: Use PreprocessedDataModule for better training performance.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
def __init__(
|
| 383 |
+
self,
|
| 384 |
+
samples: List[Dict[str, Any]],
|
| 385 |
+
dit_handler,
|
| 386 |
+
batch_size: int = 1,
|
| 387 |
+
num_workers: int = 4,
|
| 388 |
+
pin_memory: bool = True,
|
| 389 |
+
max_duration: float = 240.0,
|
| 390 |
+
val_split: float = 0.0,
|
| 391 |
+
):
|
| 392 |
+
if LIGHTNING_AVAILABLE:
|
| 393 |
+
super().__init__()
|
| 394 |
+
|
| 395 |
+
self.samples = samples
|
| 396 |
+
self.dit_handler = dit_handler
|
| 397 |
+
self.batch_size = batch_size
|
| 398 |
+
self.num_workers = num_workers
|
| 399 |
+
self.pin_memory = pin_memory
|
| 400 |
+
self.max_duration = max_duration
|
| 401 |
+
self.val_split = val_split
|
| 402 |
+
|
| 403 |
+
self.train_dataset = None
|
| 404 |
+
self.val_dataset = None
|
| 405 |
+
|
| 406 |
+
def setup(self, stage: Optional[str] = None):
|
| 407 |
+
if stage == 'fit' or stage is None:
|
| 408 |
+
if self.val_split > 0 and len(self.samples) > 1:
|
| 409 |
+
n_val = max(1, int(len(self.samples) * self.val_split))
|
| 410 |
+
|
| 411 |
+
indices = list(range(len(self.samples)))
|
| 412 |
+
random.shuffle(indices)
|
| 413 |
+
|
| 414 |
+
val_indices = indices[:n_val]
|
| 415 |
+
train_indices = indices[n_val:]
|
| 416 |
+
|
| 417 |
+
train_samples = [self.samples[i] for i in train_indices]
|
| 418 |
+
val_samples = [self.samples[i] for i in val_indices]
|
| 419 |
+
|
| 420 |
+
self.train_dataset = AceStepTrainingDataset(
|
| 421 |
+
train_samples, self.dit_handler, self.max_duration
|
| 422 |
+
)
|
| 423 |
+
self.val_dataset = AceStepTrainingDataset(
|
| 424 |
+
val_samples, self.dit_handler, self.max_duration
|
| 425 |
+
)
|
| 426 |
+
else:
|
| 427 |
+
self.train_dataset = AceStepTrainingDataset(
|
| 428 |
+
self.samples, self.dit_handler, self.max_duration
|
| 429 |
+
)
|
| 430 |
+
self.val_dataset = None
|
| 431 |
+
|
| 432 |
+
def train_dataloader(self) -> DataLoader:
|
| 433 |
+
return DataLoader(
|
| 434 |
+
self.train_dataset,
|
| 435 |
+
batch_size=self.batch_size,
|
| 436 |
+
shuffle=True,
|
| 437 |
+
num_workers=self.num_workers,
|
| 438 |
+
pin_memory=self.pin_memory,
|
| 439 |
+
collate_fn=collate_training_batch,
|
| 440 |
+
drop_last=True,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
def val_dataloader(self) -> Optional[DataLoader]:
|
| 444 |
+
if self.val_dataset is None:
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
return DataLoader(
|
| 448 |
+
self.val_dataset,
|
| 449 |
+
batch_size=self.batch_size,
|
| 450 |
+
shuffle=False,
|
| 451 |
+
num_workers=self.num_workers,
|
| 452 |
+
pin_memory=self.pin_memory,
|
| 453 |
+
collate_fn=collate_training_batch,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def load_dataset_from_json(json_path: str) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
| 458 |
+
"""Load a dataset from JSON file."""
|
| 459 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 460 |
+
data = json.load(f)
|
| 461 |
+
|
| 462 |
+
metadata = data.get("metadata", {})
|
| 463 |
+
samples = data.get("samples", [])
|
| 464 |
+
|
| 465 |
+
return samples, metadata
|
acestep/training/dataset_builder.py
ADDED
|
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Builder for LoRA Training
|
| 3 |
+
|
| 4 |
+
Provides functionality to:
|
| 5 |
+
1. Scan directories for audio files
|
| 6 |
+
2. Auto-label audio using LLM
|
| 7 |
+
3. Preview and edit metadata
|
| 8 |
+
4. Save datasets in JSON format
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import uuid
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from dataclasses import dataclass, field, asdict
|
| 16 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torchaudio
|
| 21 |
+
from loguru import logger
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Supported audio formats
|
| 25 |
+
SUPPORTED_AUDIO_FORMATS = {'.wav', '.mp3', '.flac', '.ogg', '.opus'}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class AudioSample:
|
| 30 |
+
"""Represents a single audio sample with its metadata.
|
| 31 |
+
|
| 32 |
+
Attributes:
|
| 33 |
+
id: Unique identifier for the sample
|
| 34 |
+
audio_path: Path to the audio file
|
| 35 |
+
filename: Original filename
|
| 36 |
+
caption: Generated or user-provided caption describing the music
|
| 37 |
+
lyrics: Lyrics or "[Instrumental]" for instrumental tracks
|
| 38 |
+
bpm: Beats per minute
|
| 39 |
+
keyscale: Musical key (e.g., "C Major", "Am")
|
| 40 |
+
timesignature: Time signature (e.g., "4" for 4/4)
|
| 41 |
+
duration: Duration in seconds
|
| 42 |
+
language: Vocal language or "instrumental"
|
| 43 |
+
is_instrumental: Whether the track is instrumental
|
| 44 |
+
custom_tag: User-defined activation tag for LoRA
|
| 45 |
+
labeled: Whether the sample has been labeled
|
| 46 |
+
"""
|
| 47 |
+
id: str = ""
|
| 48 |
+
audio_path: str = ""
|
| 49 |
+
filename: str = ""
|
| 50 |
+
caption: str = ""
|
| 51 |
+
lyrics: str = "[Instrumental]"
|
| 52 |
+
bpm: Optional[int] = None
|
| 53 |
+
keyscale: str = ""
|
| 54 |
+
timesignature: str = ""
|
| 55 |
+
duration: float = 0.0
|
| 56 |
+
language: str = "instrumental"
|
| 57 |
+
is_instrumental: bool = True
|
| 58 |
+
custom_tag: str = ""
|
| 59 |
+
labeled: bool = False
|
| 60 |
+
|
| 61 |
+
def __post_init__(self):
|
| 62 |
+
if not self.id:
|
| 63 |
+
self.id = str(uuid.uuid4())[:8]
|
| 64 |
+
|
| 65 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 66 |
+
"""Convert to dictionary."""
|
| 67 |
+
return asdict(self)
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
def from_dict(cls, data: Dict[str, Any]) -> "AudioSample":
|
| 71 |
+
"""Create from dictionary."""
|
| 72 |
+
return cls(**data)
|
| 73 |
+
|
| 74 |
+
def get_full_caption(self, tag_position: str = "prepend") -> str:
|
| 75 |
+
"""Get caption with custom tag applied.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
tag_position: Where to place the custom tag ("prepend", "append", "replace")
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Caption with custom tag applied
|
| 82 |
+
"""
|
| 83 |
+
if not self.custom_tag:
|
| 84 |
+
return self.caption
|
| 85 |
+
|
| 86 |
+
if tag_position == "prepend":
|
| 87 |
+
return f"{self.custom_tag}, {self.caption}" if self.caption else self.custom_tag
|
| 88 |
+
elif tag_position == "append":
|
| 89 |
+
return f"{self.caption}, {self.custom_tag}" if self.caption else self.custom_tag
|
| 90 |
+
elif tag_position == "replace":
|
| 91 |
+
return self.custom_tag
|
| 92 |
+
else:
|
| 93 |
+
return self.caption
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass
|
| 97 |
+
class DatasetMetadata:
|
| 98 |
+
"""Metadata for the entire dataset.
|
| 99 |
+
|
| 100 |
+
Attributes:
|
| 101 |
+
name: Dataset name
|
| 102 |
+
custom_tag: Default custom tag for all samples
|
| 103 |
+
tag_position: Where to place custom tag ("prepend", "append", "replace")
|
| 104 |
+
created_at: Creation timestamp
|
| 105 |
+
num_samples: Number of samples in the dataset
|
| 106 |
+
all_instrumental: Whether all tracks are instrumental
|
| 107 |
+
"""
|
| 108 |
+
name: str = "untitled_dataset"
|
| 109 |
+
custom_tag: str = ""
|
| 110 |
+
tag_position: str = "prepend"
|
| 111 |
+
created_at: str = ""
|
| 112 |
+
num_samples: int = 0
|
| 113 |
+
all_instrumental: bool = True
|
| 114 |
+
|
| 115 |
+
def __post_init__(self):
|
| 116 |
+
if not self.created_at:
|
| 117 |
+
self.created_at = datetime.now().isoformat()
|
| 118 |
+
|
| 119 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 120 |
+
"""Convert to dictionary."""
|
| 121 |
+
return asdict(self)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class DatasetBuilder:
|
| 125 |
+
"""Builder for creating training datasets from audio files.
|
| 126 |
+
|
| 127 |
+
This class handles:
|
| 128 |
+
- Scanning directories for audio files
|
| 129 |
+
- Auto-labeling using LLM
|
| 130 |
+
- Managing sample metadata
|
| 131 |
+
- Saving/loading datasets
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def __init__(self):
|
| 135 |
+
"""Initialize the dataset builder."""
|
| 136 |
+
self.samples: List[AudioSample] = []
|
| 137 |
+
self.metadata = DatasetMetadata()
|
| 138 |
+
self._current_dir: str = ""
|
| 139 |
+
|
| 140 |
+
def scan_directory(self, directory: str) -> Tuple[List[AudioSample], str]:
|
| 141 |
+
"""Scan a directory for audio files.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
directory: Path to directory containing audio files
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Tuple of (list of AudioSample objects, status message)
|
| 148 |
+
"""
|
| 149 |
+
if not os.path.exists(directory):
|
| 150 |
+
return [], f"❌ Directory not found: {directory}"
|
| 151 |
+
|
| 152 |
+
if not os.path.isdir(directory):
|
| 153 |
+
return [], f"❌ Not a directory: {directory}"
|
| 154 |
+
|
| 155 |
+
self._current_dir = directory
|
| 156 |
+
self.samples = []
|
| 157 |
+
|
| 158 |
+
# Scan for audio files
|
| 159 |
+
audio_files = []
|
| 160 |
+
for root, dirs, files in os.walk(directory):
|
| 161 |
+
for file in files:
|
| 162 |
+
ext = os.path.splitext(file)[1].lower()
|
| 163 |
+
if ext in SUPPORTED_AUDIO_FORMATS:
|
| 164 |
+
audio_files.append(os.path.join(root, file))
|
| 165 |
+
|
| 166 |
+
if not audio_files:
|
| 167 |
+
return [], f"❌ No audio files found in {directory}\nSupported formats: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
|
| 168 |
+
|
| 169 |
+
# Sort files by name
|
| 170 |
+
audio_files.sort()
|
| 171 |
+
|
| 172 |
+
# Create AudioSample objects
|
| 173 |
+
for audio_path in audio_files:
|
| 174 |
+
try:
|
| 175 |
+
# Get duration
|
| 176 |
+
duration = self._get_audio_duration(audio_path)
|
| 177 |
+
|
| 178 |
+
sample = AudioSample(
|
| 179 |
+
audio_path=audio_path,
|
| 180 |
+
filename=os.path.basename(audio_path),
|
| 181 |
+
duration=duration,
|
| 182 |
+
is_instrumental=self.metadata.all_instrumental,
|
| 183 |
+
custom_tag=self.metadata.custom_tag,
|
| 184 |
+
)
|
| 185 |
+
self.samples.append(sample)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.warning(f"Failed to process {audio_path}: {e}")
|
| 188 |
+
|
| 189 |
+
self.metadata.num_samples = len(self.samples)
|
| 190 |
+
|
| 191 |
+
status = f"✅ Found {len(self.samples)} audio files in {directory}"
|
| 192 |
+
return self.samples, status
|
| 193 |
+
|
| 194 |
+
def _get_audio_duration(self, audio_path: str) -> float:
|
| 195 |
+
"""Get the duration of an audio file in seconds.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
audio_path: Path to audio file
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
Duration in seconds
|
| 202 |
+
"""
|
| 203 |
+
try:
|
| 204 |
+
info = torchaudio.info(audio_path)
|
| 205 |
+
return info.num_frames / info.sample_rate
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.warning(f"Failed to get duration for {audio_path}: {e}")
|
| 208 |
+
return 0.0
|
| 209 |
+
|
| 210 |
+
def label_sample(
|
| 211 |
+
self,
|
| 212 |
+
sample_idx: int,
|
| 213 |
+
dit_handler,
|
| 214 |
+
llm_handler,
|
| 215 |
+
progress_callback=None,
|
| 216 |
+
) -> Tuple[AudioSample, str]:
|
| 217 |
+
"""Label a single sample using the LLM.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
sample_idx: Index of sample to label
|
| 221 |
+
dit_handler: DiT handler for audio encoding
|
| 222 |
+
llm_handler: LLM handler for caption generation
|
| 223 |
+
progress_callback: Optional callback for progress updates
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
Tuple of (updated AudioSample, status message)
|
| 227 |
+
"""
|
| 228 |
+
if sample_idx < 0 or sample_idx >= len(self.samples):
|
| 229 |
+
return None, f"❌ Invalid sample index: {sample_idx}"
|
| 230 |
+
|
| 231 |
+
sample = self.samples[sample_idx]
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
if progress_callback:
|
| 235 |
+
progress_callback(f"Processing: {sample.filename}")
|
| 236 |
+
|
| 237 |
+
# Step 1: Load and encode audio to get audio codes
|
| 238 |
+
audio_codes = self._get_audio_codes(sample.audio_path, dit_handler)
|
| 239 |
+
|
| 240 |
+
if not audio_codes:
|
| 241 |
+
return sample, f"❌ Failed to encode audio: {sample.filename}"
|
| 242 |
+
|
| 243 |
+
if progress_callback:
|
| 244 |
+
progress_callback(f"Generating metadata for: {sample.filename}")
|
| 245 |
+
|
| 246 |
+
# Step 2: Use LLM to understand the audio
|
| 247 |
+
metadata, status = llm_handler.understand_audio_from_codes(
|
| 248 |
+
audio_codes=audio_codes,
|
| 249 |
+
temperature=0.7,
|
| 250 |
+
use_constrained_decoding=True,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if not metadata:
|
| 254 |
+
return sample, f"❌ LLM labeling failed: {status}"
|
| 255 |
+
|
| 256 |
+
# Step 3: Update sample with generated metadata
|
| 257 |
+
sample.caption = metadata.get('caption', '')
|
| 258 |
+
sample.bpm = self._parse_int(metadata.get('bpm'))
|
| 259 |
+
sample.keyscale = metadata.get('keyscale', '')
|
| 260 |
+
sample.timesignature = metadata.get('timesignature', '')
|
| 261 |
+
sample.language = metadata.get('vocal_language', 'instrumental')
|
| 262 |
+
|
| 263 |
+
# Handle lyrics based on instrumental flag
|
| 264 |
+
if sample.is_instrumental:
|
| 265 |
+
sample.lyrics = "[Instrumental]"
|
| 266 |
+
sample.language = "instrumental"
|
| 267 |
+
else:
|
| 268 |
+
sample.lyrics = metadata.get('lyrics', '')
|
| 269 |
+
|
| 270 |
+
# NOTE: Duration is NOT overwritten from LM metadata.
|
| 271 |
+
# We keep the real audio duration obtained from torchaudio during scan.
|
| 272 |
+
|
| 273 |
+
sample.labeled = True
|
| 274 |
+
self.samples[sample_idx] = sample
|
| 275 |
+
|
| 276 |
+
return sample, f"✅ Labeled: {sample.filename}"
|
| 277 |
+
|
| 278 |
+
except Exception as e:
|
| 279 |
+
logger.exception(f"Error labeling sample {sample.filename}")
|
| 280 |
+
return sample, f"❌ Error: {str(e)}"
|
| 281 |
+
|
| 282 |
+
def label_all_samples(
|
| 283 |
+
self,
|
| 284 |
+
dit_handler,
|
| 285 |
+
llm_handler,
|
| 286 |
+
progress_callback=None,
|
| 287 |
+
) -> Tuple[List[AudioSample], str]:
|
| 288 |
+
"""Label all samples in the dataset.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
dit_handler: DiT handler for audio encoding
|
| 292 |
+
llm_handler: LLM handler for caption generation
|
| 293 |
+
progress_callback: Optional callback for progress updates
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
Tuple of (list of updated samples, status message)
|
| 297 |
+
"""
|
| 298 |
+
if not self.samples:
|
| 299 |
+
return [], "❌ No samples to label. Please scan a directory first."
|
| 300 |
+
|
| 301 |
+
success_count = 0
|
| 302 |
+
fail_count = 0
|
| 303 |
+
|
| 304 |
+
for i, sample in enumerate(self.samples):
|
| 305 |
+
if progress_callback:
|
| 306 |
+
progress_callback(f"Labeling {i+1}/{len(self.samples)}: {sample.filename}")
|
| 307 |
+
|
| 308 |
+
_, status = self.label_sample(i, dit_handler, llm_handler, progress_callback)
|
| 309 |
+
|
| 310 |
+
if "✅" in status:
|
| 311 |
+
success_count += 1
|
| 312 |
+
else:
|
| 313 |
+
fail_count += 1
|
| 314 |
+
|
| 315 |
+
status_msg = f"✅ Labeled {success_count}/{len(self.samples)} samples"
|
| 316 |
+
if fail_count > 0:
|
| 317 |
+
status_msg += f" ({fail_count} failed)"
|
| 318 |
+
|
| 319 |
+
return self.samples, status_msg
|
| 320 |
+
|
| 321 |
+
def _get_audio_codes(self, audio_path: str, dit_handler) -> Optional[str]:
|
| 322 |
+
"""Encode audio to get semantic codes for LLM understanding.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
audio_path: Path to audio file
|
| 326 |
+
dit_handler: DiT handler with VAE and tokenizer
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
Audio codes string or None if failed
|
| 330 |
+
"""
|
| 331 |
+
try:
|
| 332 |
+
# Check if handler has required methods
|
| 333 |
+
if not hasattr(dit_handler, 'convert_src_audio_to_codes'):
|
| 334 |
+
logger.error("DiT handler missing convert_src_audio_to_codes method")
|
| 335 |
+
return None
|
| 336 |
+
|
| 337 |
+
# Use handler's method to convert audio to codes
|
| 338 |
+
codes_string = dit_handler.convert_src_audio_to_codes(audio_path)
|
| 339 |
+
|
| 340 |
+
if codes_string and not codes_string.startswith("❌"):
|
| 341 |
+
return codes_string
|
| 342 |
+
else:
|
| 343 |
+
logger.warning(f"Failed to convert audio to codes: {codes_string}")
|
| 344 |
+
return None
|
| 345 |
+
|
| 346 |
+
except Exception as e:
|
| 347 |
+
logger.exception(f"Error encoding audio {audio_path}")
|
| 348 |
+
return None
|
| 349 |
+
|
| 350 |
+
def _parse_int(self, value: Any) -> Optional[int]:
|
| 351 |
+
"""Safely parse an integer value."""
|
| 352 |
+
if value is None or value == "N/A" or value == "":
|
| 353 |
+
return None
|
| 354 |
+
try:
|
| 355 |
+
return int(value)
|
| 356 |
+
except (ValueError, TypeError):
|
| 357 |
+
return None
|
| 358 |
+
|
| 359 |
+
def update_sample(self, sample_idx: int, **kwargs) -> Tuple[AudioSample, str]:
|
| 360 |
+
"""Update a sample's metadata.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
sample_idx: Index of sample to update
|
| 364 |
+
**kwargs: Fields to update
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Tuple of (updated sample, status message)
|
| 368 |
+
"""
|
| 369 |
+
if sample_idx < 0 or sample_idx >= len(self.samples):
|
| 370 |
+
return None, f"❌ Invalid sample index: {sample_idx}"
|
| 371 |
+
|
| 372 |
+
sample = self.samples[sample_idx]
|
| 373 |
+
|
| 374 |
+
for key, value in kwargs.items():
|
| 375 |
+
if hasattr(sample, key):
|
| 376 |
+
setattr(sample, key, value)
|
| 377 |
+
|
| 378 |
+
self.samples[sample_idx] = sample
|
| 379 |
+
return sample, f"✅ Updated: {sample.filename}"
|
| 380 |
+
|
| 381 |
+
def set_custom_tag(self, custom_tag: str, tag_position: str = "prepend"):
|
| 382 |
+
"""Set the custom tag for all samples.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
custom_tag: Custom activation tag
|
| 386 |
+
tag_position: Where to place tag ("prepend", "append", "replace")
|
| 387 |
+
"""
|
| 388 |
+
self.metadata.custom_tag = custom_tag
|
| 389 |
+
self.metadata.tag_position = tag_position
|
| 390 |
+
|
| 391 |
+
for sample in self.samples:
|
| 392 |
+
sample.custom_tag = custom_tag
|
| 393 |
+
|
| 394 |
+
def set_all_instrumental(self, is_instrumental: bool):
|
| 395 |
+
"""Set instrumental flag for all samples.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
is_instrumental: Whether all tracks are instrumental
|
| 399 |
+
"""
|
| 400 |
+
self.metadata.all_instrumental = is_instrumental
|
| 401 |
+
|
| 402 |
+
for sample in self.samples:
|
| 403 |
+
sample.is_instrumental = is_instrumental
|
| 404 |
+
if is_instrumental:
|
| 405 |
+
sample.lyrics = "[Instrumental]"
|
| 406 |
+
sample.language = "instrumental"
|
| 407 |
+
|
| 408 |
+
def get_sample_count(self) -> int:
|
| 409 |
+
"""Get the number of samples in the dataset."""
|
| 410 |
+
return len(self.samples)
|
| 411 |
+
|
| 412 |
+
def get_labeled_count(self) -> int:
|
| 413 |
+
"""Get the number of labeled samples."""
|
| 414 |
+
return sum(1 for s in self.samples if s.labeled)
|
| 415 |
+
|
| 416 |
+
def save_dataset(self, output_path: str, dataset_name: str = None) -> str:
|
| 417 |
+
"""Save the dataset to a JSON file.
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
output_path: Path to save the dataset JSON
|
| 421 |
+
dataset_name: Optional name for the dataset
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
Status message
|
| 425 |
+
"""
|
| 426 |
+
if not self.samples:
|
| 427 |
+
return "❌ No samples to save"
|
| 428 |
+
|
| 429 |
+
if dataset_name:
|
| 430 |
+
self.metadata.name = dataset_name
|
| 431 |
+
|
| 432 |
+
self.metadata.num_samples = len(self.samples)
|
| 433 |
+
self.metadata.created_at = datetime.now().isoformat()
|
| 434 |
+
|
| 435 |
+
# Build dataset with captions that include custom tags
|
| 436 |
+
dataset = {
|
| 437 |
+
"metadata": self.metadata.to_dict(),
|
| 438 |
+
"samples": []
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
for sample in self.samples:
|
| 442 |
+
sample_dict = sample.to_dict()
|
| 443 |
+
# Apply custom tag to caption based on position
|
| 444 |
+
sample_dict["caption"] = sample.get_full_caption(self.metadata.tag_position)
|
| 445 |
+
dataset["samples"].append(sample_dict)
|
| 446 |
+
|
| 447 |
+
try:
|
| 448 |
+
# Ensure output directory exists
|
| 449 |
+
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
|
| 450 |
+
|
| 451 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 452 |
+
json.dump(dataset, f, indent=2, ensure_ascii=False)
|
| 453 |
+
|
| 454 |
+
return f"✅ Dataset saved to {output_path}\n{len(self.samples)} samples, tag: '{self.metadata.custom_tag}'"
|
| 455 |
+
except Exception as e:
|
| 456 |
+
logger.exception("Error saving dataset")
|
| 457 |
+
return f"❌ Failed to save dataset: {str(e)}"
|
| 458 |
+
|
| 459 |
+
def load_dataset(self, dataset_path: str) -> Tuple[List[AudioSample], str]:
|
| 460 |
+
"""Load a dataset from a JSON file.
|
| 461 |
+
|
| 462 |
+
Args:
|
| 463 |
+
dataset_path: Path to the dataset JSON file
|
| 464 |
+
|
| 465 |
+
Returns:
|
| 466 |
+
Tuple of (list of samples, status message)
|
| 467 |
+
"""
|
| 468 |
+
if not os.path.exists(dataset_path):
|
| 469 |
+
return [], f"❌ Dataset not found: {dataset_path}"
|
| 470 |
+
|
| 471 |
+
try:
|
| 472 |
+
with open(dataset_path, 'r', encoding='utf-8') as f:
|
| 473 |
+
data = json.load(f)
|
| 474 |
+
|
| 475 |
+
# Load metadata
|
| 476 |
+
if "metadata" in data:
|
| 477 |
+
meta_dict = data["metadata"]
|
| 478 |
+
self.metadata = DatasetMetadata(
|
| 479 |
+
name=meta_dict.get("name", "untitled"),
|
| 480 |
+
custom_tag=meta_dict.get("custom_tag", ""),
|
| 481 |
+
tag_position=meta_dict.get("tag_position", "prepend"),
|
| 482 |
+
created_at=meta_dict.get("created_at", ""),
|
| 483 |
+
num_samples=meta_dict.get("num_samples", 0),
|
| 484 |
+
all_instrumental=meta_dict.get("all_instrumental", True),
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# Load samples
|
| 488 |
+
self.samples = []
|
| 489 |
+
for sample_dict in data.get("samples", []):
|
| 490 |
+
sample = AudioSample.from_dict(sample_dict)
|
| 491 |
+
self.samples.append(sample)
|
| 492 |
+
|
| 493 |
+
return self.samples, f"✅ Loaded {len(self.samples)} samples from {dataset_path}"
|
| 494 |
+
|
| 495 |
+
except Exception as e:
|
| 496 |
+
logger.exception("Error loading dataset")
|
| 497 |
+
return [], f"❌ Failed to load dataset: {str(e)}"
|
| 498 |
+
|
| 499 |
+
def get_samples_dataframe_data(self) -> List[List[Any]]:
|
| 500 |
+
"""Get samples data in a format suitable for Gradio DataFrame.
|
| 501 |
+
|
| 502 |
+
Returns:
|
| 503 |
+
List of rows for DataFrame display
|
| 504 |
+
"""
|
| 505 |
+
rows = []
|
| 506 |
+
for i, sample in enumerate(self.samples):
|
| 507 |
+
rows.append([
|
| 508 |
+
i,
|
| 509 |
+
sample.filename,
|
| 510 |
+
f"{sample.duration:.1f}s",
|
| 511 |
+
"✅" if sample.labeled else "❌",
|
| 512 |
+
sample.bpm or "-",
|
| 513 |
+
sample.keyscale or "-",
|
| 514 |
+
sample.caption[:50] + "..." if len(sample.caption) > 50 else sample.caption or "-",
|
| 515 |
+
])
|
| 516 |
+
return rows
|
| 517 |
+
|
| 518 |
+
def to_training_format(self) -> List[Dict[str, Any]]:
|
| 519 |
+
"""Convert dataset to format suitable for training.
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
List of training sample dictionaries
|
| 523 |
+
"""
|
| 524 |
+
training_samples = []
|
| 525 |
+
|
| 526 |
+
for sample in self.samples:
|
| 527 |
+
if not sample.labeled:
|
| 528 |
+
continue
|
| 529 |
+
|
| 530 |
+
training_sample = {
|
| 531 |
+
"audio_path": sample.audio_path,
|
| 532 |
+
"caption": sample.get_full_caption(self.metadata.tag_position),
|
| 533 |
+
"lyrics": sample.lyrics,
|
| 534 |
+
"bpm": sample.bpm,
|
| 535 |
+
"keyscale": sample.keyscale,
|
| 536 |
+
"timesignature": sample.timesignature,
|
| 537 |
+
"duration": sample.duration,
|
| 538 |
+
"language": sample.language,
|
| 539 |
+
"is_instrumental": sample.is_instrumental,
|
| 540 |
+
}
|
| 541 |
+
training_samples.append(training_sample)
|
| 542 |
+
|
| 543 |
+
return training_samples
|
| 544 |
+
|
| 545 |
+
def preprocess_to_tensors(
|
| 546 |
+
self,
|
| 547 |
+
dit_handler,
|
| 548 |
+
output_dir: str,
|
| 549 |
+
max_duration: float = 240.0,
|
| 550 |
+
progress_callback=None,
|
| 551 |
+
) -> Tuple[List[str], str]:
|
| 552 |
+
"""Preprocess all labeled samples to tensor files for efficient training.
|
| 553 |
+
|
| 554 |
+
This method pre-computes all tensors needed by the DiT decoder:
|
| 555 |
+
- target_latents: VAE-encoded audio
|
| 556 |
+
- encoder_hidden_states: Condition encoder output
|
| 557 |
+
- context_latents: Source context (silence_latent + zeros for text2music)
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
dit_handler: Initialized DiT handler with model, VAE, and text encoder
|
| 561 |
+
output_dir: Directory to save preprocessed .pt files
|
| 562 |
+
max_duration: Maximum audio duration in seconds (default 240s = 4 min)
|
| 563 |
+
progress_callback: Optional callback for progress updates
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
Tuple of (list of output paths, status message)
|
| 567 |
+
"""
|
| 568 |
+
if not self.samples:
|
| 569 |
+
return [], "❌ No samples to preprocess"
|
| 570 |
+
|
| 571 |
+
labeled_samples = [s for s in self.samples if s.labeled]
|
| 572 |
+
if not labeled_samples:
|
| 573 |
+
return [], "❌ No labeled samples to preprocess"
|
| 574 |
+
|
| 575 |
+
# Validate handler
|
| 576 |
+
if dit_handler is None or dit_handler.model is None:
|
| 577 |
+
return [], "❌ Model not initialized. Please initialize the service first."
|
| 578 |
+
|
| 579 |
+
# Create output directory
|
| 580 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 581 |
+
|
| 582 |
+
output_paths = []
|
| 583 |
+
success_count = 0
|
| 584 |
+
fail_count = 0
|
| 585 |
+
|
| 586 |
+
# Get model and components
|
| 587 |
+
model = dit_handler.model
|
| 588 |
+
vae = dit_handler.vae
|
| 589 |
+
text_encoder = dit_handler.text_encoder
|
| 590 |
+
text_tokenizer = dit_handler.text_tokenizer
|
| 591 |
+
silence_latent = dit_handler.silence_latent
|
| 592 |
+
device = dit_handler.device
|
| 593 |
+
dtype = dit_handler.dtype
|
| 594 |
+
|
| 595 |
+
target_sample_rate = 48000
|
| 596 |
+
|
| 597 |
+
for i, sample in enumerate(labeled_samples):
|
| 598 |
+
try:
|
| 599 |
+
if progress_callback:
|
| 600 |
+
progress_callback(f"Preprocessing {i+1}/{len(labeled_samples)}: {sample.filename}")
|
| 601 |
+
|
| 602 |
+
# Step 1: Load and preprocess audio to stereo @ 48kHz
|
| 603 |
+
audio, sr = torchaudio.load(sample.audio_path)
|
| 604 |
+
|
| 605 |
+
# Resample if needed
|
| 606 |
+
if sr != target_sample_rate:
|
| 607 |
+
resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
|
| 608 |
+
audio = resampler(audio)
|
| 609 |
+
|
| 610 |
+
# Convert to stereo
|
| 611 |
+
if audio.shape[0] == 1:
|
| 612 |
+
audio = audio.repeat(2, 1)
|
| 613 |
+
elif audio.shape[0] > 2:
|
| 614 |
+
audio = audio[:2, :]
|
| 615 |
+
|
| 616 |
+
# Truncate to max duration
|
| 617 |
+
max_samples = int(max_duration * target_sample_rate)
|
| 618 |
+
if audio.shape[1] > max_samples:
|
| 619 |
+
audio = audio[:, :max_samples]
|
| 620 |
+
|
| 621 |
+
# Add batch dimension: [2, T] -> [1, 2, T]
|
| 622 |
+
audio = audio.unsqueeze(0).to(device).to(vae.dtype)
|
| 623 |
+
|
| 624 |
+
# Step 2: VAE encode audio to get target_latents
|
| 625 |
+
with torch.no_grad():
|
| 626 |
+
latent = vae.encode(audio).latent_dist.sample()
|
| 627 |
+
# [1, 64, T_latent] -> [1, T_latent, 64]
|
| 628 |
+
target_latents = latent.transpose(1, 2).to(dtype)
|
| 629 |
+
|
| 630 |
+
latent_length = target_latents.shape[1]
|
| 631 |
+
|
| 632 |
+
# Step 3: Create attention mask (all ones for valid audio)
|
| 633 |
+
attention_mask = torch.ones(1, latent_length, device=device, dtype=dtype)
|
| 634 |
+
|
| 635 |
+
# Step 4: Encode caption text
|
| 636 |
+
caption = sample.get_full_caption(self.metadata.tag_position)
|
| 637 |
+
text_inputs = text_tokenizer(
|
| 638 |
+
caption,
|
| 639 |
+
padding="max_length",
|
| 640 |
+
max_length=256,
|
| 641 |
+
truncation=True,
|
| 642 |
+
return_tensors="pt",
|
| 643 |
+
)
|
| 644 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
| 645 |
+
text_attention_mask = text_inputs.attention_mask.to(device).to(dtype)
|
| 646 |
+
|
| 647 |
+
with torch.no_grad():
|
| 648 |
+
text_outputs = text_encoder(text_input_ids)
|
| 649 |
+
text_hidden_states = text_outputs.last_hidden_state.to(dtype)
|
| 650 |
+
|
| 651 |
+
# Step 5: Encode lyrics
|
| 652 |
+
lyrics = sample.lyrics if sample.lyrics else "[Instrumental]"
|
| 653 |
+
lyric_inputs = text_tokenizer(
|
| 654 |
+
lyrics,
|
| 655 |
+
padding="max_length",
|
| 656 |
+
max_length=512,
|
| 657 |
+
truncation=True,
|
| 658 |
+
return_tensors="pt",
|
| 659 |
+
)
|
| 660 |
+
lyric_input_ids = lyric_inputs.input_ids.to(device)
|
| 661 |
+
lyric_attention_mask = lyric_inputs.attention_mask.to(device).to(dtype)
|
| 662 |
+
|
| 663 |
+
with torch.no_grad():
|
| 664 |
+
lyric_hidden_states = text_encoder.embed_tokens(lyric_input_ids).to(dtype)
|
| 665 |
+
|
| 666 |
+
# Step 6: Prepare refer_audio (empty for text2music)
|
| 667 |
+
# Create minimal refer_audio placeholder
|
| 668 |
+
refer_audio_hidden = torch.zeros(1, 1, 64, device=device, dtype=dtype)
|
| 669 |
+
refer_audio_order_mask = torch.zeros(1, device=device, dtype=torch.long)
|
| 670 |
+
|
| 671 |
+
# Step 7: Run model.encoder to get encoder_hidden_states
|
| 672 |
+
with torch.no_grad():
|
| 673 |
+
encoder_hidden_states, encoder_attention_mask = model.encoder(
|
| 674 |
+
text_hidden_states=text_hidden_states,
|
| 675 |
+
text_attention_mask=text_attention_mask,
|
| 676 |
+
lyric_hidden_states=lyric_hidden_states,
|
| 677 |
+
lyric_attention_mask=lyric_attention_mask,
|
| 678 |
+
refer_audio_acoustic_hidden_states_packed=refer_audio_hidden,
|
| 679 |
+
refer_audio_order_mask=refer_audio_order_mask,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
# Step 8: Build context_latents for text2music
|
| 683 |
+
# For text2music: src_latents = silence_latent, is_covers = 0
|
| 684 |
+
# chunk_masks: 1 = generate, 0 = keep original
|
| 685 |
+
# IMPORTANT: chunk_masks must have same shape as src_latents [B, T, 64]
|
| 686 |
+
# For text2music, we want to generate the entire audio, so chunk_masks = all 1s
|
| 687 |
+
src_latents = silence_latent[:, :latent_length, :].to(dtype)
|
| 688 |
+
if src_latents.shape[0] < 1:
|
| 689 |
+
src_latents = src_latents.expand(1, -1, -1)
|
| 690 |
+
|
| 691 |
+
# Pad or truncate silence_latent to match latent_length
|
| 692 |
+
if src_latents.shape[1] < latent_length:
|
| 693 |
+
pad_len = latent_length - src_latents.shape[1]
|
| 694 |
+
src_latents = torch.cat([
|
| 695 |
+
src_latents,
|
| 696 |
+
silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)
|
| 697 |
+
], dim=1)
|
| 698 |
+
elif src_latents.shape[1] > latent_length:
|
| 699 |
+
src_latents = src_latents[:, :latent_length, :]
|
| 700 |
+
|
| 701 |
+
# chunk_masks = 1 means "generate this region", 0 = keep original
|
| 702 |
+
# Shape must match src_latents: [B, T, 64] (NOT [B, T, 1])
|
| 703 |
+
# For text2music, generate everything -> all 1s with shape [1, T, 64]
|
| 704 |
+
chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
|
| 705 |
+
# context_latents = [src_latents, chunk_masks] -> [B, T, 128]
|
| 706 |
+
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
| 707 |
+
|
| 708 |
+
# Step 9: Save all tensors to .pt file (squeeze batch dimension for storage)
|
| 709 |
+
output_data = {
|
| 710 |
+
"target_latents": target_latents.squeeze(0).cpu(), # [T, 64]
|
| 711 |
+
"attention_mask": attention_mask.squeeze(0).cpu(), # [T]
|
| 712 |
+
"encoder_hidden_states": encoder_hidden_states.squeeze(0).cpu(), # [L, D]
|
| 713 |
+
"encoder_attention_mask": encoder_attention_mask.squeeze(0).cpu(), # [L]
|
| 714 |
+
"context_latents": context_latents.squeeze(0).cpu(), # [T, 65]
|
| 715 |
+
"metadata": {
|
| 716 |
+
"audio_path": sample.audio_path,
|
| 717 |
+
"filename": sample.filename,
|
| 718 |
+
"caption": caption,
|
| 719 |
+
"lyrics": lyrics,
|
| 720 |
+
"duration": sample.duration,
|
| 721 |
+
"bpm": sample.bpm,
|
| 722 |
+
"keyscale": sample.keyscale,
|
| 723 |
+
"timesignature": sample.timesignature,
|
| 724 |
+
"language": sample.language,
|
| 725 |
+
"is_instrumental": sample.is_instrumental,
|
| 726 |
+
}
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
# Save with sample ID as filename
|
| 730 |
+
output_path = os.path.join(output_dir, f"{sample.id}.pt")
|
| 731 |
+
torch.save(output_data, output_path)
|
| 732 |
+
output_paths.append(output_path)
|
| 733 |
+
success_count += 1
|
| 734 |
+
|
| 735 |
+
except Exception as e:
|
| 736 |
+
logger.exception(f"Error preprocessing {sample.filename}")
|
| 737 |
+
fail_count += 1
|
| 738 |
+
if progress_callback:
|
| 739 |
+
progress_callback(f"❌ Failed: {sample.filename}: {str(e)}")
|
| 740 |
+
|
| 741 |
+
# Save manifest file listing all preprocessed samples
|
| 742 |
+
manifest = {
|
| 743 |
+
"metadata": self.metadata.to_dict(),
|
| 744 |
+
"samples": output_paths,
|
| 745 |
+
"num_samples": len(output_paths),
|
| 746 |
+
}
|
| 747 |
+
manifest_path = os.path.join(output_dir, "manifest.json")
|
| 748 |
+
with open(manifest_path, 'w', encoding='utf-8') as f:
|
| 749 |
+
json.dump(manifest, f, indent=2)
|
| 750 |
+
|
| 751 |
+
status = f"✅ Preprocessed {success_count}/{len(labeled_samples)} samples to {output_dir}"
|
| 752 |
+
if fail_count > 0:
|
| 753 |
+
status += f" ({fail_count} failed)"
|
| 754 |
+
|
| 755 |
+
return output_paths, status
|
acestep/training/lora_utils.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA Utilities for ACE-Step
|
| 3 |
+
|
| 4 |
+
Provides utilities for injecting LoRA adapters into the DiT decoder model.
|
| 5 |
+
Uses PEFT (Parameter-Efficient Fine-Tuning) library for LoRA implementation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 10 |
+
from loguru import logger
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from peft import (
|
| 17 |
+
get_peft_model,
|
| 18 |
+
LoraConfig,
|
| 19 |
+
TaskType,
|
| 20 |
+
PeftModel,
|
| 21 |
+
PeftConfig,
|
| 22 |
+
)
|
| 23 |
+
PEFT_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
PEFT_AVAILABLE = False
|
| 26 |
+
logger.warning("PEFT library not installed. LoRA training will not be available.")
|
| 27 |
+
|
| 28 |
+
from acestep.training.configs import LoRAConfig
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check_peft_available() -> bool:
|
| 32 |
+
"""Check if PEFT library is available."""
|
| 33 |
+
return PEFT_AVAILABLE
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_dit_target_modules(model) -> List[str]:
|
| 37 |
+
"""Get the list of module names in the DiT decoder that can have LoRA applied.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model: The AceStepConditionGenerationModel
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
List of module names suitable for LoRA
|
| 44 |
+
"""
|
| 45 |
+
target_modules = []
|
| 46 |
+
|
| 47 |
+
# Focus on the decoder (DiT) attention layers
|
| 48 |
+
if hasattr(model, 'decoder'):
|
| 49 |
+
for name, module in model.decoder.named_modules():
|
| 50 |
+
# Target attention projection layers
|
| 51 |
+
if any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
|
| 52 |
+
if isinstance(module, nn.Linear):
|
| 53 |
+
target_modules.append(name)
|
| 54 |
+
|
| 55 |
+
return target_modules
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def freeze_non_lora_parameters(model, freeze_encoder: bool = True) -> None:
|
| 59 |
+
"""Freeze all non-LoRA parameters in the model.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
model: The model to freeze parameters for
|
| 63 |
+
freeze_encoder: Whether to freeze the encoder (condition encoder)
|
| 64 |
+
"""
|
| 65 |
+
# Freeze all parameters first
|
| 66 |
+
for param in model.parameters():
|
| 67 |
+
param.requires_grad = False
|
| 68 |
+
|
| 69 |
+
# Count frozen and trainable parameters
|
| 70 |
+
total_params = 0
|
| 71 |
+
trainable_params = 0
|
| 72 |
+
|
| 73 |
+
for name, param in model.named_parameters():
|
| 74 |
+
total_params += param.numel()
|
| 75 |
+
if param.requires_grad:
|
| 76 |
+
trainable_params += param.numel()
|
| 77 |
+
|
| 78 |
+
logger.info(f"Frozen parameters: {total_params - trainable_params:,}")
|
| 79 |
+
logger.info(f"Trainable parameters: {trainable_params:,}")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def inject_lora_into_dit(
|
| 83 |
+
model,
|
| 84 |
+
lora_config: LoRAConfig,
|
| 85 |
+
) -> Tuple[Any, Dict[str, Any]]:
|
| 86 |
+
"""Inject LoRA adapters into the DiT decoder of the model.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
model: The AceStepConditionGenerationModel
|
| 90 |
+
lora_config: LoRA configuration
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Tuple of (peft_model, info_dict)
|
| 94 |
+
"""
|
| 95 |
+
if not PEFT_AVAILABLE:
|
| 96 |
+
raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft")
|
| 97 |
+
|
| 98 |
+
# Get the decoder (DiT model)
|
| 99 |
+
decoder = model.decoder
|
| 100 |
+
|
| 101 |
+
# Create PEFT LoRA config
|
| 102 |
+
peft_lora_config = LoraConfig(
|
| 103 |
+
r=lora_config.r,
|
| 104 |
+
lora_alpha=lora_config.alpha,
|
| 105 |
+
lora_dropout=lora_config.dropout,
|
| 106 |
+
target_modules=lora_config.target_modules,
|
| 107 |
+
bias=lora_config.bias,
|
| 108 |
+
task_type=TaskType.FEATURE_EXTRACTION, # For diffusion models
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Apply LoRA to the decoder
|
| 112 |
+
peft_decoder = get_peft_model(decoder, peft_lora_config)
|
| 113 |
+
|
| 114 |
+
# Replace the decoder in the original model
|
| 115 |
+
model.decoder = peft_decoder
|
| 116 |
+
|
| 117 |
+
# Freeze all non-LoRA parameters
|
| 118 |
+
# Freeze encoder, tokenizer, detokenizer
|
| 119 |
+
for name, param in model.named_parameters():
|
| 120 |
+
# Only keep LoRA parameters trainable
|
| 121 |
+
if 'lora_' not in name:
|
| 122 |
+
param.requires_grad = False
|
| 123 |
+
|
| 124 |
+
# Count parameters
|
| 125 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 126 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 127 |
+
|
| 128 |
+
info = {
|
| 129 |
+
"total_params": total_params,
|
| 130 |
+
"trainable_params": trainable_params,
|
| 131 |
+
"trainable_ratio": trainable_params / total_params if total_params > 0 else 0,
|
| 132 |
+
"lora_r": lora_config.r,
|
| 133 |
+
"lora_alpha": lora_config.alpha,
|
| 134 |
+
"target_modules": lora_config.target_modules,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
logger.info(f"LoRA injected into DiT decoder:")
|
| 138 |
+
logger.info(f" Total parameters: {total_params:,}")
|
| 139 |
+
logger.info(f" Trainable parameters: {trainable_params:,} ({info['trainable_ratio']:.2%})")
|
| 140 |
+
logger.info(f" LoRA rank: {lora_config.r}, alpha: {lora_config.alpha}")
|
| 141 |
+
|
| 142 |
+
return model, info
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def save_lora_weights(
|
| 146 |
+
model,
|
| 147 |
+
output_dir: str,
|
| 148 |
+
save_full_model: bool = False,
|
| 149 |
+
) -> str:
|
| 150 |
+
"""Save LoRA adapter weights.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
model: Model with LoRA adapters
|
| 154 |
+
output_dir: Directory to save weights
|
| 155 |
+
save_full_model: Whether to save the full model state dict
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Path to saved weights
|
| 159 |
+
"""
|
| 160 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 161 |
+
|
| 162 |
+
if hasattr(model, 'decoder') and hasattr(model.decoder, 'save_pretrained'):
|
| 163 |
+
# Save PEFT adapter
|
| 164 |
+
adapter_path = os.path.join(output_dir, "adapter")
|
| 165 |
+
model.decoder.save_pretrained(adapter_path)
|
| 166 |
+
logger.info(f"LoRA adapter saved to {adapter_path}")
|
| 167 |
+
return adapter_path
|
| 168 |
+
elif save_full_model:
|
| 169 |
+
# Save full model state dict (larger file)
|
| 170 |
+
model_path = os.path.join(output_dir, "model.pt")
|
| 171 |
+
torch.save(model.state_dict(), model_path)
|
| 172 |
+
logger.info(f"Full model state dict saved to {model_path}")
|
| 173 |
+
return model_path
|
| 174 |
+
else:
|
| 175 |
+
# Extract only LoRA parameters
|
| 176 |
+
lora_state_dict = {}
|
| 177 |
+
for name, param in model.named_parameters():
|
| 178 |
+
if 'lora_' in name:
|
| 179 |
+
lora_state_dict[name] = param.data.clone()
|
| 180 |
+
|
| 181 |
+
if not lora_state_dict:
|
| 182 |
+
logger.warning("No LoRA parameters found to save!")
|
| 183 |
+
return ""
|
| 184 |
+
|
| 185 |
+
lora_path = os.path.join(output_dir, "lora_weights.pt")
|
| 186 |
+
torch.save(lora_state_dict, lora_path)
|
| 187 |
+
logger.info(f"LoRA weights saved to {lora_path}")
|
| 188 |
+
return lora_path
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def load_lora_weights(
|
| 192 |
+
model,
|
| 193 |
+
lora_path: str,
|
| 194 |
+
lora_config: Optional[LoRAConfig] = None,
|
| 195 |
+
) -> Any:
|
| 196 |
+
"""Load LoRA adapter weights into the model.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
model: The base model (without LoRA)
|
| 200 |
+
lora_path: Path to saved LoRA weights (adapter or .pt file)
|
| 201 |
+
lora_config: LoRA configuration (required if loading from .pt file)
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Model with LoRA weights loaded
|
| 205 |
+
"""
|
| 206 |
+
if not os.path.exists(lora_path):
|
| 207 |
+
raise FileNotFoundError(f"LoRA weights not found: {lora_path}")
|
| 208 |
+
|
| 209 |
+
# Check if it's a PEFT adapter directory
|
| 210 |
+
if os.path.isdir(lora_path):
|
| 211 |
+
if not PEFT_AVAILABLE:
|
| 212 |
+
raise ImportError("PEFT library is required to load adapter. Install with: pip install peft")
|
| 213 |
+
|
| 214 |
+
# Load PEFT adapter
|
| 215 |
+
peft_config = PeftConfig.from_pretrained(lora_path)
|
| 216 |
+
model.decoder = PeftModel.from_pretrained(model.decoder, lora_path)
|
| 217 |
+
logger.info(f"LoRA adapter loaded from {lora_path}")
|
| 218 |
+
|
| 219 |
+
elif lora_path.endswith('.pt'):
|
| 220 |
+
# Load from PyTorch state dict
|
| 221 |
+
if lora_config is None:
|
| 222 |
+
raise ValueError("lora_config is required when loading from .pt file")
|
| 223 |
+
|
| 224 |
+
# First inject LoRA structure
|
| 225 |
+
model, _ = inject_lora_into_dit(model, lora_config)
|
| 226 |
+
|
| 227 |
+
# Load weights
|
| 228 |
+
lora_state_dict = torch.load(lora_path, map_location='cpu')
|
| 229 |
+
|
| 230 |
+
# Load into model
|
| 231 |
+
model_state = model.state_dict()
|
| 232 |
+
for name, param in lora_state_dict.items():
|
| 233 |
+
if name in model_state:
|
| 234 |
+
model_state[name].copy_(param)
|
| 235 |
+
else:
|
| 236 |
+
logger.warning(f"Unexpected key in LoRA state dict: {name}")
|
| 237 |
+
|
| 238 |
+
logger.info(f"LoRA weights loaded from {lora_path}")
|
| 239 |
+
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError(f"Unsupported LoRA weight format: {lora_path}")
|
| 242 |
+
|
| 243 |
+
return model
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def merge_lora_weights(model) -> Any:
|
| 247 |
+
"""Merge LoRA weights into the base model.
|
| 248 |
+
|
| 249 |
+
This permanently integrates the LoRA adaptations into the model weights.
|
| 250 |
+
After merging, the model can be used without PEFT.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
model: Model with LoRA adapters
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
Model with merged weights
|
| 257 |
+
"""
|
| 258 |
+
if hasattr(model, 'decoder') and hasattr(model.decoder, 'merge_and_unload'):
|
| 259 |
+
# PEFT model - merge and unload
|
| 260 |
+
model.decoder = model.decoder.merge_and_unload()
|
| 261 |
+
logger.info("LoRA weights merged into base model")
|
| 262 |
+
else:
|
| 263 |
+
logger.warning("Model does not support LoRA merging")
|
| 264 |
+
|
| 265 |
+
return model
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def get_lora_info(model) -> Dict[str, Any]:
|
| 269 |
+
"""Get information about LoRA adapters in the model.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
model: Model to inspect
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Dictionary with LoRA information
|
| 276 |
+
"""
|
| 277 |
+
info = {
|
| 278 |
+
"has_lora": False,
|
| 279 |
+
"lora_params": 0,
|
| 280 |
+
"total_params": 0,
|
| 281 |
+
"modules_with_lora": [],
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
total_params = 0
|
| 285 |
+
lora_params = 0
|
| 286 |
+
lora_modules = []
|
| 287 |
+
|
| 288 |
+
for name, param in model.named_parameters():
|
| 289 |
+
total_params += param.numel()
|
| 290 |
+
if 'lora_' in name:
|
| 291 |
+
lora_params += param.numel()
|
| 292 |
+
# Extract module name
|
| 293 |
+
module_name = name.rsplit('.lora_', 1)[0]
|
| 294 |
+
if module_name not in lora_modules:
|
| 295 |
+
lora_modules.append(module_name)
|
| 296 |
+
|
| 297 |
+
info["total_params"] = total_params
|
| 298 |
+
info["lora_params"] = lora_params
|
| 299 |
+
info["has_lora"] = lora_params > 0
|
| 300 |
+
info["modules_with_lora"] = lora_modules
|
| 301 |
+
|
| 302 |
+
if total_params > 0:
|
| 303 |
+
info["lora_ratio"] = lora_params / total_params
|
| 304 |
+
|
| 305 |
+
return info
|
acestep/training/trainer.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA Trainer for ACE-Step
|
| 3 |
+
|
| 4 |
+
Lightning Fabric-based trainer for LoRA fine-tuning of ACE-Step DiT decoder.
|
| 5 |
+
Supports training from preprocessed tensor files for optimal performance.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from typing import Optional, List, Dict, Any, Tuple, Generator
|
| 11 |
+
from loguru import logger
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch.optim import AdamW
|
| 17 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from lightning.fabric import Fabric
|
| 21 |
+
from lightning.fabric.loggers import TensorBoardLogger
|
| 22 |
+
LIGHTNING_AVAILABLE = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
LIGHTNING_AVAILABLE = False
|
| 25 |
+
logger.warning("Lightning Fabric not installed. Training will use basic training loop.")
|
| 26 |
+
|
| 27 |
+
from acestep.training.configs import LoRAConfig, TrainingConfig
|
| 28 |
+
from acestep.training.lora_utils import inject_lora_into_dit, save_lora_weights, check_peft_available
|
| 29 |
+
from acestep.training.data_module import PreprocessedDataModule
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Turbo model shift=3.0 discrete timesteps (8 steps, same as inference)
|
| 33 |
+
TURBO_SHIFT3_TIMESTEPS = [1.0, 0.9545454545454546, 0.9, 0.8333333333333334, 0.75, 0.6428571428571429, 0.5, 0.3]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def sample_discrete_timestep(bsz, device, dtype):
|
| 37 |
+
"""Sample timesteps from discrete turbo shift=3 schedule.
|
| 38 |
+
|
| 39 |
+
For each sample in the batch, randomly select one of the 8 discrete timesteps
|
| 40 |
+
used by the turbo model with shift=3.0.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
bsz: Batch size
|
| 44 |
+
device: Device
|
| 45 |
+
dtype: Data type (should be bfloat16)
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (t, r) where both are the same sampled timestep
|
| 49 |
+
"""
|
| 50 |
+
# Randomly select indices for each sample in batch
|
| 51 |
+
indices = torch.randint(0, len(TURBO_SHIFT3_TIMESTEPS), (bsz,), device=device)
|
| 52 |
+
|
| 53 |
+
# Convert to tensor and index
|
| 54 |
+
timesteps_tensor = torch.tensor(TURBO_SHIFT3_TIMESTEPS, device=device, dtype=dtype)
|
| 55 |
+
t = timesteps_tensor[indices]
|
| 56 |
+
|
| 57 |
+
# r = t for this training setup
|
| 58 |
+
r = t
|
| 59 |
+
|
| 60 |
+
return t, r
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PreprocessedLoRAModule(nn.Module):
|
| 64 |
+
"""LoRA Training Module using preprocessed tensors.
|
| 65 |
+
|
| 66 |
+
This module trains only the DiT decoder with LoRA adapters.
|
| 67 |
+
All inputs are pre-computed tensors - no VAE or text encoder needed!
|
| 68 |
+
|
| 69 |
+
Training flow:
|
| 70 |
+
1. Load pre-computed tensors (target_latents, encoder_hidden_states, context_latents)
|
| 71 |
+
2. Sample noise and timestep
|
| 72 |
+
3. Forward through decoder (with LoRA)
|
| 73 |
+
4. Compute flow matching loss
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
model: nn.Module,
|
| 79 |
+
lora_config: LoRAConfig,
|
| 80 |
+
training_config: TrainingConfig,
|
| 81 |
+
device: torch.device,
|
| 82 |
+
dtype: torch.dtype,
|
| 83 |
+
):
|
| 84 |
+
"""Initialize the training module.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
model: The AceStepConditionGenerationModel
|
| 88 |
+
lora_config: LoRA configuration
|
| 89 |
+
training_config: Training configuration
|
| 90 |
+
device: Device to use
|
| 91 |
+
dtype: Data type to use
|
| 92 |
+
"""
|
| 93 |
+
super().__init__()
|
| 94 |
+
|
| 95 |
+
self.lora_config = lora_config
|
| 96 |
+
self.training_config = training_config
|
| 97 |
+
self.device = device
|
| 98 |
+
self.dtype = dtype
|
| 99 |
+
|
| 100 |
+
# Inject LoRA into the decoder only
|
| 101 |
+
if check_peft_available():
|
| 102 |
+
self.model, self.lora_info = inject_lora_into_dit(model, lora_config)
|
| 103 |
+
logger.info(f"LoRA injected: {self.lora_info['trainable_params']:,} trainable params")
|
| 104 |
+
else:
|
| 105 |
+
self.model = model
|
| 106 |
+
self.lora_info = {}
|
| 107 |
+
logger.warning("PEFT not available, training without LoRA adapters")
|
| 108 |
+
|
| 109 |
+
# Model config for flow matching
|
| 110 |
+
self.config = model.config
|
| 111 |
+
|
| 112 |
+
# Store training losses
|
| 113 |
+
self.training_losses = []
|
| 114 |
+
|
| 115 |
+
def training_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 116 |
+
"""Single training step using preprocessed tensors.
|
| 117 |
+
|
| 118 |
+
Note: This is a distilled turbo model, NO CFG is used.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
batch: Dictionary containing pre-computed tensors:
|
| 122 |
+
- target_latents: [B, T, 64] - VAE encoded audio
|
| 123 |
+
- attention_mask: [B, T] - Valid audio mask
|
| 124 |
+
- encoder_hidden_states: [B, L, D] - Condition encoder output
|
| 125 |
+
- encoder_attention_mask: [B, L] - Condition mask
|
| 126 |
+
- context_latents: [B, T, 128] - Source context
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Loss tensor (float32 for stable backward)
|
| 130 |
+
"""
|
| 131 |
+
# Use autocast for bf16 mixed precision training
|
| 132 |
+
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
| 133 |
+
# Get tensors from batch (already on device from Fabric dataloader)
|
| 134 |
+
target_latents = batch["target_latents"].to(self.device) # x0
|
| 135 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 136 |
+
encoder_hidden_states = batch["encoder_hidden_states"].to(self.device)
|
| 137 |
+
encoder_attention_mask = batch["encoder_attention_mask"].to(self.device)
|
| 138 |
+
context_latents = batch["context_latents"].to(self.device)
|
| 139 |
+
|
| 140 |
+
bsz = target_latents.shape[0]
|
| 141 |
+
|
| 142 |
+
# Flow matching: sample noise x1 and interpolate with data x0
|
| 143 |
+
x1 = torch.randn_like(target_latents) # Noise
|
| 144 |
+
x0 = target_latents # Data
|
| 145 |
+
|
| 146 |
+
# Sample timesteps from discrete turbo shift=3 schedule (8 steps)
|
| 147 |
+
t, r = sample_discrete_timestep(bsz, self.device, torch.bfloat16)
|
| 148 |
+
t_ = t.unsqueeze(-1).unsqueeze(-1)
|
| 149 |
+
|
| 150 |
+
# Interpolate: x_t = t * x1 + (1 - t) * x0
|
| 151 |
+
xt = t_ * x1 + (1.0 - t_) * x0
|
| 152 |
+
|
| 153 |
+
# Forward through decoder (distilled turbo model, no CFG)
|
| 154 |
+
decoder_outputs = self.model.decoder(
|
| 155 |
+
hidden_states=xt,
|
| 156 |
+
timestep=t,
|
| 157 |
+
timestep_r=t,
|
| 158 |
+
attention_mask=attention_mask,
|
| 159 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 160 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 161 |
+
context_latents=context_latents,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Flow matching loss: predict the flow field v = x1 - x0
|
| 165 |
+
flow = x1 - x0
|
| 166 |
+
diffusion_loss = F.mse_loss(decoder_outputs[0], flow)
|
| 167 |
+
|
| 168 |
+
# Convert loss to float32 for stable backward pass
|
| 169 |
+
diffusion_loss = diffusion_loss.float()
|
| 170 |
+
|
| 171 |
+
self.training_losses.append(diffusion_loss.item())
|
| 172 |
+
|
| 173 |
+
return diffusion_loss
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class LoRATrainer:
|
| 177 |
+
"""High-level trainer for ACE-Step LoRA fine-tuning.
|
| 178 |
+
|
| 179 |
+
Uses Lightning Fabric for distributed training and mixed precision.
|
| 180 |
+
Supports training from preprocessed tensor directories.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
dit_handler,
|
| 186 |
+
lora_config: LoRAConfig,
|
| 187 |
+
training_config: TrainingConfig,
|
| 188 |
+
):
|
| 189 |
+
"""Initialize the trainer.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
dit_handler: Initialized DiT handler (for model access)
|
| 193 |
+
lora_config: LoRA configuration
|
| 194 |
+
training_config: Training configuration
|
| 195 |
+
"""
|
| 196 |
+
self.dit_handler = dit_handler
|
| 197 |
+
self.lora_config = lora_config
|
| 198 |
+
self.training_config = training_config
|
| 199 |
+
|
| 200 |
+
self.module = None
|
| 201 |
+
self.fabric = None
|
| 202 |
+
self.is_training = False
|
| 203 |
+
|
| 204 |
+
def train_from_preprocessed(
|
| 205 |
+
self,
|
| 206 |
+
tensor_dir: str,
|
| 207 |
+
training_state: Optional[Dict] = None,
|
| 208 |
+
) -> Generator[Tuple[int, float, str], None, None]:
|
| 209 |
+
"""Train LoRA adapters from preprocessed tensor files.
|
| 210 |
+
|
| 211 |
+
This is the recommended training method for best performance.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
tensor_dir: Directory containing preprocessed .pt files
|
| 215 |
+
training_state: Optional state dict for stopping control
|
| 216 |
+
|
| 217 |
+
Yields:
|
| 218 |
+
Tuples of (step, loss, status_message)
|
| 219 |
+
"""
|
| 220 |
+
self.is_training = True
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
# Validate tensor directory
|
| 224 |
+
if not os.path.exists(tensor_dir):
|
| 225 |
+
yield 0, 0.0, f"❌ Tensor directory not found: {tensor_dir}"
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
# Create training module
|
| 229 |
+
self.module = PreprocessedLoRAModule(
|
| 230 |
+
model=self.dit_handler.model,
|
| 231 |
+
lora_config=self.lora_config,
|
| 232 |
+
training_config=self.training_config,
|
| 233 |
+
device=self.dit_handler.device,
|
| 234 |
+
dtype=self.dit_handler.dtype,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Create data module
|
| 238 |
+
data_module = PreprocessedDataModule(
|
| 239 |
+
tensor_dir=tensor_dir,
|
| 240 |
+
batch_size=self.training_config.batch_size,
|
| 241 |
+
num_workers=self.training_config.num_workers,
|
| 242 |
+
pin_memory=self.training_config.pin_memory,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Setup data
|
| 246 |
+
data_module.setup('fit')
|
| 247 |
+
|
| 248 |
+
if len(data_module.train_dataset) == 0:
|
| 249 |
+
yield 0, 0.0, "❌ No valid samples found in tensor directory"
|
| 250 |
+
return
|
| 251 |
+
|
| 252 |
+
yield 0, 0.0, f"📂 Loaded {len(data_module.train_dataset)} preprocessed samples"
|
| 253 |
+
|
| 254 |
+
if LIGHTNING_AVAILABLE:
|
| 255 |
+
yield from self._train_with_fabric(data_module, training_state)
|
| 256 |
+
else:
|
| 257 |
+
yield from self._train_basic(data_module, training_state)
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.exception("Training failed")
|
| 261 |
+
yield 0, 0.0, f"❌ Training failed: {str(e)}"
|
| 262 |
+
finally:
|
| 263 |
+
self.is_training = False
|
| 264 |
+
|
| 265 |
+
def _train_with_fabric(
|
| 266 |
+
self,
|
| 267 |
+
data_module: PreprocessedDataModule,
|
| 268 |
+
training_state: Optional[Dict],
|
| 269 |
+
) -> Generator[Tuple[int, float, str], None, None]:
|
| 270 |
+
"""Train using Lightning Fabric."""
|
| 271 |
+
# Create output directory
|
| 272 |
+
os.makedirs(self.training_config.output_dir, exist_ok=True)
|
| 273 |
+
|
| 274 |
+
# Force BFloat16 precision (only supported precision for this model)
|
| 275 |
+
precision = "bf16-mixed"
|
| 276 |
+
|
| 277 |
+
# Create TensorBoard logger
|
| 278 |
+
tb_logger = TensorBoardLogger(
|
| 279 |
+
root_dir=self.training_config.output_dir,
|
| 280 |
+
name="logs"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Initialize Fabric
|
| 284 |
+
self.fabric = Fabric(
|
| 285 |
+
accelerator="auto",
|
| 286 |
+
devices=1,
|
| 287 |
+
precision=precision,
|
| 288 |
+
loggers=[tb_logger],
|
| 289 |
+
)
|
| 290 |
+
self.fabric.launch()
|
| 291 |
+
|
| 292 |
+
yield 0, 0.0, f"🚀 Starting training (precision: {precision})..."
|
| 293 |
+
|
| 294 |
+
# Get dataloader
|
| 295 |
+
train_loader = data_module.train_dataloader()
|
| 296 |
+
|
| 297 |
+
# Setup optimizer - only LoRA parameters
|
| 298 |
+
trainable_params = [p for p in self.module.model.parameters() if p.requires_grad]
|
| 299 |
+
|
| 300 |
+
if not trainable_params:
|
| 301 |
+
yield 0, 0.0, "❌ No trainable parameters found!"
|
| 302 |
+
return
|
| 303 |
+
|
| 304 |
+
yield 0, 0.0, f"🎯 Training {sum(p.numel() for p in trainable_params):,} parameters"
|
| 305 |
+
|
| 306 |
+
optimizer = AdamW(
|
| 307 |
+
trainable_params,
|
| 308 |
+
lr=self.training_config.learning_rate,
|
| 309 |
+
weight_decay=self.training_config.weight_decay,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Calculate total steps
|
| 313 |
+
total_steps = len(train_loader) * self.training_config.max_epochs // self.training_config.gradient_accumulation_steps
|
| 314 |
+
warmup_steps = min(self.training_config.warmup_steps, max(1, total_steps // 10))
|
| 315 |
+
|
| 316 |
+
# Scheduler
|
| 317 |
+
warmup_scheduler = LinearLR(
|
| 318 |
+
optimizer,
|
| 319 |
+
start_factor=0.1,
|
| 320 |
+
end_factor=1.0,
|
| 321 |
+
total_iters=warmup_steps,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
main_scheduler = CosineAnnealingWarmRestarts(
|
| 325 |
+
optimizer,
|
| 326 |
+
T_0=max(1, total_steps - warmup_steps),
|
| 327 |
+
T_mult=1,
|
| 328 |
+
eta_min=self.training_config.learning_rate * 0.01,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
scheduler = SequentialLR(
|
| 332 |
+
optimizer,
|
| 333 |
+
schedulers=[warmup_scheduler, main_scheduler],
|
| 334 |
+
milestones=[warmup_steps],
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
# Convert model to bfloat16 (entire model for consistent dtype)
|
| 338 |
+
self.module.model = self.module.model.to(torch.bfloat16)
|
| 339 |
+
|
| 340 |
+
# Setup with Fabric - only the decoder (which has LoRA)
|
| 341 |
+
self.module.model.decoder, optimizer = self.fabric.setup(self.module.model.decoder, optimizer)
|
| 342 |
+
train_loader = self.fabric.setup_dataloaders(train_loader)
|
| 343 |
+
|
| 344 |
+
# Training loop
|
| 345 |
+
global_step = 0
|
| 346 |
+
accumulation_step = 0
|
| 347 |
+
accumulated_loss = 0.0
|
| 348 |
+
|
| 349 |
+
self.module.model.decoder.train()
|
| 350 |
+
|
| 351 |
+
for epoch in range(self.training_config.max_epochs):
|
| 352 |
+
epoch_loss = 0.0
|
| 353 |
+
num_batches = 0
|
| 354 |
+
epoch_start_time = time.time()
|
| 355 |
+
|
| 356 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 357 |
+
# Check for stop signal
|
| 358 |
+
if training_state and training_state.get("should_stop", False):
|
| 359 |
+
yield global_step, accumulated_loss / max(accumulation_step, 1), "⏹️ Training stopped by user"
|
| 360 |
+
return
|
| 361 |
+
|
| 362 |
+
# Forward pass
|
| 363 |
+
loss = self.module.training_step(batch)
|
| 364 |
+
loss = loss / self.training_config.gradient_accumulation_steps
|
| 365 |
+
|
| 366 |
+
# Backward pass
|
| 367 |
+
self.fabric.backward(loss)
|
| 368 |
+
accumulated_loss += loss.item()
|
| 369 |
+
accumulation_step += 1
|
| 370 |
+
|
| 371 |
+
# Optimizer step
|
| 372 |
+
if accumulation_step >= self.training_config.gradient_accumulation_steps:
|
| 373 |
+
self.fabric.clip_gradients(
|
| 374 |
+
self.module.model.decoder,
|
| 375 |
+
optimizer,
|
| 376 |
+
max_norm=self.training_config.max_grad_norm,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
optimizer.step()
|
| 380 |
+
scheduler.step()
|
| 381 |
+
optimizer.zero_grad()
|
| 382 |
+
|
| 383 |
+
global_step += 1
|
| 384 |
+
|
| 385 |
+
# Log
|
| 386 |
+
avg_loss = accumulated_loss / accumulation_step
|
| 387 |
+
self.fabric.log("train/loss", avg_loss, step=global_step)
|
| 388 |
+
self.fabric.log("train/lr", scheduler.get_last_lr()[0], step=global_step)
|
| 389 |
+
|
| 390 |
+
if global_step % self.training_config.log_every_n_steps == 0:
|
| 391 |
+
yield global_step, avg_loss, f"Epoch {epoch+1}/{self.training_config.max_epochs}, Step {global_step}, Loss: {avg_loss:.4f}"
|
| 392 |
+
|
| 393 |
+
epoch_loss += accumulated_loss
|
| 394 |
+
num_batches += 1
|
| 395 |
+
accumulated_loss = 0.0
|
| 396 |
+
accumulation_step = 0
|
| 397 |
+
|
| 398 |
+
# End of epoch
|
| 399 |
+
epoch_time = time.time() - epoch_start_time
|
| 400 |
+
avg_epoch_loss = epoch_loss / max(num_batches, 1)
|
| 401 |
+
|
| 402 |
+
self.fabric.log("train/epoch_loss", avg_epoch_loss, step=epoch + 1)
|
| 403 |
+
yield global_step, avg_epoch_loss, f"✅ Epoch {epoch+1}/{self.training_config.max_epochs} in {epoch_time:.1f}s, Loss: {avg_epoch_loss:.4f}"
|
| 404 |
+
|
| 405 |
+
# Save checkpoint
|
| 406 |
+
if (epoch + 1) % self.training_config.save_every_n_epochs == 0:
|
| 407 |
+
checkpoint_dir = os.path.join(self.training_config.output_dir, "checkpoints", f"epoch_{epoch+1}")
|
| 408 |
+
save_lora_weights(self.module.model, checkpoint_dir)
|
| 409 |
+
yield global_step, avg_epoch_loss, f"💾 Checkpoint saved at epoch {epoch+1}"
|
| 410 |
+
|
| 411 |
+
# Save final model
|
| 412 |
+
final_path = os.path.join(self.training_config.output_dir, "final")
|
| 413 |
+
save_lora_weights(self.module.model, final_path)
|
| 414 |
+
|
| 415 |
+
final_loss = self.module.training_losses[-1] if self.module.training_losses else 0.0
|
| 416 |
+
yield global_step, final_loss, f"✅ Training complete! LoRA saved to {final_path}"
|
| 417 |
+
|
| 418 |
+
def _train_basic(
|
| 419 |
+
self,
|
| 420 |
+
data_module: PreprocessedDataModule,
|
| 421 |
+
training_state: Optional[Dict],
|
| 422 |
+
) -> Generator[Tuple[int, float, str], None, None]:
|
| 423 |
+
"""Basic training loop without Fabric."""
|
| 424 |
+
yield 0, 0.0, "🚀 Starting basic training loop..."
|
| 425 |
+
|
| 426 |
+
os.makedirs(self.training_config.output_dir, exist_ok=True)
|
| 427 |
+
|
| 428 |
+
train_loader = data_module.train_dataloader()
|
| 429 |
+
|
| 430 |
+
trainable_params = [p for p in self.module.model.parameters() if p.requires_grad]
|
| 431 |
+
|
| 432 |
+
if not trainable_params:
|
| 433 |
+
yield 0, 0.0, "❌ No trainable parameters found!"
|
| 434 |
+
return
|
| 435 |
+
|
| 436 |
+
optimizer = AdamW(
|
| 437 |
+
trainable_params,
|
| 438 |
+
lr=self.training_config.learning_rate,
|
| 439 |
+
weight_decay=self.training_config.weight_decay,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
total_steps = len(train_loader) * self.training_config.max_epochs // self.training_config.gradient_accumulation_steps
|
| 443 |
+
warmup_steps = min(self.training_config.warmup_steps, max(1, total_steps // 10))
|
| 444 |
+
|
| 445 |
+
warmup_scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps)
|
| 446 |
+
main_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=max(1, total_steps - warmup_steps), T_mult=1, eta_min=self.training_config.learning_rate * 0.01)
|
| 447 |
+
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_steps])
|
| 448 |
+
|
| 449 |
+
global_step = 0
|
| 450 |
+
accumulation_step = 0
|
| 451 |
+
accumulated_loss = 0.0
|
| 452 |
+
|
| 453 |
+
self.module.model.decoder.train()
|
| 454 |
+
|
| 455 |
+
for epoch in range(self.training_config.max_epochs):
|
| 456 |
+
epoch_loss = 0.0
|
| 457 |
+
num_batches = 0
|
| 458 |
+
epoch_start_time = time.time()
|
| 459 |
+
|
| 460 |
+
for batch in train_loader:
|
| 461 |
+
if training_state and training_state.get("should_stop", False):
|
| 462 |
+
yield global_step, accumulated_loss / max(accumulation_step, 1), "⏹️ Training stopped"
|
| 463 |
+
return
|
| 464 |
+
|
| 465 |
+
loss = self.module.training_step(batch)
|
| 466 |
+
loss = loss / self.training_config.gradient_accumulation_steps
|
| 467 |
+
loss.backward()
|
| 468 |
+
accumulated_loss += loss.item()
|
| 469 |
+
accumulation_step += 1
|
| 470 |
+
|
| 471 |
+
if accumulation_step >= self.training_config.gradient_accumulation_steps:
|
| 472 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, self.training_config.max_grad_norm)
|
| 473 |
+
optimizer.step()
|
| 474 |
+
scheduler.step()
|
| 475 |
+
optimizer.zero_grad()
|
| 476 |
+
global_step += 1
|
| 477 |
+
|
| 478 |
+
if global_step % self.training_config.log_every_n_steps == 0:
|
| 479 |
+
avg_loss = accumulated_loss / accumulation_step
|
| 480 |
+
yield global_step, avg_loss, f"Epoch {epoch+1}, Step {global_step}, Loss: {avg_loss:.4f}"
|
| 481 |
+
|
| 482 |
+
epoch_loss += accumulated_loss
|
| 483 |
+
num_batches += 1
|
| 484 |
+
accumulated_loss = 0.0
|
| 485 |
+
accumulation_step = 0
|
| 486 |
+
|
| 487 |
+
epoch_time = time.time() - epoch_start_time
|
| 488 |
+
avg_epoch_loss = epoch_loss / max(num_batches, 1)
|
| 489 |
+
yield global_step, avg_epoch_loss, f"✅ Epoch {epoch+1}/{self.training_config.max_epochs} in {epoch_time:.1f}s"
|
| 490 |
+
|
| 491 |
+
if (epoch + 1) % self.training_config.save_every_n_epochs == 0:
|
| 492 |
+
checkpoint_dir = os.path.join(self.training_config.output_dir, "checkpoints", f"epoch_{epoch+1}")
|
| 493 |
+
save_lora_weights(self.module.model, checkpoint_dir)
|
| 494 |
+
yield global_step, avg_epoch_loss, f"💾 Checkpoint saved"
|
| 495 |
+
|
| 496 |
+
final_path = os.path.join(self.training_config.output_dir, "final")
|
| 497 |
+
save_lora_weights(self.module.model, final_path)
|
| 498 |
+
final_loss = self.module.training_losses[-1] if self.module.training_losses else 0.0
|
| 499 |
+
yield global_step, final_loss, f"✅ Training complete! LoRA saved to {final_path}"
|
| 500 |
+
|
| 501 |
+
def stop(self):
|
| 502 |
+
"""Stop training."""
|
| 503 |
+
self.is_training = False
|
requirements.txt
CHANGED
|
@@ -24,6 +24,10 @@ numba>=0.63.1
|
|
| 24 |
vector-quantize-pytorch>=1.27.15
|
| 25 |
torchcodec>=0.9.1
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# nano-vllm dependencies
|
| 28 |
triton-windows>=3.0.0,<3.4; sys_platform == 'win32'
|
| 29 |
triton>=3.0.0; sys_platform != 'win32'
|
|
|
|
| 24 |
vector-quantize-pytorch>=1.27.15
|
| 25 |
torchcodec>=0.9.1
|
| 26 |
|
| 27 |
+
# LoRA Training dependencies (optional)
|
| 28 |
+
peft>=0.7.0
|
| 29 |
+
lightning>=2.0.0
|
| 30 |
+
|
| 31 |
# nano-vllm dependencies
|
| 32 |
triton-windows>=3.0.0,<3.4; sys_platform == 'win32'
|
| 33 |
triton>=3.0.0; sys_platform != 'win32'
|