ChuxiJ commited on
Commit
3df46a2
·
1 Parent(s): 4f80622

support lora trianing & inter

Browse files
.gitignore CHANGED
@@ -221,4 +221,6 @@ feishu_bot/
221
  tmp*
222
  torchinductor_root/
223
  scripts/
224
- checkpoints_legacy/
 
 
 
221
  tmp*
222
  torchinductor_root/
223
  scripts/
224
+ checkpoints_legacy/
225
+ lora_output/
226
+ datasets/
acestep/gradio_ui/events/__init__.py CHANGED
@@ -8,6 +8,7 @@ from typing import Optional
8
  # Import handler modules
9
  from . import generation_handlers as gen_h
10
  from . import results_handlers as res_h
 
11
  from acestep.gradio_ui.i18n import t
12
 
13
 
@@ -69,6 +70,32 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
69
  ]
70
  )
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # ========== UI Visibility Updates ==========
73
  generation_section["init_llm_checkbox"].change(
74
  fn=gen_h.update_negative_prompt_visibility,
@@ -859,3 +886,244 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
859
  results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
860
  ]
861
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # Import handler modules
9
  from . import generation_handlers as gen_h
10
  from . import results_handlers as res_h
11
+ from . import training_handlers as train_h
12
  from acestep.gradio_ui.i18n import t
13
 
14
 
 
70
  ]
71
  )
72
 
73
+ # ========== LoRA Handlers ==========
74
+ generation_section["load_lora_btn"].click(
75
+ fn=dit_handler.load_lora,
76
+ inputs=[generation_section["lora_path"]],
77
+ outputs=[generation_section["lora_status"]]
78
+ ).then(
79
+ # Update checkbox to enabled state after loading
80
+ fn=lambda: gr.update(value=True),
81
+ outputs=[generation_section["use_lora_checkbox"]]
82
+ )
83
+
84
+ generation_section["unload_lora_btn"].click(
85
+ fn=dit_handler.unload_lora,
86
+ outputs=[generation_section["lora_status"]]
87
+ ).then(
88
+ # Update checkbox to disabled state after unloading
89
+ fn=lambda: gr.update(value=False),
90
+ outputs=[generation_section["use_lora_checkbox"]]
91
+ )
92
+
93
+ generation_section["use_lora_checkbox"].change(
94
+ fn=dit_handler.set_use_lora,
95
+ inputs=[generation_section["use_lora_checkbox"]],
96
+ outputs=[generation_section["lora_status"]]
97
+ )
98
+
99
  # ========== UI Visibility Updates ==========
100
  generation_section["init_llm_checkbox"].change(
101
  fn=gen_h.update_negative_prompt_visibility,
 
886
  results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
887
  ]
888
  )
889
+
890
+
891
+ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
892
+ """Setup event handlers for the training tab (dataset builder and LoRA training)"""
893
+
894
+ # ========== Load Existing Dataset (Top Section) ==========
895
+
896
+ # Load existing dataset JSON at the top of Dataset Builder
897
+ training_section["load_json_btn"].click(
898
+ fn=train_h.load_existing_dataset_for_preprocess,
899
+ inputs=[
900
+ training_section["load_json_path"],
901
+ training_section["dataset_builder_state"],
902
+ ],
903
+ outputs=[
904
+ training_section["load_json_status"],
905
+ training_section["audio_files_table"],
906
+ training_section["sample_selector"],
907
+ training_section["dataset_builder_state"],
908
+ # Also update preview fields with first sample
909
+ training_section["preview_audio"],
910
+ training_section["preview_filename"],
911
+ training_section["edit_caption"],
912
+ training_section["edit_lyrics"],
913
+ training_section["edit_bpm"],
914
+ training_section["edit_keyscale"],
915
+ training_section["edit_timesig"],
916
+ training_section["edit_duration"],
917
+ training_section["edit_language"],
918
+ training_section["edit_instrumental"],
919
+ ]
920
+ )
921
+
922
+ # ========== Dataset Builder Handlers ==========
923
+
924
+ # Scan directory for audio files
925
+ training_section["scan_btn"].click(
926
+ fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
927
+ dir, name, tag, pos, instr, state
928
+ ),
929
+ inputs=[
930
+ training_section["audio_directory"],
931
+ training_section["dataset_name"],
932
+ training_section["custom_tag"],
933
+ training_section["tag_position"],
934
+ training_section["all_instrumental"],
935
+ training_section["dataset_builder_state"],
936
+ ],
937
+ outputs=[
938
+ training_section["audio_files_table"],
939
+ training_section["scan_status"],
940
+ training_section["sample_selector"],
941
+ training_section["dataset_builder_state"],
942
+ ]
943
+ )
944
+
945
+ # Auto-label all samples
946
+ training_section["auto_label_btn"].click(
947
+ fn=lambda state, skip: train_h.auto_label_all(dit_handler, llm_handler, state, skip),
948
+ inputs=[
949
+ training_section["dataset_builder_state"],
950
+ training_section["skip_metas"],
951
+ ],
952
+ outputs=[
953
+ training_section["audio_files_table"],
954
+ training_section["label_progress"],
955
+ training_section["dataset_builder_state"],
956
+ ]
957
+ )
958
+
959
+ # Sample selector change - update preview
960
+ training_section["sample_selector"].change(
961
+ fn=train_h.get_sample_preview,
962
+ inputs=[
963
+ training_section["sample_selector"],
964
+ training_section["dataset_builder_state"],
965
+ ],
966
+ outputs=[
967
+ training_section["preview_audio"],
968
+ training_section["preview_filename"],
969
+ training_section["edit_caption"],
970
+ training_section["edit_lyrics"],
971
+ training_section["edit_bpm"],
972
+ training_section["edit_keyscale"],
973
+ training_section["edit_timesig"],
974
+ training_section["edit_duration"],
975
+ training_section["edit_language"],
976
+ training_section["edit_instrumental"],
977
+ ]
978
+ )
979
+
980
+ # Save sample edit
981
+ training_section["save_edit_btn"].click(
982
+ fn=train_h.save_sample_edit,
983
+ inputs=[
984
+ training_section["sample_selector"],
985
+ training_section["edit_caption"],
986
+ training_section["edit_lyrics"],
987
+ training_section["edit_bpm"],
988
+ training_section["edit_keyscale"],
989
+ training_section["edit_timesig"],
990
+ training_section["edit_language"],
991
+ training_section["edit_instrumental"],
992
+ training_section["dataset_builder_state"],
993
+ ],
994
+ outputs=[
995
+ training_section["audio_files_table"],
996
+ training_section["edit_status"],
997
+ training_section["dataset_builder_state"],
998
+ ]
999
+ )
1000
+
1001
+ # Update settings when changed
1002
+ for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"]]:
1003
+ trigger.change(
1004
+ fn=train_h.update_settings,
1005
+ inputs=[
1006
+ training_section["custom_tag"],
1007
+ training_section["tag_position"],
1008
+ training_section["all_instrumental"],
1009
+ training_section["dataset_builder_state"],
1010
+ ],
1011
+ outputs=[training_section["dataset_builder_state"]]
1012
+ )
1013
+
1014
+ # Save dataset
1015
+ training_section["save_dataset_btn"].click(
1016
+ fn=train_h.save_dataset,
1017
+ inputs=[
1018
+ training_section["save_path"],
1019
+ training_section["dataset_name"],
1020
+ training_section["dataset_builder_state"],
1021
+ ],
1022
+ outputs=[training_section["save_status"]]
1023
+ )
1024
+
1025
+ # ========== Preprocess Handlers ==========
1026
+
1027
+ # Load existing dataset JSON for preprocessing
1028
+ # This also updates the preview section so users can view/edit samples
1029
+ training_section["load_existing_dataset_btn"].click(
1030
+ fn=train_h.load_existing_dataset_for_preprocess,
1031
+ inputs=[
1032
+ training_section["load_existing_dataset_path"],
1033
+ training_section["dataset_builder_state"],
1034
+ ],
1035
+ outputs=[
1036
+ training_section["load_existing_status"],
1037
+ training_section["audio_files_table"],
1038
+ training_section["sample_selector"],
1039
+ training_section["dataset_builder_state"],
1040
+ # Also update preview fields with first sample
1041
+ training_section["preview_audio"],
1042
+ training_section["preview_filename"],
1043
+ training_section["edit_caption"],
1044
+ training_section["edit_lyrics"],
1045
+ training_section["edit_bpm"],
1046
+ training_section["edit_keyscale"],
1047
+ training_section["edit_timesig"],
1048
+ training_section["edit_duration"],
1049
+ training_section["edit_language"],
1050
+ training_section["edit_instrumental"],
1051
+ ]
1052
+ )
1053
+
1054
+ # Preprocess dataset to tensor files
1055
+ training_section["preprocess_btn"].click(
1056
+ fn=lambda output_dir, state: train_h.preprocess_dataset(
1057
+ output_dir, dit_handler, state
1058
+ ),
1059
+ inputs=[
1060
+ training_section["preprocess_output_dir"],
1061
+ training_section["dataset_builder_state"],
1062
+ ],
1063
+ outputs=[training_section["preprocess_progress"]]
1064
+ )
1065
+
1066
+ # ========== Training Tab Handlers ==========
1067
+
1068
+ # Load preprocessed tensor dataset
1069
+ training_section["load_dataset_btn"].click(
1070
+ fn=train_h.load_training_dataset,
1071
+ inputs=[training_section["training_tensor_dir"]],
1072
+ outputs=[training_section["training_dataset_info"]]
1073
+ )
1074
+
1075
+ # Start training from preprocessed tensors
1076
+ def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts):
1077
+ try:
1078
+ for progress, log, plot, state in train_h.start_training(
1079
+ tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts
1080
+ ):
1081
+ yield progress, log, plot, state
1082
+ except Exception as e:
1083
+ logger.exception("Training wrapper error")
1084
+ yield f"❌ Error: {str(e)}", str(e), None, ts
1085
+
1086
+ training_section["start_training_btn"].click(
1087
+ fn=training_wrapper,
1088
+ inputs=[
1089
+ training_section["training_tensor_dir"],
1090
+ training_section["lora_rank"],
1091
+ training_section["lora_alpha"],
1092
+ training_section["lora_dropout"],
1093
+ training_section["learning_rate"],
1094
+ training_section["train_epochs"],
1095
+ training_section["train_batch_size"],
1096
+ training_section["gradient_accumulation"],
1097
+ training_section["save_every_n_epochs"],
1098
+ training_section["training_shift"],
1099
+ training_section["training_seed"],
1100
+ training_section["lora_output_dir"],
1101
+ training_section["training_state"],
1102
+ ],
1103
+ outputs=[
1104
+ training_section["training_progress"],
1105
+ training_section["training_log"],
1106
+ training_section["training_loss_plot"],
1107
+ training_section["training_state"],
1108
+ ]
1109
+ )
1110
+
1111
+ # Stop training
1112
+ training_section["stop_training_btn"].click(
1113
+ fn=train_h.stop_training,
1114
+ inputs=[training_section["training_state"]],
1115
+ outputs=[
1116
+ training_section["training_progress"],
1117
+ training_section["training_state"],
1118
+ ]
1119
+ )
1120
+
1121
+ # Export LoRA
1122
+ training_section["export_lora_btn"].click(
1123
+ fn=train_h.export_lora,
1124
+ inputs=[
1125
+ training_section["export_path"],
1126
+ training_section["lora_output_dir"],
1127
+ ],
1128
+ outputs=[training_section["export_status"]]
1129
+ )
acestep/gradio_ui/events/training_handlers.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Event Handlers for Training Tab
3
+
4
+ Contains all event handler functions for the dataset builder and training UI.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from typing import Any, Dict, List, Tuple, Optional
10
+ from loguru import logger
11
+ import gradio as gr
12
+
13
+ from acestep.training.dataset_builder import DatasetBuilder, AudioSample
14
+
15
+
16
+ def create_dataset_builder() -> DatasetBuilder:
17
+ """Create a new DatasetBuilder instance."""
18
+ return DatasetBuilder()
19
+
20
+
21
+ def scan_directory(
22
+ audio_dir: str,
23
+ dataset_name: str,
24
+ custom_tag: str,
25
+ tag_position: str,
26
+ all_instrumental: bool,
27
+ builder_state: Optional[DatasetBuilder],
28
+ ) -> Tuple[Any, str, Any, DatasetBuilder]:
29
+ """Scan a directory for audio files.
30
+
31
+ Returns:
32
+ Tuple of (table_data, status, slider_update, builder_state)
33
+ """
34
+ if not audio_dir or not audio_dir.strip():
35
+ return [], "❌ Please enter a directory path", gr.Slider(maximum=0, value=0), builder_state
36
+
37
+ # Create or use existing builder
38
+ builder = builder_state if builder_state else DatasetBuilder()
39
+
40
+ # Set metadata before scanning
41
+ builder.metadata.name = dataset_name
42
+ builder.metadata.custom_tag = custom_tag
43
+ builder.metadata.tag_position = tag_position
44
+ builder.metadata.all_instrumental = all_instrumental
45
+
46
+ # Scan directory
47
+ samples, status = builder.scan_directory(audio_dir.strip())
48
+
49
+ if not samples:
50
+ return [], status, gr.Slider(maximum=0, value=0), builder
51
+
52
+ # Set instrumental and tag for all samples
53
+ builder.set_all_instrumental(all_instrumental)
54
+ if custom_tag:
55
+ builder.set_custom_tag(custom_tag, tag_position)
56
+
57
+ # Get table data
58
+ table_data = builder.get_samples_dataframe_data()
59
+
60
+ # Calculate slider max and return as Slider update
61
+ slider_max = max(0, len(samples) - 1)
62
+
63
+ return table_data, status, gr.Slider(maximum=slider_max, value=0), builder
64
+
65
+
66
+ def auto_label_all(
67
+ dit_handler,
68
+ llm_handler,
69
+ builder_state: Optional[DatasetBuilder],
70
+ skip_metas: bool = False,
71
+ progress=None,
72
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
73
+ """Auto-label all samples in the dataset.
74
+
75
+ Args:
76
+ dit_handler: DiT handler for audio processing
77
+ llm_handler: LLM handler for caption generation
78
+ builder_state: Dataset builder state
79
+ skip_metas: If True, skip LLM labeling. BPM/Key/TimeSig = N/A, Language = unknown for instrumental
80
+ progress: Progress callback
81
+
82
+ Returns:
83
+ Tuple of (table_data, status, builder_state)
84
+ """
85
+ if builder_state is None:
86
+ return [], "❌ Please scan a directory first", builder_state
87
+
88
+ if not builder_state.samples:
89
+ return [], "❌ No samples to label. Please scan a directory first.", builder_state
90
+
91
+ # If skip_metas is True, just set default values without LLM
92
+ if skip_metas:
93
+ for sample in builder_state.samples:
94
+ sample.bpm = None # Will display as N/A
95
+ sample.keyscale = "N/A"
96
+ sample.timesignature = "N/A"
97
+ # For instrumental, language should be "unknown"
98
+ if sample.is_instrumental:
99
+ sample.language = "unknown"
100
+ else:
101
+ sample.language = "unknown"
102
+ # Use custom tag as caption if set, otherwise use filename
103
+ if builder_state.metadata.custom_tag:
104
+ sample.caption = builder_state.metadata.custom_tag
105
+ else:
106
+ sample.caption = sample.filename
107
+
108
+ table_data = builder_state.get_samples_dataframe_data()
109
+ return table_data, f"✅ Skipped AI labeling. {len(builder_state.samples)} samples set with default values.", builder_state
110
+
111
+ # Check if handlers are initialized
112
+ if dit_handler is None or dit_handler.model is None:
113
+ return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
114
+
115
+ if llm_handler is None or not llm_handler.llm_initialized:
116
+ return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
117
+
118
+ def progress_callback(msg):
119
+ if progress:
120
+ try:
121
+ progress(msg)
122
+ except:
123
+ pass
124
+
125
+ # Label all samples
126
+ samples, status = builder_state.label_all_samples(
127
+ dit_handler=dit_handler,
128
+ llm_handler=llm_handler,
129
+ progress_callback=progress_callback,
130
+ )
131
+
132
+ # Get updated table data
133
+ table_data = builder_state.get_samples_dataframe_data()
134
+
135
+ return table_data, status, builder_state
136
+
137
+
138
+ def get_sample_preview(
139
+ sample_idx: int,
140
+ builder_state: Optional[DatasetBuilder],
141
+ ) -> Tuple[str, str, str, str, Optional[int], str, str, float, str, bool]:
142
+ """Get preview data for a specific sample.
143
+
144
+ Returns:
145
+ Tuple of (audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
146
+ """
147
+ if builder_state is None or not builder_state.samples:
148
+ return None, "", "", "", None, "", "", 0.0, "instrumental", True
149
+
150
+ idx = int(sample_idx)
151
+ if idx < 0 or idx >= len(builder_state.samples):
152
+ return None, "", "", "", None, "", "", 0.0, "instrumental", True
153
+
154
+ sample = builder_state.samples[idx]
155
+
156
+ return (
157
+ sample.audio_path,
158
+ sample.filename,
159
+ sample.caption,
160
+ sample.lyrics,
161
+ sample.bpm,
162
+ sample.keyscale,
163
+ sample.timesignature,
164
+ sample.duration,
165
+ sample.language,
166
+ sample.is_instrumental,
167
+ )
168
+
169
+
170
+ def save_sample_edit(
171
+ sample_idx: int,
172
+ caption: str,
173
+ lyrics: str,
174
+ bpm: Optional[int],
175
+ keyscale: str,
176
+ timesig: str,
177
+ language: str,
178
+ is_instrumental: bool,
179
+ builder_state: Optional[DatasetBuilder],
180
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
181
+ """Save edits to a sample.
182
+
183
+ Returns:
184
+ Tuple of (table_data, status, builder_state)
185
+ """
186
+ if builder_state is None:
187
+ return [], "❌ No dataset loaded", builder_state
188
+
189
+ idx = int(sample_idx)
190
+
191
+ # Update sample
192
+ sample, status = builder_state.update_sample(
193
+ idx,
194
+ caption=caption,
195
+ lyrics=lyrics if not is_instrumental else "[Instrumental]",
196
+ bpm=int(bpm) if bpm else None,
197
+ keyscale=keyscale,
198
+ timesignature=timesig,
199
+ language="instrumental" if is_instrumental else language,
200
+ is_instrumental=is_instrumental,
201
+ labeled=True,
202
+ )
203
+
204
+ # Get updated table data
205
+ table_data = builder_state.get_samples_dataframe_data()
206
+
207
+ return table_data, status, builder_state
208
+
209
+
210
+ def update_settings(
211
+ custom_tag: str,
212
+ tag_position: str,
213
+ all_instrumental: bool,
214
+ builder_state: Optional[DatasetBuilder],
215
+ ) -> DatasetBuilder:
216
+ """Update dataset settings.
217
+
218
+ Returns:
219
+ Updated builder_state
220
+ """
221
+ if builder_state is None:
222
+ return builder_state
223
+
224
+ if custom_tag:
225
+ builder_state.set_custom_tag(custom_tag, tag_position)
226
+
227
+ builder_state.set_all_instrumental(all_instrumental)
228
+
229
+ return builder_state
230
+
231
+
232
+ def save_dataset(
233
+ save_path: str,
234
+ dataset_name: str,
235
+ builder_state: Optional[DatasetBuilder],
236
+ ) -> str:
237
+ """Save the dataset to a JSON file.
238
+
239
+ Returns:
240
+ Status message
241
+ """
242
+ if builder_state is None:
243
+ return "❌ No dataset to save. Please scan a directory first."
244
+
245
+ if not builder_state.samples:
246
+ return "❌ No samples in dataset."
247
+
248
+ if not save_path or not save_path.strip():
249
+ return "❌ Please enter a save path."
250
+
251
+ # Check if any samples are labeled
252
+ labeled_count = builder_state.get_labeled_count()
253
+ if labeled_count == 0:
254
+ return "⚠️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway..."
255
+
256
+ return builder_state.save_dataset(save_path.strip(), dataset_name)
257
+
258
+
259
+ def load_existing_dataset_for_preprocess(
260
+ dataset_path: str,
261
+ builder_state: Optional[DatasetBuilder],
262
+ ) -> Tuple[str, Any, Any, DatasetBuilder, str, str, str, str, Optional[int], str, str, float, str, bool]:
263
+ """Load an existing dataset JSON file for preprocessing.
264
+
265
+ This allows users to load a previously saved dataset and proceed to preprocessing
266
+ without having to re-scan and re-label.
267
+
268
+ Returns:
269
+ Tuple of (status, table_data, slider_update, builder_state,
270
+ audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
271
+ """
272
+ empty_preview = (None, "", "", "", None, "", "", 0.0, "instrumental", True)
273
+
274
+ if not dataset_path or not dataset_path.strip():
275
+ return ("❌ Please enter a dataset path", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
276
+
277
+ dataset_path = dataset_path.strip()
278
+
279
+ if not os.path.exists(dataset_path):
280
+ return (f"❌ Dataset not found: {dataset_path}", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
281
+
282
+ # Create new builder (don't reuse old state when loading a file)
283
+ builder = DatasetBuilder()
284
+
285
+ # Load the dataset
286
+ samples, status = builder.load_dataset(dataset_path)
287
+
288
+ if not samples:
289
+ return (status, [], gr.Slider(maximum=0, value=0), builder) + empty_preview
290
+
291
+ # Get table data
292
+ table_data = builder.get_samples_dataframe_data()
293
+
294
+ # Calculate slider max
295
+ slider_max = max(0, len(samples) - 1)
296
+
297
+ # Create info text
298
+ labeled_count = builder.get_labeled_count()
299
+ info = f"✅ Loaded dataset: {builder.metadata.name}\n"
300
+ info += f"📊 Samples: {len(samples)} ({labeled_count} labeled)\n"
301
+ info += f"🏷️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
302
+ info += "📝 Ready for preprocessing! You can also edit samples below."
303
+
304
+ # Get first sample preview
305
+ first_sample = builder.samples[0]
306
+ preview = (
307
+ first_sample.audio_path,
308
+ first_sample.filename,
309
+ first_sample.caption,
310
+ first_sample.lyrics,
311
+ first_sample.bpm,
312
+ first_sample.keyscale,
313
+ first_sample.timesignature,
314
+ first_sample.duration,
315
+ first_sample.language,
316
+ first_sample.is_instrumental,
317
+ )
318
+
319
+ return (info, table_data, gr.Slider(maximum=slider_max, value=0), builder) + preview
320
+
321
+
322
+ def preprocess_dataset(
323
+ output_dir: str,
324
+ dit_handler,
325
+ builder_state: Optional[DatasetBuilder],
326
+ progress=None,
327
+ ) -> str:
328
+ """Preprocess dataset to tensor files for fast training.
329
+
330
+ This converts audio files to VAE latents and text to embeddings.
331
+
332
+ Returns:
333
+ Status message
334
+ """
335
+ if builder_state is None:
336
+ return "❌ No dataset loaded. Please scan a directory first."
337
+
338
+ if not builder_state.samples:
339
+ return "❌ No samples in dataset."
340
+
341
+ labeled_count = builder_state.get_labeled_count()
342
+ if labeled_count == 0:
343
+ return "❌ No labeled samples. Please auto-label or manually label samples first."
344
+
345
+ if not output_dir or not output_dir.strip():
346
+ return "❌ Please enter an output directory."
347
+
348
+ if dit_handler is None or dit_handler.model is None:
349
+ return "❌ Model not initialized. Please initialize the service first."
350
+
351
+ def progress_callback(msg):
352
+ if progress:
353
+ try:
354
+ progress(msg)
355
+ except:
356
+ pass
357
+
358
+ # Run preprocessing
359
+ output_paths, status = builder_state.preprocess_to_tensors(
360
+ dit_handler=dit_handler,
361
+ output_dir=output_dir.strip(),
362
+ progress_callback=progress_callback,
363
+ )
364
+
365
+ return status
366
+
367
+
368
+ def load_training_dataset(
369
+ tensor_dir: str,
370
+ ) -> str:
371
+ """Load a preprocessed tensor dataset for training.
372
+
373
+ Returns:
374
+ Info text about the dataset
375
+ """
376
+ if not tensor_dir or not tensor_dir.strip():
377
+ return "❌ Please enter a tensor directory path"
378
+
379
+ tensor_dir = tensor_dir.strip()
380
+
381
+ if not os.path.exists(tensor_dir):
382
+ return f"❌ Directory not found: {tensor_dir}"
383
+
384
+ if not os.path.isdir(tensor_dir):
385
+ return f"❌ Not a directory: {tensor_dir}"
386
+
387
+ # Check for manifest
388
+ manifest_path = os.path.join(tensor_dir, "manifest.json")
389
+ if os.path.exists(manifest_path):
390
+ try:
391
+ with open(manifest_path, 'r') as f:
392
+ manifest = json.load(f)
393
+
394
+ num_samples = manifest.get("num_samples", 0)
395
+ metadata = manifest.get("metadata", {})
396
+ name = metadata.get("name", "Unknown")
397
+ custom_tag = metadata.get("custom_tag", "")
398
+
399
+ info = f"✅ Loaded preprocessed dataset: {name}\n"
400
+ info += f"📊 Samples: {num_samples} preprocessed tensors\n"
401
+ info += f"🏷️ Custom Tag: {custom_tag or '(none)'}"
402
+
403
+ return info
404
+ except Exception as e:
405
+ logger.warning(f"Failed to read manifest: {e}")
406
+
407
+ # Fallback: count .pt files
408
+ pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
409
+
410
+ if not pt_files:
411
+ return f"❌ No .pt tensor files found in {tensor_dir}"
412
+
413
+ info = f"✅ Found {len(pt_files)} tensor files in {tensor_dir}\n"
414
+ info += "⚠️ No manifest.json found - using all .pt files"
415
+
416
+ return info
417
+
418
+
419
+ # Training handlers
420
+
421
+ import time
422
+ import re
423
+
424
+
425
+ def _format_duration(seconds):
426
+ """Format seconds to human readable string."""
427
+ seconds = int(seconds)
428
+ if seconds < 60:
429
+ return f"{seconds}s"
430
+ elif seconds < 3600:
431
+ return f"{seconds // 60}m {seconds % 60}s"
432
+ else:
433
+ return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
434
+
435
+
436
+ def start_training(
437
+ tensor_dir: str,
438
+ dit_handler,
439
+ lora_rank: int,
440
+ lora_alpha: int,
441
+ lora_dropout: float,
442
+ learning_rate: float,
443
+ train_epochs: int,
444
+ train_batch_size: int,
445
+ gradient_accumulation: int,
446
+ save_every_n_epochs: int,
447
+ training_shift: float,
448
+ training_seed: int,
449
+ lora_output_dir: str,
450
+ training_state: Dict,
451
+ progress=None,
452
+ ):
453
+ """Start LoRA training from preprocessed tensors.
454
+
455
+ This is a generator function that yields progress updates.
456
+ """
457
+ if not tensor_dir or not tensor_dir.strip():
458
+ yield "❌ Please enter a tensor directory path", "", None, training_state
459
+ return
460
+
461
+ tensor_dir = tensor_dir.strip()
462
+
463
+ if not os.path.exists(tensor_dir):
464
+ yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
465
+ return
466
+
467
+ if dit_handler is None or dit_handler.model is None:
468
+ yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
469
+ return
470
+
471
+ # Check for required training dependencies
472
+ try:
473
+ from lightning.fabric import Fabric
474
+ from peft import get_peft_model, LoraConfig
475
+ except ImportError as e:
476
+ yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
477
+ return
478
+
479
+ training_state["is_training"] = True
480
+ training_state["should_stop"] = False
481
+
482
+ try:
483
+ from acestep.training.trainer import LoRATrainer
484
+ from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
485
+
486
+ # Create configs
487
+ lora_config = LoRAConfigClass(
488
+ r=lora_rank,
489
+ alpha=lora_alpha,
490
+ dropout=lora_dropout,
491
+ )
492
+
493
+ training_config = TrainingConfig(
494
+ shift=training_shift,
495
+ learning_rate=learning_rate,
496
+ batch_size=train_batch_size,
497
+ gradient_accumulation_steps=gradient_accumulation,
498
+ max_epochs=train_epochs,
499
+ save_every_n_epochs=save_every_n_epochs,
500
+ seed=training_seed,
501
+ output_dir=lora_output_dir,
502
+ )
503
+
504
+ import pandas as pd
505
+
506
+ # Initialize training log and loss history
507
+ log_lines = []
508
+ loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
509
+
510
+ # Start timer
511
+ start_time = time.time()
512
+
513
+ yield f"🚀 Starting training from {tensor_dir}...", "", loss_data, training_state
514
+
515
+ # Create trainer
516
+ trainer = LoRATrainer(
517
+ dit_handler=dit_handler,
518
+ lora_config=lora_config,
519
+ training_config=training_config,
520
+ )
521
+
522
+ # Collect loss history
523
+ step_list = []
524
+ loss_list = []
525
+
526
+ # Train with progress updates using preprocessed tensors
527
+ for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
528
+ # Calculate elapsed time and ETA
529
+ elapsed_seconds = time.time() - start_time
530
+ time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
531
+
532
+ # Parse "Epoch x/y" from status to calculate ETA
533
+ match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
534
+ if match:
535
+ current_ep = int(match.group(1))
536
+ total_ep = int(match.group(2))
537
+ if current_ep > 0:
538
+ eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
539
+ time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
540
+
541
+ # Display status with time info
542
+ display_status = f"{status}\n{time_info}"
543
+
544
+ # Terminal log
545
+ log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
546
+ logger.info(log_msg)
547
+
548
+ # Add to UI log
549
+ log_lines.append(status)
550
+ if len(log_lines) > 15:
551
+ log_lines = log_lines[-15:]
552
+ log_text = "\n".join(log_lines)
553
+
554
+ # Track loss for plot (only valid values)
555
+ if step > 0 and loss is not None and loss == loss: # Check for NaN
556
+ step_list.append(step)
557
+ loss_list.append(float(loss))
558
+ loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
559
+
560
+ yield display_status, log_text, loss_data, training_state
561
+
562
+ if training_state.get("should_stop", False):
563
+ logger.info("⏹️ Training stopped by user")
564
+ log_lines.append("⏹️ Training stopped by user")
565
+ yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
566
+ break
567
+
568
+ total_time = time.time() - start_time
569
+ training_state["is_training"] = False
570
+ completion_msg = f"✅ Training completed! Total time: {_format_duration(total_time)}"
571
+
572
+ logger.info(completion_msg)
573
+ log_lines.append(completion_msg)
574
+
575
+ yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
576
+
577
+ except Exception as e:
578
+ logger.exception("Training error")
579
+ training_state["is_training"] = False
580
+ import pandas as pd
581
+ empty_df = pd.DataFrame({"step": [], "loss": []})
582
+ yield f"❌ Error: {str(e)}", str(e), empty_df, training_state
583
+
584
+
585
+ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
586
+ """Stop the current training process.
587
+
588
+ Returns:
589
+ Tuple of (status, training_state)
590
+ """
591
+ if not training_state.get("is_training", False):
592
+ return "⚠️ No training in progress", training_state
593
+
594
+ training_state["should_stop"] = True
595
+ return "⏹️ Stopping training...", training_state
596
+
597
+
598
+ def export_lora(
599
+ export_path: str,
600
+ lora_output_dir: str,
601
+ ) -> str:
602
+ """Export the trained LoRA weights.
603
+
604
+ Returns:
605
+ Status message
606
+ """
607
+ if not export_path or not export_path.strip():
608
+ return "❌ Please enter an export path"
609
+
610
+ # Check if there's a trained model to export
611
+ final_dir = os.path.join(lora_output_dir, "final")
612
+ checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
613
+
614
+ # Prefer final, fallback to checkpoints
615
+ if os.path.exists(final_dir):
616
+ source_path = final_dir
617
+ elif os.path.exists(checkpoint_dir):
618
+ # Find the latest checkpoint
619
+ checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
620
+ if not checkpoints:
621
+ return "❌ No checkpoints found"
622
+
623
+ checkpoints.sort(key=lambda x: int(x.split("_")[1]))
624
+ latest = checkpoints[-1]
625
+ source_path = os.path.join(checkpoint_dir, latest)
626
+ else:
627
+ return f"❌ No trained model found in {lora_output_dir}"
628
+
629
+ try:
630
+ import shutil
631
+
632
+ export_path = export_path.strip()
633
+ os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
634
+
635
+ if os.path.exists(export_path):
636
+ shutil.rmtree(export_path)
637
+
638
+ shutil.copytree(source_path, export_path)
639
+
640
+ return f"✅ LoRA exported to {export_path}"
641
+
642
+ except Exception as e:
643
+ logger.exception("Export error")
644
+ return f"❌ Export failed: {str(e)}"
acestep/gradio_ui/interfaces/__init__.py CHANGED
@@ -7,7 +7,8 @@ from acestep.gradio_ui.i18n import get_i18n, t
7
  from acestep.gradio_ui.interfaces.dataset import create_dataset_section
8
  from acestep.gradio_ui.interfaces.generation import create_generation_section
9
  from acestep.gradio_ui.interfaces.result import create_results_section
10
- from acestep.gradio_ui.events import setup_event_handlers
 
11
 
12
 
13
  def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
@@ -76,7 +77,13 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
76
  # Results Section
77
  results_section = create_results_section(dit_handler)
78
 
 
 
 
79
  # Connect event handlers
80
  setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
 
 
 
81
 
82
  return demo
 
7
  from acestep.gradio_ui.interfaces.dataset import create_dataset_section
8
  from acestep.gradio_ui.interfaces.generation import create_generation_section
9
  from acestep.gradio_ui.interfaces.result import create_results_section
10
+ from acestep.gradio_ui.interfaces.training import create_training_section
11
+ from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
12
 
13
 
14
  def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
 
77
  # Results Section
78
  results_section = create_results_section(dit_handler)
79
 
80
+ # Training Section (LoRA training and dataset builder)
81
+ training_section = create_training_section(dit_handler, llm_handler)
82
+
83
  # Connect event handlers
84
  setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
85
+
86
+ # Connect training event handlers
87
+ setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
88
 
89
  return demo
acestep/gradio_ui/interfaces/generation.py CHANGED
@@ -144,6 +144,31 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
144
  # Set init_status value from init_params if pre-initialized
145
  init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
146
  init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  # Inputs
149
  with gr.Row():
@@ -653,6 +678,12 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
653
  "use_flash_attention_checkbox": use_flash_attention_checkbox,
654
  "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
655
  "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
 
 
 
 
 
 
656
  "task_type": task_type,
657
  "instruction_display_gen": instruction_display_gen,
658
  "track_name": track_name,
 
144
  # Set init_status value from init_params if pre-initialized
145
  init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
146
  init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
147
+
148
+ # LoRA Configuration Section
149
+ gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
150
+ with gr.Row():
151
+ lora_path = gr.Textbox(
152
+ label="LoRA Path",
153
+ placeholder="./lora_output/final/adapter",
154
+ info="Path to trained LoRA adapter directory",
155
+ scale=3,
156
+ )
157
+ load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
158
+ unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
159
+ with gr.Row():
160
+ use_lora_checkbox = gr.Checkbox(
161
+ label="Use LoRA",
162
+ value=False,
163
+ info="Enable LoRA adapter for inference",
164
+ scale=1,
165
+ )
166
+ lora_status = gr.Textbox(
167
+ label="LoRA Status",
168
+ value="No LoRA loaded",
169
+ interactive=False,
170
+ scale=2,
171
+ )
172
 
173
  # Inputs
174
  with gr.Row():
 
678
  "use_flash_attention_checkbox": use_flash_attention_checkbox,
679
  "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
680
  "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
681
+ # LoRA components
682
+ "lora_path": lora_path,
683
+ "load_lora_btn": load_lora_btn,
684
+ "unload_lora_btn": unload_lora_btn,
685
+ "use_lora_checkbox": use_lora_checkbox,
686
+ "lora_status": lora_status,
687
  "task_type": task_type,
688
  "instruction_display_gen": instruction_display_gen,
689
  "track_name": track_name,
acestep/gradio_ui/interfaces/training.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Training Tab Module
3
+
4
+ Contains the dataset builder and LoRA training interface components.
5
+ """
6
+
7
+ import os
8
+ import gradio as gr
9
+ from acestep.gradio_ui.i18n import t
10
+
11
+
12
+ def create_training_section(dit_handler, llm_handler) -> dict:
13
+ """Create the training tab section with dataset builder and training controls.
14
+
15
+ Args:
16
+ dit_handler: DiT handler instance
17
+ llm_handler: LLM handler instance
18
+
19
+ Returns:
20
+ Dictionary of Gradio components for event handling
21
+ """
22
+
23
+ with gr.Tab("🎓 LoRA Training"):
24
+ gr.HTML("""
25
+ <div style="text-align: center; padding: 10px; margin-bottom: 15px;">
26
+ <h2>🎵 LoRA Training for ACE-Step</h2>
27
+ <p>Build datasets from your audio files and train custom LoRA adapters</p>
28
+ </div>
29
+ """)
30
+
31
+ with gr.Tabs():
32
+ # ==================== Dataset Builder Tab ====================
33
+ with gr.Tab("📁 Dataset Builder"):
34
+ # ========== Load Existing OR Scan New ==========
35
+ gr.HTML("""
36
+ <div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
37
+ <h3 style="margin: 0 0 5px 0;">🚀 Quick Start</h3>
38
+ <p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
39
+ </div>
40
+ """)
41
+
42
+ with gr.Row():
43
+ with gr.Column(scale=1):
44
+ gr.HTML("<h4>📂 Load Existing Dataset</h4>")
45
+ with gr.Row():
46
+ load_json_path = gr.Textbox(
47
+ label="Dataset JSON Path",
48
+ placeholder="./datasets/my_lora_dataset.json",
49
+ info="Load a previously saved dataset",
50
+ scale=3,
51
+ )
52
+ load_json_btn = gr.Button("📂 Load", variant="primary", scale=1)
53
+ load_json_status = gr.Textbox(
54
+ label="Load Status",
55
+ interactive=False,
56
+ )
57
+
58
+ with gr.Column(scale=1):
59
+ gr.HTML("<h4>🔍 Scan New Directory</h4>")
60
+ with gr.Row():
61
+ audio_directory = gr.Textbox(
62
+ label="Audio Directory Path",
63
+ placeholder="/path/to/your/audio/folder",
64
+ info="Scan for audio files (wav, mp3, flac, ogg, opus)",
65
+ scale=3,
66
+ )
67
+ scan_btn = gr.Button("🔍 Scan", variant="secondary", scale=1)
68
+ scan_status = gr.Textbox(
69
+ label="Scan Status",
70
+ interactive=False,
71
+ )
72
+
73
+ gr.HTML("<hr>")
74
+
75
+ with gr.Row():
76
+ with gr.Column(scale=2):
77
+
78
+ # Audio files table
79
+ audio_files_table = gr.Dataframe(
80
+ headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"],
81
+ datatype=["number", "str", "str", "str", "str", "str", "str"],
82
+ label="Found Audio Files",
83
+ interactive=False,
84
+ wrap=True,
85
+ )
86
+
87
+ with gr.Column(scale=1):
88
+ gr.HTML("<h3>⚙️ Dataset Settings</h3>")
89
+
90
+ dataset_name = gr.Textbox(
91
+ label="Dataset Name",
92
+ value="my_lora_dataset",
93
+ placeholder="Enter dataset name",
94
+ )
95
+
96
+ all_instrumental = gr.Checkbox(
97
+ label="All Instrumental",
98
+ value=True,
99
+ info="Check if all tracks are instrumental (no vocals)",
100
+ )
101
+
102
+ need_lyrics = gr.Checkbox(
103
+ label="Transcribe Lyrics",
104
+ value=False,
105
+ info="Attempt to transcribe lyrics (slower)",
106
+ interactive=False, # Disabled for now
107
+ )
108
+
109
+ custom_tag = gr.Textbox(
110
+ label="Custom Activation Tag",
111
+ placeholder="e.g., 8bit_retro, my_style",
112
+ info="Unique tag to activate this LoRA's style",
113
+ )
114
+
115
+ tag_position = gr.Radio(
116
+ choices=[
117
+ ("Prepend (tag, caption)", "prepend"),
118
+ ("Append (caption, tag)", "append"),
119
+ ("Replace caption", "replace"),
120
+ ],
121
+ value="replace",
122
+ label="Tag Position",
123
+ info="Where to place the custom tag in the caption",
124
+ )
125
+
126
+ gr.HTML("<hr><h3>🤖 Step 2: Auto-Label with AI</h3>")
127
+
128
+ with gr.Row():
129
+ with gr.Column(scale=3):
130
+ gr.Markdown("""
131
+ Click the button below to automatically generate metadata for all audio files using AI:
132
+ - **Caption**: Music style, genre, mood description
133
+ - **BPM**: Beats per minute
134
+ - **Key**: Musical key (e.g., C Major, Am)
135
+ - **Time Signature**: 4/4, 3/4, etc.
136
+ """)
137
+ skip_metas = gr.Checkbox(
138
+ label="Skip Metas (No LLM)",
139
+ value=False,
140
+ info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental",
141
+ )
142
+ with gr.Column(scale=1):
143
+ auto_label_btn = gr.Button(
144
+ "🏷️ Auto-Label All",
145
+ variant="primary",
146
+ size="lg",
147
+ )
148
+
149
+ label_progress = gr.Textbox(
150
+ label="Labeling Progress",
151
+ interactive=False,
152
+ lines=2,
153
+ )
154
+
155
+ gr.HTML("<hr><h3>👀 Step 3: Preview & Edit</h3>")
156
+
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ sample_selector = gr.Slider(
160
+ minimum=0,
161
+ maximum=0,
162
+ step=1,
163
+ value=0,
164
+ label="Select Sample #",
165
+ info="Choose a sample to preview and edit",
166
+ )
167
+
168
+ preview_audio = gr.Audio(
169
+ label="Audio Preview",
170
+ type="filepath",
171
+ interactive=False,
172
+ )
173
+
174
+ preview_filename = gr.Textbox(
175
+ label="Filename",
176
+ interactive=False,
177
+ )
178
+
179
+ with gr.Column(scale=2):
180
+ with gr.Row():
181
+ edit_caption = gr.Textbox(
182
+ label="Caption",
183
+ lines=3,
184
+ placeholder="Music description...",
185
+ )
186
+
187
+ with gr.Row():
188
+ edit_lyrics = gr.Textbox(
189
+ label="Lyrics",
190
+ lines=4,
191
+ placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
192
+ )
193
+
194
+ with gr.Row():
195
+ edit_bpm = gr.Number(
196
+ label="BPM",
197
+ precision=0,
198
+ )
199
+ edit_keyscale = gr.Textbox(
200
+ label="Key",
201
+ placeholder="C Major",
202
+ )
203
+ edit_timesig = gr.Dropdown(
204
+ choices=["", "2", "3", "4", "6"],
205
+ label="Time Signature",
206
+ )
207
+ edit_duration = gr.Number(
208
+ label="Duration (s)",
209
+ precision=1,
210
+ interactive=False,
211
+ )
212
+
213
+ with gr.Row():
214
+ edit_language = gr.Dropdown(
215
+ choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
216
+ value="instrumental",
217
+ label="Language",
218
+ )
219
+ edit_instrumental = gr.Checkbox(
220
+ label="Instrumental",
221
+ value=True,
222
+ )
223
+ save_edit_btn = gr.Button("💾 Save Changes", variant="secondary")
224
+
225
+ edit_status = gr.Textbox(
226
+ label="Edit Status",
227
+ interactive=False,
228
+ )
229
+
230
+ gr.HTML("<hr><h3>💾 Step 4: Save Dataset</h3>")
231
+
232
+ with gr.Row():
233
+ with gr.Column(scale=3):
234
+ save_path = gr.Textbox(
235
+ label="Save Path",
236
+ value="./datasets/my_lora_dataset.json",
237
+ placeholder="./datasets/dataset_name.json",
238
+ info="Path where the dataset JSON will be saved",
239
+ )
240
+ with gr.Column(scale=1):
241
+ save_dataset_btn = gr.Button(
242
+ "💾 Save Dataset",
243
+ variant="primary",
244
+ size="lg",
245
+ )
246
+
247
+ save_status = gr.Textbox(
248
+ label="Save Status",
249
+ interactive=False,
250
+ lines=2,
251
+ )
252
+
253
+ gr.HTML("<hr><h3>⚡ Step 5: Preprocess to Tensors</h3>")
254
+
255
+ gr.Markdown("""
256
+ **Preprocessing converts your dataset to pre-computed tensors for fast training.**
257
+
258
+ You can either:
259
+ - Use the dataset from Steps 1-4 above, **OR**
260
+ - Load an existing dataset JSON file (if you've already saved one)
261
+ """)
262
+
263
+ with gr.Row():
264
+ with gr.Column(scale=3):
265
+ load_existing_dataset_path = gr.Textbox(
266
+ label="Load Existing Dataset (Optional)",
267
+ placeholder="./datasets/my_lora_dataset.json",
268
+ info="Path to a previously saved dataset JSON file",
269
+ )
270
+ with gr.Column(scale=1):
271
+ load_existing_dataset_btn = gr.Button(
272
+ "📂 Load Dataset",
273
+ variant="secondary",
274
+ size="lg",
275
+ )
276
+
277
+ load_existing_status = gr.Textbox(
278
+ label="Load Status",
279
+ interactive=False,
280
+ )
281
+
282
+ gr.Markdown("""
283
+ This step:
284
+ - Encodes audio to VAE latents
285
+ - Encodes captions and lyrics to text embeddings
286
+ - Runs the condition encoder
287
+ - Saves all tensors to `.pt` files
288
+
289
+ ⚠️ **This requires the model to be loaded and may take a few minutes.**
290
+ """)
291
+
292
+ with gr.Row():
293
+ with gr.Column(scale=3):
294
+ preprocess_output_dir = gr.Textbox(
295
+ label="Tensor Output Directory",
296
+ value="./datasets/preprocessed_tensors",
297
+ placeholder="./datasets/preprocessed_tensors",
298
+ info="Directory to save preprocessed tensor files",
299
+ )
300
+ with gr.Column(scale=1):
301
+ preprocess_btn = gr.Button(
302
+ "⚡ Preprocess",
303
+ variant="primary",
304
+ size="lg",
305
+ )
306
+
307
+ preprocess_progress = gr.Textbox(
308
+ label="Preprocessing Progress",
309
+ interactive=False,
310
+ lines=3,
311
+ )
312
+
313
+ # ==================== Training Tab ====================
314
+ with gr.Tab("🚀 Train LoRA"):
315
+ with gr.Row():
316
+ with gr.Column(scale=2):
317
+ gr.HTML("<h3>📊 Preprocessed Dataset Selection</h3>")
318
+
319
+ gr.Markdown("""
320
+ Select the directory containing preprocessed tensor files (`.pt` files).
321
+ These are created in the "Dataset Builder" tab using the "Preprocess" button.
322
+ """)
323
+
324
+ training_tensor_dir = gr.Textbox(
325
+ label="Preprocessed Tensors Directory",
326
+ placeholder="./datasets/preprocessed_tensors",
327
+ value="./datasets/preprocessed_tensors",
328
+ info="Directory containing preprocessed .pt tensor files",
329
+ )
330
+
331
+ load_dataset_btn = gr.Button("📂 Load Dataset", variant="secondary")
332
+
333
+ training_dataset_info = gr.Textbox(
334
+ label="Dataset Info",
335
+ interactive=False,
336
+ lines=3,
337
+ )
338
+
339
+ with gr.Column(scale=1):
340
+ gr.HTML("<h3>⚙️ LoRA Settings</h3>")
341
+
342
+ lora_rank = gr.Slider(
343
+ minimum=4,
344
+ maximum=256,
345
+ step=4,
346
+ value=64,
347
+ label="LoRA Rank (r)",
348
+ info="Higher = more capacity, more memory",
349
+ )
350
+
351
+ lora_alpha = gr.Slider(
352
+ minimum=4,
353
+ maximum=512,
354
+ step=4,
355
+ value=128,
356
+ label="LoRA Alpha",
357
+ info="Scaling factor (typically 2x rank)",
358
+ )
359
+
360
+ lora_dropout = gr.Slider(
361
+ minimum=0.0,
362
+ maximum=0.5,
363
+ step=0.05,
364
+ value=0.1,
365
+ label="LoRA Dropout",
366
+ )
367
+
368
+ gr.HTML("<hr><h3>🎛️ Training Parameters</h3>")
369
+
370
+ with gr.Row():
371
+ learning_rate = gr.Number(
372
+ label="Learning Rate",
373
+ value=1e-4,
374
+ info="Start with 1e-4, adjust if needed",
375
+ )
376
+
377
+ train_epochs = gr.Slider(
378
+ minimum=100,
379
+ maximum=4000,
380
+ step=100,
381
+ value=500,
382
+ label="Max Epochs",
383
+ )
384
+
385
+ train_batch_size = gr.Slider(
386
+ minimum=1,
387
+ maximum=8,
388
+ step=1,
389
+ value=1,
390
+ label="Batch Size",
391
+ info="Increase if you have enough VRAM",
392
+ )
393
+
394
+ gradient_accumulation = gr.Slider(
395
+ minimum=1,
396
+ maximum=16,
397
+ step=1,
398
+ value=1,
399
+ label="Gradient Accumulation",
400
+ info="Effective batch = batch_size × accumulation",
401
+ )
402
+
403
+ with gr.Row():
404
+ save_every_n_epochs = gr.Slider(
405
+ minimum=50,
406
+ maximum=1000,
407
+ step=50,
408
+ value=200,
409
+ label="Save Every N Epochs",
410
+ )
411
+
412
+ training_shift = gr.Slider(
413
+ minimum=1.0,
414
+ maximum=5.0,
415
+ step=0.5,
416
+ value=3.0,
417
+ label="Shift",
418
+ info="Timestep shift for turbo model",
419
+ )
420
+
421
+ training_seed = gr.Number(
422
+ label="Seed",
423
+ value=42,
424
+ precision=0,
425
+ )
426
+
427
+ with gr.Row():
428
+ lora_output_dir = gr.Textbox(
429
+ label="Output Directory",
430
+ value="./lora_output",
431
+ placeholder="./lora_output",
432
+ info="Directory to save trained LoRA weights",
433
+ )
434
+
435
+ gr.HTML("<hr>")
436
+
437
+ with gr.Row():
438
+ with gr.Column(scale=1):
439
+ start_training_btn = gr.Button(
440
+ "🚀 Start Training",
441
+ variant="primary",
442
+ size="lg",
443
+ )
444
+ with gr.Column(scale=1):
445
+ stop_training_btn = gr.Button(
446
+ "⏹️ Stop Training",
447
+ variant="stop",
448
+ size="lg",
449
+ )
450
+
451
+ training_progress = gr.Textbox(
452
+ label="Training Progress",
453
+ interactive=False,
454
+ lines=2,
455
+ )
456
+
457
+ with gr.Row():
458
+ training_log = gr.Textbox(
459
+ label="Training Log",
460
+ interactive=False,
461
+ lines=10,
462
+ max_lines=15,
463
+ scale=1,
464
+ )
465
+ training_loss_plot = gr.LinePlot(
466
+ x="step",
467
+ y="loss",
468
+ title="Training Loss",
469
+ x_title="Step",
470
+ y_title="Loss",
471
+ scale=1,
472
+ )
473
+
474
+ gr.HTML("<hr><h3>📦 Export LoRA</h3>")
475
+
476
+ with gr.Row():
477
+ export_path = gr.Textbox(
478
+ label="Export Path",
479
+ value="./lora_output/final_lora",
480
+ placeholder="./lora_output/my_lora",
481
+ )
482
+ export_lora_btn = gr.Button("📦 Export LoRA", variant="secondary")
483
+
484
+ export_status = gr.Textbox(
485
+ label="Export Status",
486
+ interactive=False,
487
+ )
488
+
489
+ # Store dataset builder state
490
+ dataset_builder_state = gr.State(None)
491
+ training_state = gr.State({"is_training": False, "should_stop": False})
492
+
493
+ return {
494
+ # Dataset Builder - Load or Scan
495
+ "load_json_path": load_json_path,
496
+ "load_json_btn": load_json_btn,
497
+ "load_json_status": load_json_status,
498
+ "audio_directory": audio_directory,
499
+ "scan_btn": scan_btn,
500
+ "scan_status": scan_status,
501
+ "audio_files_table": audio_files_table,
502
+ "dataset_name": dataset_name,
503
+ "all_instrumental": all_instrumental,
504
+ "need_lyrics": need_lyrics,
505
+ "custom_tag": custom_tag,
506
+ "tag_position": tag_position,
507
+ "skip_metas": skip_metas,
508
+ "auto_label_btn": auto_label_btn,
509
+ "label_progress": label_progress,
510
+ "sample_selector": sample_selector,
511
+ "preview_audio": preview_audio,
512
+ "preview_filename": preview_filename,
513
+ "edit_caption": edit_caption,
514
+ "edit_lyrics": edit_lyrics,
515
+ "edit_bpm": edit_bpm,
516
+ "edit_keyscale": edit_keyscale,
517
+ "edit_timesig": edit_timesig,
518
+ "edit_duration": edit_duration,
519
+ "edit_language": edit_language,
520
+ "edit_instrumental": edit_instrumental,
521
+ "save_edit_btn": save_edit_btn,
522
+ "edit_status": edit_status,
523
+ "save_path": save_path,
524
+ "save_dataset_btn": save_dataset_btn,
525
+ "save_status": save_status,
526
+ # Preprocessing
527
+ "load_existing_dataset_path": load_existing_dataset_path,
528
+ "load_existing_dataset_btn": load_existing_dataset_btn,
529
+ "load_existing_status": load_existing_status,
530
+ "preprocess_output_dir": preprocess_output_dir,
531
+ "preprocess_btn": preprocess_btn,
532
+ "preprocess_progress": preprocess_progress,
533
+ "dataset_builder_state": dataset_builder_state,
534
+ # Training
535
+ "training_tensor_dir": training_tensor_dir,
536
+ "load_dataset_btn": load_dataset_btn,
537
+ "training_dataset_info": training_dataset_info,
538
+ "lora_rank": lora_rank,
539
+ "lora_alpha": lora_alpha,
540
+ "lora_dropout": lora_dropout,
541
+ "learning_rate": learning_rate,
542
+ "train_epochs": train_epochs,
543
+ "train_batch_size": train_batch_size,
544
+ "gradient_accumulation": gradient_accumulation,
545
+ "save_every_n_epochs": save_every_n_epochs,
546
+ "training_shift": training_shift,
547
+ "training_seed": training_seed,
548
+ "lora_output_dir": lora_output_dir,
549
+ "start_training_btn": start_training_btn,
550
+ "stop_training_btn": stop_training_btn,
551
+ "training_progress": training_progress,
552
+ "training_log": training_log,
553
+ "training_loss_plot": training_loss_plot,
554
+ "export_path": export_path,
555
+ "export_lora_btn": export_lora_btn,
556
+ "export_status": export_status,
557
+ "training_state": training_state,
558
+ }
acestep/handler.py CHANGED
@@ -3,6 +3,10 @@ Business Logic Handler
3
  Encapsulates all data processing and business logic as a bridge between model and UI
4
  """
5
  import os
 
 
 
 
6
  import math
7
  from copy import deepcopy
8
  import tempfile
@@ -70,6 +74,11 @@ class AceStepHandler:
70
  self.offload_to_cpu = False
71
  self.offload_dit_to_cpu = False
72
  self.current_offload_cost = 0.0
 
 
 
 
 
73
 
74
  def get_available_checkpoints(self) -> str:
75
  """Return project root directory path"""
@@ -114,6 +123,137 @@ class AceStepHandler:
114
  return False
115
  return getattr(self.config, 'is_turbo', False)
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def initialize_service(
118
  self,
119
  project_root: str,
 
3
  Encapsulates all data processing and business logic as a bridge between model and UI
4
  """
5
  import os
6
+
7
+ # Disable tokenizers parallelism to avoid fork warning
8
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
9
+
10
  import math
11
  from copy import deepcopy
12
  import tempfile
 
74
  self.offload_to_cpu = False
75
  self.offload_dit_to_cpu = False
76
  self.current_offload_cost = 0.0
77
+
78
+ # LoRA state
79
+ self.lora_loaded = False
80
+ self.use_lora = False
81
+ self._base_decoder = None # Backup of original decoder
82
 
83
  def get_available_checkpoints(self) -> str:
84
  """Return project root directory path"""
 
123
  return False
124
  return getattr(self.config, 'is_turbo', False)
125
 
126
+ def load_lora(self, lora_path: str) -> str:
127
+ """Load LoRA adapter into the decoder.
128
+
129
+ Args:
130
+ lora_path: Path to the LoRA adapter directory (containing adapter_config.json)
131
+
132
+ Returns:
133
+ Status message
134
+ """
135
+ if self.model is None:
136
+ return "❌ Model not initialized. Please initialize service first."
137
+
138
+ if not lora_path or not lora_path.strip():
139
+ return "❌ Please provide a LoRA path."
140
+
141
+ lora_path = lora_path.strip()
142
+
143
+ # Check if path exists
144
+ if not os.path.exists(lora_path):
145
+ return f"❌ LoRA path not found: {lora_path}"
146
+
147
+ # Check if it's a valid PEFT adapter directory
148
+ config_file = os.path.join(lora_path, "adapter_config.json")
149
+ if not os.path.exists(config_file):
150
+ return f"❌ Invalid LoRA adapter: adapter_config.json not found in {lora_path}"
151
+
152
+ try:
153
+ from peft import PeftModel, PeftConfig
154
+ except ImportError:
155
+ return "❌ PEFT library not installed. Please install with: pip install peft"
156
+
157
+ try:
158
+ # Backup base decoder if not already backed up
159
+ if self._base_decoder is None:
160
+ import copy
161
+ self._base_decoder = copy.deepcopy(self.model.decoder)
162
+ logger.info("Base decoder backed up")
163
+ else:
164
+ # Restore base decoder before loading new LoRA
165
+ self.model.decoder = copy.deepcopy(self._base_decoder)
166
+ logger.info("Restored base decoder before loading new LoRA")
167
+
168
+ # Load PEFT adapter
169
+ logger.info(f"Loading LoRA adapter from {lora_path}")
170
+ self.model.decoder = PeftModel.from_pretrained(
171
+ self.model.decoder,
172
+ lora_path,
173
+ is_trainable=False,
174
+ )
175
+ self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
176
+ self.model.decoder.eval()
177
+
178
+ self.lora_loaded = True
179
+ self.use_lora = True # Enable LoRA by default after loading
180
+
181
+ logger.info(f"LoRA adapter loaded successfully from {lora_path}")
182
+ return f"✅ LoRA loaded from {lora_path}"
183
+
184
+ except Exception as e:
185
+ logger.exception("Failed to load LoRA adapter")
186
+ return f"❌ Failed to load LoRA: {str(e)}"
187
+
188
+ def unload_lora(self) -> str:
189
+ """Unload LoRA adapter and restore base decoder.
190
+
191
+ Returns:
192
+ Status message
193
+ """
194
+ if not self.lora_loaded:
195
+ return "⚠️ No LoRA adapter loaded."
196
+
197
+ if self._base_decoder is None:
198
+ return "❌ Base decoder backup not found. Cannot restore."
199
+
200
+ try:
201
+ import copy
202
+ # Restore base decoder
203
+ self.model.decoder = copy.deepcopy(self._base_decoder)
204
+ self.model.decoder = self.model.decoder.to(self.device).to(self.dtype)
205
+ self.model.decoder.eval()
206
+
207
+ self.lora_loaded = False
208
+ self.use_lora = False
209
+
210
+ logger.info("LoRA unloaded, base decoder restored")
211
+ return "✅ LoRA unloaded, using base model"
212
+
213
+ except Exception as e:
214
+ logger.exception("Failed to unload LoRA")
215
+ return f"❌ Failed to unload LoRA: {str(e)}"
216
+
217
+ def set_use_lora(self, use_lora: bool) -> str:
218
+ """Toggle LoRA usage for inference.
219
+
220
+ Args:
221
+ use_lora: Whether to use LoRA adapter
222
+
223
+ Returns:
224
+ Status message
225
+ """
226
+ if use_lora and not self.lora_loaded:
227
+ return "❌ No LoRA adapter loaded. Please load a LoRA first."
228
+
229
+ self.use_lora = use_lora
230
+
231
+ # Use PEFT's enable/disable methods if available
232
+ if self.lora_loaded and hasattr(self.model.decoder, 'disable_adapter_layers'):
233
+ try:
234
+ if use_lora:
235
+ self.model.decoder.enable_adapter_layers()
236
+ logger.info("LoRA adapter enabled")
237
+ else:
238
+ self.model.decoder.disable_adapter_layers()
239
+ logger.info("LoRA adapter disabled")
240
+ except Exception as e:
241
+ logger.warning(f"Could not toggle adapter layers: {e}")
242
+
243
+ status = "enabled" if use_lora else "disabled"
244
+ return f"✅ LoRA {status}"
245
+
246
+ def get_lora_status(self) -> Dict[str, Any]:
247
+ """Get current LoRA status.
248
+
249
+ Returns:
250
+ Dictionary with LoRA status info
251
+ """
252
+ return {
253
+ "loaded": self.lora_loaded,
254
+ "active": self.use_lora,
255
+ }
256
+
257
  def initialize_service(
258
  self,
259
  project_root: str,
acestep/training/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Training Module
3
+
4
+ This module provides LoRA training functionality for ACE-Step models,
5
+ including dataset building, audio labeling, and training utilities.
6
+ """
7
+
8
+ from acestep.training.dataset_builder import DatasetBuilder, AudioSample
9
+ from acestep.training.configs import LoRAConfig, TrainingConfig
10
+ from acestep.training.lora_utils import (
11
+ inject_lora_into_dit,
12
+ save_lora_weights,
13
+ load_lora_weights,
14
+ merge_lora_weights,
15
+ check_peft_available,
16
+ )
17
+ from acestep.training.data_module import (
18
+ # Preprocessed (recommended)
19
+ PreprocessedTensorDataset,
20
+ PreprocessedDataModule,
21
+ collate_preprocessed_batch,
22
+ # Legacy (raw audio)
23
+ AceStepTrainingDataset,
24
+ AceStepDataModule,
25
+ collate_training_batch,
26
+ load_dataset_from_json,
27
+ )
28
+ from acestep.training.trainer import LoRATrainer, PreprocessedLoRAModule, LIGHTNING_AVAILABLE
29
+
30
+ def check_lightning_available():
31
+ """Check if Lightning Fabric is available."""
32
+ return LIGHTNING_AVAILABLE
33
+
34
+ __all__ = [
35
+ # Dataset Builder
36
+ "DatasetBuilder",
37
+ "AudioSample",
38
+ # Configs
39
+ "LoRAConfig",
40
+ "TrainingConfig",
41
+ # LoRA Utils
42
+ "inject_lora_into_dit",
43
+ "save_lora_weights",
44
+ "load_lora_weights",
45
+ "merge_lora_weights",
46
+ "check_peft_available",
47
+ # Data Module (Preprocessed - Recommended)
48
+ "PreprocessedTensorDataset",
49
+ "PreprocessedDataModule",
50
+ "collate_preprocessed_batch",
51
+ # Data Module (Legacy)
52
+ "AceStepTrainingDataset",
53
+ "AceStepDataModule",
54
+ "collate_training_batch",
55
+ "load_dataset_from_json",
56
+ # Trainer
57
+ "LoRATrainer",
58
+ "PreprocessedLoRAModule",
59
+ "check_lightning_available",
60
+ "LIGHTNING_AVAILABLE",
61
+ ]
acestep/training/configs.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Configuration Classes
3
+
4
+ Contains dataclasses for LoRA and training configurations.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import List, Optional
9
+
10
+
11
+ @dataclass
12
+ class LoRAConfig:
13
+ """Configuration for LoRA (Low-Rank Adaptation) training.
14
+
15
+ Attributes:
16
+ r: LoRA rank (dimension of low-rank matrices)
17
+ alpha: LoRA scaling factor (alpha/r determines the scaling)
18
+ dropout: Dropout probability for LoRA layers
19
+ target_modules: List of module names to apply LoRA to
20
+ bias: Whether to train bias parameters ("none", "all", or "lora_only")
21
+ """
22
+ r: int = 8
23
+ alpha: int = 16
24
+ dropout: float = 0.1
25
+ target_modules: List[str] = field(default_factory=lambda: [
26
+ "q_proj", "k_proj", "v_proj", "o_proj"
27
+ ])
28
+ bias: str = "none"
29
+
30
+ def to_dict(self):
31
+ """Convert to dictionary for PEFT config."""
32
+ return {
33
+ "r": self.r,
34
+ "lora_alpha": self.alpha,
35
+ "lora_dropout": self.dropout,
36
+ "target_modules": self.target_modules,
37
+ "bias": self.bias,
38
+ }
39
+
40
+
41
+ @dataclass
42
+ class TrainingConfig:
43
+ """Configuration for LoRA training process.
44
+
45
+ Training uses:
46
+ - BFloat16 precision (only supported precision)
47
+ - Discrete timesteps from turbo shift=3.0 schedule (8 steps)
48
+ - Randomly samples one of 8 timesteps per training step:
49
+ [1.0, 0.9545, 0.9, 0.8333, 0.75, 0.6429, 0.5, 0.3]
50
+
51
+ Attributes:
52
+ shift: Timestep shift factor (fixed at 3.0 for turbo model)
53
+ num_inference_steps: Number of inference steps (fixed at 8 for turbo)
54
+ learning_rate: Initial learning rate
55
+ batch_size: Training batch size
56
+ gradient_accumulation_steps: Number of gradient accumulation steps
57
+ max_epochs: Maximum number of training epochs
58
+ save_every_n_epochs: Save checkpoint every N epochs
59
+ warmup_steps: Number of warmup steps for learning rate scheduler
60
+ weight_decay: Weight decay for optimizer
61
+ max_grad_norm: Maximum gradient norm for clipping
62
+ mixed_precision: Always "bf16" (only supported precision)
63
+ seed: Random seed for reproducibility
64
+ output_dir: Directory to save checkpoints and logs
65
+ """
66
+ # Fixed for turbo model
67
+ shift: float = 3.0 # Fixed: turbo uses shift=3.0
68
+ num_inference_steps: int = 8 # Fixed: turbo uses 8 steps
69
+ learning_rate: float = 1e-4
70
+ batch_size: int = 1
71
+ gradient_accumulation_steps: int = 4
72
+ max_epochs: int = 100
73
+ save_every_n_epochs: int = 10
74
+ warmup_steps: int = 100
75
+ weight_decay: float = 0.01
76
+ max_grad_norm: float = 1.0
77
+ mixed_precision: str = "bf16" # Fixed: only bf16 supported
78
+ seed: int = 42
79
+ output_dir: str = "./lora_output"
80
+
81
+ # Data loading
82
+ num_workers: int = 4
83
+ pin_memory: bool = True
84
+
85
+ # Logging
86
+ log_every_n_steps: int = 10
87
+
88
+ def to_dict(self):
89
+ """Convert to dictionary."""
90
+ return {
91
+ "shift": self.shift,
92
+ "num_inference_steps": self.num_inference_steps,
93
+ "learning_rate": self.learning_rate,
94
+ "batch_size": self.batch_size,
95
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
96
+ "max_epochs": self.max_epochs,
97
+ "save_every_n_epochs": self.save_every_n_epochs,
98
+ "warmup_steps": self.warmup_steps,
99
+ "weight_decay": self.weight_decay,
100
+ "max_grad_norm": self.max_grad_norm,
101
+ "mixed_precision": self.mixed_precision,
102
+ "seed": self.seed,
103
+ "output_dir": self.output_dir,
104
+ "num_workers": self.num_workers,
105
+ "pin_memory": self.pin_memory,
106
+ "log_every_n_steps": self.log_every_n_steps,
107
+ }
acestep/training/data_module.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Lightning DataModule for LoRA Training
3
+
4
+ Handles data loading and preprocessing for training ACE-Step LoRA adapters.
5
+ Supports both raw audio loading and preprocessed tensor loading.
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import random
11
+ from typing import Optional, List, Dict, Any, Tuple
12
+ from loguru import logger
13
+
14
+ import torch
15
+ import torchaudio
16
+ from torch.utils.data import Dataset, DataLoader
17
+
18
+ try:
19
+ from lightning.pytorch import LightningDataModule
20
+ LIGHTNING_AVAILABLE = True
21
+ except ImportError:
22
+ LIGHTNING_AVAILABLE = False
23
+ logger.warning("Lightning not installed. Training module will not be available.")
24
+ # Create a dummy class for type hints
25
+ class LightningDataModule:
26
+ pass
27
+
28
+
29
+ # ============================================================================
30
+ # Preprocessed Tensor Dataset (Recommended for Training)
31
+ # ============================================================================
32
+
33
+ class PreprocessedTensorDataset(Dataset):
34
+ """Dataset that loads preprocessed tensor files.
35
+
36
+ This is the recommended dataset for training as all tensors are pre-computed:
37
+ - target_latents: VAE-encoded audio [T, 64]
38
+ - encoder_hidden_states: Condition encoder output [L, D]
39
+ - encoder_attention_mask: Condition mask [L]
40
+ - context_latents: Source context [T, 65]
41
+ - attention_mask: Audio latent mask [T]
42
+
43
+ No VAE/text encoder needed during training - just load tensors directly!
44
+ """
45
+
46
+ def __init__(self, tensor_dir: str):
47
+ """Initialize from a directory of preprocessed .pt files.
48
+
49
+ Args:
50
+ tensor_dir: Directory containing preprocessed .pt files and manifest.json
51
+ """
52
+ self.tensor_dir = tensor_dir
53
+ self.sample_paths = []
54
+
55
+ # Load manifest if exists
56
+ manifest_path = os.path.join(tensor_dir, "manifest.json")
57
+ if os.path.exists(manifest_path):
58
+ with open(manifest_path, 'r') as f:
59
+ manifest = json.load(f)
60
+ self.sample_paths = manifest.get("samples", [])
61
+ else:
62
+ # Fallback: scan directory for .pt files
63
+ for f in os.listdir(tensor_dir):
64
+ if f.endswith('.pt') and f != "manifest.json":
65
+ self.sample_paths.append(os.path.join(tensor_dir, f))
66
+
67
+ # Validate paths
68
+ self.valid_paths = [p for p in self.sample_paths if os.path.exists(p)]
69
+
70
+ if len(self.valid_paths) != len(self.sample_paths):
71
+ logger.warning(f"Some tensor files not found: {len(self.sample_paths) - len(self.valid_paths)} missing")
72
+
73
+ logger.info(f"PreprocessedTensorDataset: {len(self.valid_paths)} samples from {tensor_dir}")
74
+
75
+ def __len__(self) -> int:
76
+ return len(self.valid_paths)
77
+
78
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
79
+ """Load a preprocessed tensor file.
80
+
81
+ Returns:
82
+ Dictionary containing all pre-computed tensors for training
83
+ """
84
+ tensor_path = self.valid_paths[idx]
85
+ data = torch.load(tensor_path, map_location='cpu')
86
+
87
+ return {
88
+ "target_latents": data["target_latents"], # [T, 64]
89
+ "attention_mask": data["attention_mask"], # [T]
90
+ "encoder_hidden_states": data["encoder_hidden_states"], # [L, D]
91
+ "encoder_attention_mask": data["encoder_attention_mask"], # [L]
92
+ "context_latents": data["context_latents"], # [T, 65]
93
+ "metadata": data.get("metadata", {}),
94
+ }
95
+
96
+
97
+ def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
98
+ """Collate function for preprocessed tensor batches.
99
+
100
+ Handles variable-length tensors by padding to the longest in the batch.
101
+
102
+ Args:
103
+ batch: List of sample dictionaries with pre-computed tensors
104
+
105
+ Returns:
106
+ Batched dictionary with all tensors stacked
107
+ """
108
+ # Get max lengths
109
+ max_latent_len = max(s["target_latents"].shape[0] for s in batch)
110
+ max_encoder_len = max(s["encoder_hidden_states"].shape[0] for s in batch)
111
+
112
+ # Pad and stack tensors
113
+ target_latents = []
114
+ attention_masks = []
115
+ encoder_hidden_states = []
116
+ encoder_attention_masks = []
117
+ context_latents = []
118
+
119
+ for sample in batch:
120
+ # Pad target_latents [T, 64] -> [max_T, 64]
121
+ tl = sample["target_latents"]
122
+ if tl.shape[0] < max_latent_len:
123
+ pad = torch.zeros(max_latent_len - tl.shape[0], tl.shape[1])
124
+ tl = torch.cat([tl, pad], dim=0)
125
+ target_latents.append(tl)
126
+
127
+ # Pad attention_mask [T] -> [max_T]
128
+ am = sample["attention_mask"]
129
+ if am.shape[0] < max_latent_len:
130
+ pad = torch.zeros(max_latent_len - am.shape[0])
131
+ am = torch.cat([am, pad], dim=0)
132
+ attention_masks.append(am)
133
+
134
+ # Pad context_latents [T, 65] -> [max_T, 65]
135
+ cl = sample["context_latents"]
136
+ if cl.shape[0] < max_latent_len:
137
+ pad = torch.zeros(max_latent_len - cl.shape[0], cl.shape[1])
138
+ cl = torch.cat([cl, pad], dim=0)
139
+ context_latents.append(cl)
140
+
141
+ # Pad encoder_hidden_states [L, D] -> [max_L, D]
142
+ ehs = sample["encoder_hidden_states"]
143
+ if ehs.shape[0] < max_encoder_len:
144
+ pad = torch.zeros(max_encoder_len - ehs.shape[0], ehs.shape[1])
145
+ ehs = torch.cat([ehs, pad], dim=0)
146
+ encoder_hidden_states.append(ehs)
147
+
148
+ # Pad encoder_attention_mask [L] -> [max_L]
149
+ eam = sample["encoder_attention_mask"]
150
+ if eam.shape[0] < max_encoder_len:
151
+ pad = torch.zeros(max_encoder_len - eam.shape[0])
152
+ eam = torch.cat([eam, pad], dim=0)
153
+ encoder_attention_masks.append(eam)
154
+
155
+ return {
156
+ "target_latents": torch.stack(target_latents), # [B, T, 64]
157
+ "attention_mask": torch.stack(attention_masks), # [B, T]
158
+ "encoder_hidden_states": torch.stack(encoder_hidden_states), # [B, L, D]
159
+ "encoder_attention_mask": torch.stack(encoder_attention_masks), # [B, L]
160
+ "context_latents": torch.stack(context_latents), # [B, T, 65]
161
+ "metadata": [s["metadata"] for s in batch],
162
+ }
163
+
164
+
165
+ class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object):
166
+ """DataModule for preprocessed tensor files.
167
+
168
+ This is the recommended DataModule for training. It loads pre-computed tensors
169
+ directly without needing VAE, text encoder, or condition encoder at training time.
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ tensor_dir: str,
175
+ batch_size: int = 1,
176
+ num_workers: int = 4,
177
+ pin_memory: bool = True,
178
+ val_split: float = 0.0,
179
+ ):
180
+ """Initialize the data module.
181
+
182
+ Args:
183
+ tensor_dir: Directory containing preprocessed .pt files
184
+ batch_size: Training batch size
185
+ num_workers: Number of data loading workers
186
+ pin_memory: Whether to pin memory for faster GPU transfer
187
+ val_split: Fraction of data for validation (0 = no validation)
188
+ """
189
+ if LIGHTNING_AVAILABLE:
190
+ super().__init__()
191
+
192
+ self.tensor_dir = tensor_dir
193
+ self.batch_size = batch_size
194
+ self.num_workers = num_workers
195
+ self.pin_memory = pin_memory
196
+ self.val_split = val_split
197
+
198
+ self.train_dataset = None
199
+ self.val_dataset = None
200
+
201
+ def setup(self, stage: Optional[str] = None):
202
+ """Setup datasets."""
203
+ if stage == 'fit' or stage is None:
204
+ # Create full dataset
205
+ full_dataset = PreprocessedTensorDataset(self.tensor_dir)
206
+
207
+ # Split if validation requested
208
+ if self.val_split > 0 and len(full_dataset) > 1:
209
+ n_val = max(1, int(len(full_dataset) * self.val_split))
210
+ n_train = len(full_dataset) - n_val
211
+
212
+ self.train_dataset, self.val_dataset = torch.utils.data.random_split(
213
+ full_dataset, [n_train, n_val]
214
+ )
215
+ else:
216
+ self.train_dataset = full_dataset
217
+ self.val_dataset = None
218
+
219
+ def train_dataloader(self) -> DataLoader:
220
+ """Create training dataloader."""
221
+ return DataLoader(
222
+ self.train_dataset,
223
+ batch_size=self.batch_size,
224
+ shuffle=True,
225
+ num_workers=self.num_workers,
226
+ pin_memory=self.pin_memory,
227
+ collate_fn=collate_preprocessed_batch,
228
+ drop_last=True,
229
+ )
230
+
231
+ def val_dataloader(self) -> Optional[DataLoader]:
232
+ """Create validation dataloader."""
233
+ if self.val_dataset is None:
234
+ return None
235
+
236
+ return DataLoader(
237
+ self.val_dataset,
238
+ batch_size=self.batch_size,
239
+ shuffle=False,
240
+ num_workers=self.num_workers,
241
+ pin_memory=self.pin_memory,
242
+ collate_fn=collate_preprocessed_batch,
243
+ )
244
+
245
+
246
+ # ============================================================================
247
+ # Raw Audio Dataset (Legacy - for backward compatibility)
248
+ # ============================================================================
249
+
250
+ class AceStepTrainingDataset(Dataset):
251
+ """Dataset for ACE-Step LoRA training from raw audio.
252
+
253
+ DEPRECATED: Use PreprocessedTensorDataset instead for better performance.
254
+
255
+ Audio Format Requirements (handled automatically):
256
+ - Sample rate: 48kHz (resampled if different)
257
+ - Channels: Stereo (2 channels, mono is duplicated)
258
+ - Max duration: 240 seconds (4 minutes)
259
+ - Min duration: 5 seconds (padded if shorter)
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ samples: List[Dict[str, Any]],
265
+ dit_handler,
266
+ max_duration: float = 240.0,
267
+ target_sample_rate: int = 48000,
268
+ ):
269
+ """Initialize the dataset."""
270
+ self.samples = samples
271
+ self.dit_handler = dit_handler
272
+ self.max_duration = max_duration
273
+ self.target_sample_rate = target_sample_rate
274
+
275
+ self.valid_samples = self._validate_samples()
276
+ logger.info(f"Dataset initialized with {len(self.valid_samples)} valid samples")
277
+
278
+ def _validate_samples(self) -> List[Dict[str, Any]]:
279
+ """Validate and filter samples."""
280
+ valid = []
281
+ for i, sample in enumerate(self.samples):
282
+ audio_path = sample.get("audio_path", "")
283
+ if not audio_path or not os.path.exists(audio_path):
284
+ logger.warning(f"Sample {i}: Audio file not found: {audio_path}")
285
+ continue
286
+
287
+ if not sample.get("caption"):
288
+ logger.warning(f"Sample {i}: Missing caption")
289
+ continue
290
+
291
+ valid.append(sample)
292
+
293
+ return valid
294
+
295
+ def __len__(self) -> int:
296
+ return len(self.valid_samples)
297
+
298
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
299
+ """Get a single training sample."""
300
+ sample = self.valid_samples[idx]
301
+
302
+ audio_path = sample["audio_path"]
303
+ audio, sr = torchaudio.load(audio_path)
304
+
305
+ # Resample to 48kHz
306
+ if sr != self.target_sample_rate:
307
+ resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
308
+ audio = resampler(audio)
309
+
310
+ # Convert to stereo
311
+ if audio.shape[0] == 1:
312
+ audio = audio.repeat(2, 1)
313
+ elif audio.shape[0] > 2:
314
+ audio = audio[:2, :]
315
+
316
+ # Truncate/pad
317
+ max_samples = int(self.max_duration * self.target_sample_rate)
318
+ if audio.shape[1] > max_samples:
319
+ audio = audio[:, :max_samples]
320
+
321
+ min_samples = int(5.0 * self.target_sample_rate)
322
+ if audio.shape[1] < min_samples:
323
+ padding = min_samples - audio.shape[1]
324
+ audio = torch.nn.functional.pad(audio, (0, padding))
325
+
326
+ return {
327
+ "audio": audio,
328
+ "caption": sample.get("caption", ""),
329
+ "lyrics": sample.get("lyrics", "[Instrumental]"),
330
+ "metadata": {
331
+ "caption": sample.get("caption", ""),
332
+ "lyrics": sample.get("lyrics", "[Instrumental]"),
333
+ "bpm": sample.get("bpm"),
334
+ "keyscale": sample.get("keyscale", ""),
335
+ "timesignature": sample.get("timesignature", ""),
336
+ "duration": sample.get("duration", audio.shape[1] / self.target_sample_rate),
337
+ "language": sample.get("language", "instrumental"),
338
+ "is_instrumental": sample.get("is_instrumental", True),
339
+ },
340
+ "audio_path": audio_path,
341
+ }
342
+
343
+
344
+ def collate_training_batch(batch: List[Dict]) -> Dict[str, Any]:
345
+ """Collate function for raw audio batches (legacy)."""
346
+ max_len = max(sample["audio"].shape[1] for sample in batch)
347
+
348
+ padded_audio = []
349
+ attention_masks = []
350
+
351
+ for sample in batch:
352
+ audio = sample["audio"]
353
+ audio_len = audio.shape[1]
354
+
355
+ if audio_len < max_len:
356
+ padding = max_len - audio_len
357
+ audio = torch.nn.functional.pad(audio, (0, padding))
358
+
359
+ padded_audio.append(audio)
360
+
361
+ mask = torch.ones(max_len)
362
+ if audio_len < max_len:
363
+ mask[audio_len:] = 0
364
+ attention_masks.append(mask)
365
+
366
+ return {
367
+ "audio": torch.stack(padded_audio),
368
+ "attention_mask": torch.stack(attention_masks),
369
+ "captions": [s["caption"] for s in batch],
370
+ "lyrics": [s["lyrics"] for s in batch],
371
+ "metadata": [s["metadata"] for s in batch],
372
+ "audio_paths": [s["audio_path"] for s in batch],
373
+ }
374
+
375
+
376
+ class AceStepDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object):
377
+ """DataModule for raw audio loading (legacy).
378
+
379
+ DEPRECATED: Use PreprocessedDataModule for better training performance.
380
+ """
381
+
382
+ def __init__(
383
+ self,
384
+ samples: List[Dict[str, Any]],
385
+ dit_handler,
386
+ batch_size: int = 1,
387
+ num_workers: int = 4,
388
+ pin_memory: bool = True,
389
+ max_duration: float = 240.0,
390
+ val_split: float = 0.0,
391
+ ):
392
+ if LIGHTNING_AVAILABLE:
393
+ super().__init__()
394
+
395
+ self.samples = samples
396
+ self.dit_handler = dit_handler
397
+ self.batch_size = batch_size
398
+ self.num_workers = num_workers
399
+ self.pin_memory = pin_memory
400
+ self.max_duration = max_duration
401
+ self.val_split = val_split
402
+
403
+ self.train_dataset = None
404
+ self.val_dataset = None
405
+
406
+ def setup(self, stage: Optional[str] = None):
407
+ if stage == 'fit' or stage is None:
408
+ if self.val_split > 0 and len(self.samples) > 1:
409
+ n_val = max(1, int(len(self.samples) * self.val_split))
410
+
411
+ indices = list(range(len(self.samples)))
412
+ random.shuffle(indices)
413
+
414
+ val_indices = indices[:n_val]
415
+ train_indices = indices[n_val:]
416
+
417
+ train_samples = [self.samples[i] for i in train_indices]
418
+ val_samples = [self.samples[i] for i in val_indices]
419
+
420
+ self.train_dataset = AceStepTrainingDataset(
421
+ train_samples, self.dit_handler, self.max_duration
422
+ )
423
+ self.val_dataset = AceStepTrainingDataset(
424
+ val_samples, self.dit_handler, self.max_duration
425
+ )
426
+ else:
427
+ self.train_dataset = AceStepTrainingDataset(
428
+ self.samples, self.dit_handler, self.max_duration
429
+ )
430
+ self.val_dataset = None
431
+
432
+ def train_dataloader(self) -> DataLoader:
433
+ return DataLoader(
434
+ self.train_dataset,
435
+ batch_size=self.batch_size,
436
+ shuffle=True,
437
+ num_workers=self.num_workers,
438
+ pin_memory=self.pin_memory,
439
+ collate_fn=collate_training_batch,
440
+ drop_last=True,
441
+ )
442
+
443
+ def val_dataloader(self) -> Optional[DataLoader]:
444
+ if self.val_dataset is None:
445
+ return None
446
+
447
+ return DataLoader(
448
+ self.val_dataset,
449
+ batch_size=self.batch_size,
450
+ shuffle=False,
451
+ num_workers=self.num_workers,
452
+ pin_memory=self.pin_memory,
453
+ collate_fn=collate_training_batch,
454
+ )
455
+
456
+
457
+ def load_dataset_from_json(json_path: str) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
458
+ """Load a dataset from JSON file."""
459
+ with open(json_path, 'r', encoding='utf-8') as f:
460
+ data = json.load(f)
461
+
462
+ metadata = data.get("metadata", {})
463
+ samples = data.get("samples", [])
464
+
465
+ return samples, metadata
acestep/training/dataset_builder.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Builder for LoRA Training
3
+
4
+ Provides functionality to:
5
+ 1. Scan directories for audio files
6
+ 2. Auto-label audio using LLM
7
+ 3. Preview and edit metadata
8
+ 4. Save datasets in JSON format
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import uuid
14
+ from datetime import datetime
15
+ from dataclasses import dataclass, field, asdict
16
+ from typing import List, Dict, Any, Optional, Tuple
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ import torchaudio
21
+ from loguru import logger
22
+
23
+
24
+ # Supported audio formats
25
+ SUPPORTED_AUDIO_FORMATS = {'.wav', '.mp3', '.flac', '.ogg', '.opus'}
26
+
27
+
28
+ @dataclass
29
+ class AudioSample:
30
+ """Represents a single audio sample with its metadata.
31
+
32
+ Attributes:
33
+ id: Unique identifier for the sample
34
+ audio_path: Path to the audio file
35
+ filename: Original filename
36
+ caption: Generated or user-provided caption describing the music
37
+ lyrics: Lyrics or "[Instrumental]" for instrumental tracks
38
+ bpm: Beats per minute
39
+ keyscale: Musical key (e.g., "C Major", "Am")
40
+ timesignature: Time signature (e.g., "4" for 4/4)
41
+ duration: Duration in seconds
42
+ language: Vocal language or "instrumental"
43
+ is_instrumental: Whether the track is instrumental
44
+ custom_tag: User-defined activation tag for LoRA
45
+ labeled: Whether the sample has been labeled
46
+ """
47
+ id: str = ""
48
+ audio_path: str = ""
49
+ filename: str = ""
50
+ caption: str = ""
51
+ lyrics: str = "[Instrumental]"
52
+ bpm: Optional[int] = None
53
+ keyscale: str = ""
54
+ timesignature: str = ""
55
+ duration: float = 0.0
56
+ language: str = "instrumental"
57
+ is_instrumental: bool = True
58
+ custom_tag: str = ""
59
+ labeled: bool = False
60
+
61
+ def __post_init__(self):
62
+ if not self.id:
63
+ self.id = str(uuid.uuid4())[:8]
64
+
65
+ def to_dict(self) -> Dict[str, Any]:
66
+ """Convert to dictionary."""
67
+ return asdict(self)
68
+
69
+ @classmethod
70
+ def from_dict(cls, data: Dict[str, Any]) -> "AudioSample":
71
+ """Create from dictionary."""
72
+ return cls(**data)
73
+
74
+ def get_full_caption(self, tag_position: str = "prepend") -> str:
75
+ """Get caption with custom tag applied.
76
+
77
+ Args:
78
+ tag_position: Where to place the custom tag ("prepend", "append", "replace")
79
+
80
+ Returns:
81
+ Caption with custom tag applied
82
+ """
83
+ if not self.custom_tag:
84
+ return self.caption
85
+
86
+ if tag_position == "prepend":
87
+ return f"{self.custom_tag}, {self.caption}" if self.caption else self.custom_tag
88
+ elif tag_position == "append":
89
+ return f"{self.caption}, {self.custom_tag}" if self.caption else self.custom_tag
90
+ elif tag_position == "replace":
91
+ return self.custom_tag
92
+ else:
93
+ return self.caption
94
+
95
+
96
+ @dataclass
97
+ class DatasetMetadata:
98
+ """Metadata for the entire dataset.
99
+
100
+ Attributes:
101
+ name: Dataset name
102
+ custom_tag: Default custom tag for all samples
103
+ tag_position: Where to place custom tag ("prepend", "append", "replace")
104
+ created_at: Creation timestamp
105
+ num_samples: Number of samples in the dataset
106
+ all_instrumental: Whether all tracks are instrumental
107
+ """
108
+ name: str = "untitled_dataset"
109
+ custom_tag: str = ""
110
+ tag_position: str = "prepend"
111
+ created_at: str = ""
112
+ num_samples: int = 0
113
+ all_instrumental: bool = True
114
+
115
+ def __post_init__(self):
116
+ if not self.created_at:
117
+ self.created_at = datetime.now().isoformat()
118
+
119
+ def to_dict(self) -> Dict[str, Any]:
120
+ """Convert to dictionary."""
121
+ return asdict(self)
122
+
123
+
124
+ class DatasetBuilder:
125
+ """Builder for creating training datasets from audio files.
126
+
127
+ This class handles:
128
+ - Scanning directories for audio files
129
+ - Auto-labeling using LLM
130
+ - Managing sample metadata
131
+ - Saving/loading datasets
132
+ """
133
+
134
+ def __init__(self):
135
+ """Initialize the dataset builder."""
136
+ self.samples: List[AudioSample] = []
137
+ self.metadata = DatasetMetadata()
138
+ self._current_dir: str = ""
139
+
140
+ def scan_directory(self, directory: str) -> Tuple[List[AudioSample], str]:
141
+ """Scan a directory for audio files.
142
+
143
+ Args:
144
+ directory: Path to directory containing audio files
145
+
146
+ Returns:
147
+ Tuple of (list of AudioSample objects, status message)
148
+ """
149
+ if not os.path.exists(directory):
150
+ return [], f"❌ Directory not found: {directory}"
151
+
152
+ if not os.path.isdir(directory):
153
+ return [], f"❌ Not a directory: {directory}"
154
+
155
+ self._current_dir = directory
156
+ self.samples = []
157
+
158
+ # Scan for audio files
159
+ audio_files = []
160
+ for root, dirs, files in os.walk(directory):
161
+ for file in files:
162
+ ext = os.path.splitext(file)[1].lower()
163
+ if ext in SUPPORTED_AUDIO_FORMATS:
164
+ audio_files.append(os.path.join(root, file))
165
+
166
+ if not audio_files:
167
+ return [], f"❌ No audio files found in {directory}\nSupported formats: {', '.join(SUPPORTED_AUDIO_FORMATS)}"
168
+
169
+ # Sort files by name
170
+ audio_files.sort()
171
+
172
+ # Create AudioSample objects
173
+ for audio_path in audio_files:
174
+ try:
175
+ # Get duration
176
+ duration = self._get_audio_duration(audio_path)
177
+
178
+ sample = AudioSample(
179
+ audio_path=audio_path,
180
+ filename=os.path.basename(audio_path),
181
+ duration=duration,
182
+ is_instrumental=self.metadata.all_instrumental,
183
+ custom_tag=self.metadata.custom_tag,
184
+ )
185
+ self.samples.append(sample)
186
+ except Exception as e:
187
+ logger.warning(f"Failed to process {audio_path}: {e}")
188
+
189
+ self.metadata.num_samples = len(self.samples)
190
+
191
+ status = f"✅ Found {len(self.samples)} audio files in {directory}"
192
+ return self.samples, status
193
+
194
+ def _get_audio_duration(self, audio_path: str) -> float:
195
+ """Get the duration of an audio file in seconds.
196
+
197
+ Args:
198
+ audio_path: Path to audio file
199
+
200
+ Returns:
201
+ Duration in seconds
202
+ """
203
+ try:
204
+ info = torchaudio.info(audio_path)
205
+ return info.num_frames / info.sample_rate
206
+ except Exception as e:
207
+ logger.warning(f"Failed to get duration for {audio_path}: {e}")
208
+ return 0.0
209
+
210
+ def label_sample(
211
+ self,
212
+ sample_idx: int,
213
+ dit_handler,
214
+ llm_handler,
215
+ progress_callback=None,
216
+ ) -> Tuple[AudioSample, str]:
217
+ """Label a single sample using the LLM.
218
+
219
+ Args:
220
+ sample_idx: Index of sample to label
221
+ dit_handler: DiT handler for audio encoding
222
+ llm_handler: LLM handler for caption generation
223
+ progress_callback: Optional callback for progress updates
224
+
225
+ Returns:
226
+ Tuple of (updated AudioSample, status message)
227
+ """
228
+ if sample_idx < 0 or sample_idx >= len(self.samples):
229
+ return None, f"❌ Invalid sample index: {sample_idx}"
230
+
231
+ sample = self.samples[sample_idx]
232
+
233
+ try:
234
+ if progress_callback:
235
+ progress_callback(f"Processing: {sample.filename}")
236
+
237
+ # Step 1: Load and encode audio to get audio codes
238
+ audio_codes = self._get_audio_codes(sample.audio_path, dit_handler)
239
+
240
+ if not audio_codes:
241
+ return sample, f"❌ Failed to encode audio: {sample.filename}"
242
+
243
+ if progress_callback:
244
+ progress_callback(f"Generating metadata for: {sample.filename}")
245
+
246
+ # Step 2: Use LLM to understand the audio
247
+ metadata, status = llm_handler.understand_audio_from_codes(
248
+ audio_codes=audio_codes,
249
+ temperature=0.7,
250
+ use_constrained_decoding=True,
251
+ )
252
+
253
+ if not metadata:
254
+ return sample, f"❌ LLM labeling failed: {status}"
255
+
256
+ # Step 3: Update sample with generated metadata
257
+ sample.caption = metadata.get('caption', '')
258
+ sample.bpm = self._parse_int(metadata.get('bpm'))
259
+ sample.keyscale = metadata.get('keyscale', '')
260
+ sample.timesignature = metadata.get('timesignature', '')
261
+ sample.language = metadata.get('vocal_language', 'instrumental')
262
+
263
+ # Handle lyrics based on instrumental flag
264
+ if sample.is_instrumental:
265
+ sample.lyrics = "[Instrumental]"
266
+ sample.language = "instrumental"
267
+ else:
268
+ sample.lyrics = metadata.get('lyrics', '')
269
+
270
+ # NOTE: Duration is NOT overwritten from LM metadata.
271
+ # We keep the real audio duration obtained from torchaudio during scan.
272
+
273
+ sample.labeled = True
274
+ self.samples[sample_idx] = sample
275
+
276
+ return sample, f"✅ Labeled: {sample.filename}"
277
+
278
+ except Exception as e:
279
+ logger.exception(f"Error labeling sample {sample.filename}")
280
+ return sample, f"❌ Error: {str(e)}"
281
+
282
+ def label_all_samples(
283
+ self,
284
+ dit_handler,
285
+ llm_handler,
286
+ progress_callback=None,
287
+ ) -> Tuple[List[AudioSample], str]:
288
+ """Label all samples in the dataset.
289
+
290
+ Args:
291
+ dit_handler: DiT handler for audio encoding
292
+ llm_handler: LLM handler for caption generation
293
+ progress_callback: Optional callback for progress updates
294
+
295
+ Returns:
296
+ Tuple of (list of updated samples, status message)
297
+ """
298
+ if not self.samples:
299
+ return [], "❌ No samples to label. Please scan a directory first."
300
+
301
+ success_count = 0
302
+ fail_count = 0
303
+
304
+ for i, sample in enumerate(self.samples):
305
+ if progress_callback:
306
+ progress_callback(f"Labeling {i+1}/{len(self.samples)}: {sample.filename}")
307
+
308
+ _, status = self.label_sample(i, dit_handler, llm_handler, progress_callback)
309
+
310
+ if "✅" in status:
311
+ success_count += 1
312
+ else:
313
+ fail_count += 1
314
+
315
+ status_msg = f"✅ Labeled {success_count}/{len(self.samples)} samples"
316
+ if fail_count > 0:
317
+ status_msg += f" ({fail_count} failed)"
318
+
319
+ return self.samples, status_msg
320
+
321
+ def _get_audio_codes(self, audio_path: str, dit_handler) -> Optional[str]:
322
+ """Encode audio to get semantic codes for LLM understanding.
323
+
324
+ Args:
325
+ audio_path: Path to audio file
326
+ dit_handler: DiT handler with VAE and tokenizer
327
+
328
+ Returns:
329
+ Audio codes string or None if failed
330
+ """
331
+ try:
332
+ # Check if handler has required methods
333
+ if not hasattr(dit_handler, 'convert_src_audio_to_codes'):
334
+ logger.error("DiT handler missing convert_src_audio_to_codes method")
335
+ return None
336
+
337
+ # Use handler's method to convert audio to codes
338
+ codes_string = dit_handler.convert_src_audio_to_codes(audio_path)
339
+
340
+ if codes_string and not codes_string.startswith("❌"):
341
+ return codes_string
342
+ else:
343
+ logger.warning(f"Failed to convert audio to codes: {codes_string}")
344
+ return None
345
+
346
+ except Exception as e:
347
+ logger.exception(f"Error encoding audio {audio_path}")
348
+ return None
349
+
350
+ def _parse_int(self, value: Any) -> Optional[int]:
351
+ """Safely parse an integer value."""
352
+ if value is None or value == "N/A" or value == "":
353
+ return None
354
+ try:
355
+ return int(value)
356
+ except (ValueError, TypeError):
357
+ return None
358
+
359
+ def update_sample(self, sample_idx: int, **kwargs) -> Tuple[AudioSample, str]:
360
+ """Update a sample's metadata.
361
+
362
+ Args:
363
+ sample_idx: Index of sample to update
364
+ **kwargs: Fields to update
365
+
366
+ Returns:
367
+ Tuple of (updated sample, status message)
368
+ """
369
+ if sample_idx < 0 or sample_idx >= len(self.samples):
370
+ return None, f"❌ Invalid sample index: {sample_idx}"
371
+
372
+ sample = self.samples[sample_idx]
373
+
374
+ for key, value in kwargs.items():
375
+ if hasattr(sample, key):
376
+ setattr(sample, key, value)
377
+
378
+ self.samples[sample_idx] = sample
379
+ return sample, f"✅ Updated: {sample.filename}"
380
+
381
+ def set_custom_tag(self, custom_tag: str, tag_position: str = "prepend"):
382
+ """Set the custom tag for all samples.
383
+
384
+ Args:
385
+ custom_tag: Custom activation tag
386
+ tag_position: Where to place tag ("prepend", "append", "replace")
387
+ """
388
+ self.metadata.custom_tag = custom_tag
389
+ self.metadata.tag_position = tag_position
390
+
391
+ for sample in self.samples:
392
+ sample.custom_tag = custom_tag
393
+
394
+ def set_all_instrumental(self, is_instrumental: bool):
395
+ """Set instrumental flag for all samples.
396
+
397
+ Args:
398
+ is_instrumental: Whether all tracks are instrumental
399
+ """
400
+ self.metadata.all_instrumental = is_instrumental
401
+
402
+ for sample in self.samples:
403
+ sample.is_instrumental = is_instrumental
404
+ if is_instrumental:
405
+ sample.lyrics = "[Instrumental]"
406
+ sample.language = "instrumental"
407
+
408
+ def get_sample_count(self) -> int:
409
+ """Get the number of samples in the dataset."""
410
+ return len(self.samples)
411
+
412
+ def get_labeled_count(self) -> int:
413
+ """Get the number of labeled samples."""
414
+ return sum(1 for s in self.samples if s.labeled)
415
+
416
+ def save_dataset(self, output_path: str, dataset_name: str = None) -> str:
417
+ """Save the dataset to a JSON file.
418
+
419
+ Args:
420
+ output_path: Path to save the dataset JSON
421
+ dataset_name: Optional name for the dataset
422
+
423
+ Returns:
424
+ Status message
425
+ """
426
+ if not self.samples:
427
+ return "❌ No samples to save"
428
+
429
+ if dataset_name:
430
+ self.metadata.name = dataset_name
431
+
432
+ self.metadata.num_samples = len(self.samples)
433
+ self.metadata.created_at = datetime.now().isoformat()
434
+
435
+ # Build dataset with captions that include custom tags
436
+ dataset = {
437
+ "metadata": self.metadata.to_dict(),
438
+ "samples": []
439
+ }
440
+
441
+ for sample in self.samples:
442
+ sample_dict = sample.to_dict()
443
+ # Apply custom tag to caption based on position
444
+ sample_dict["caption"] = sample.get_full_caption(self.metadata.tag_position)
445
+ dataset["samples"].append(sample_dict)
446
+
447
+ try:
448
+ # Ensure output directory exists
449
+ os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
450
+
451
+ with open(output_path, 'w', encoding='utf-8') as f:
452
+ json.dump(dataset, f, indent=2, ensure_ascii=False)
453
+
454
+ return f"✅ Dataset saved to {output_path}\n{len(self.samples)} samples, tag: '{self.metadata.custom_tag}'"
455
+ except Exception as e:
456
+ logger.exception("Error saving dataset")
457
+ return f"❌ Failed to save dataset: {str(e)}"
458
+
459
+ def load_dataset(self, dataset_path: str) -> Tuple[List[AudioSample], str]:
460
+ """Load a dataset from a JSON file.
461
+
462
+ Args:
463
+ dataset_path: Path to the dataset JSON file
464
+
465
+ Returns:
466
+ Tuple of (list of samples, status message)
467
+ """
468
+ if not os.path.exists(dataset_path):
469
+ return [], f"❌ Dataset not found: {dataset_path}"
470
+
471
+ try:
472
+ with open(dataset_path, 'r', encoding='utf-8') as f:
473
+ data = json.load(f)
474
+
475
+ # Load metadata
476
+ if "metadata" in data:
477
+ meta_dict = data["metadata"]
478
+ self.metadata = DatasetMetadata(
479
+ name=meta_dict.get("name", "untitled"),
480
+ custom_tag=meta_dict.get("custom_tag", ""),
481
+ tag_position=meta_dict.get("tag_position", "prepend"),
482
+ created_at=meta_dict.get("created_at", ""),
483
+ num_samples=meta_dict.get("num_samples", 0),
484
+ all_instrumental=meta_dict.get("all_instrumental", True),
485
+ )
486
+
487
+ # Load samples
488
+ self.samples = []
489
+ for sample_dict in data.get("samples", []):
490
+ sample = AudioSample.from_dict(sample_dict)
491
+ self.samples.append(sample)
492
+
493
+ return self.samples, f"✅ Loaded {len(self.samples)} samples from {dataset_path}"
494
+
495
+ except Exception as e:
496
+ logger.exception("Error loading dataset")
497
+ return [], f"❌ Failed to load dataset: {str(e)}"
498
+
499
+ def get_samples_dataframe_data(self) -> List[List[Any]]:
500
+ """Get samples data in a format suitable for Gradio DataFrame.
501
+
502
+ Returns:
503
+ List of rows for DataFrame display
504
+ """
505
+ rows = []
506
+ for i, sample in enumerate(self.samples):
507
+ rows.append([
508
+ i,
509
+ sample.filename,
510
+ f"{sample.duration:.1f}s",
511
+ "✅" if sample.labeled else "❌",
512
+ sample.bpm or "-",
513
+ sample.keyscale or "-",
514
+ sample.caption[:50] + "..." if len(sample.caption) > 50 else sample.caption or "-",
515
+ ])
516
+ return rows
517
+
518
+ def to_training_format(self) -> List[Dict[str, Any]]:
519
+ """Convert dataset to format suitable for training.
520
+
521
+ Returns:
522
+ List of training sample dictionaries
523
+ """
524
+ training_samples = []
525
+
526
+ for sample in self.samples:
527
+ if not sample.labeled:
528
+ continue
529
+
530
+ training_sample = {
531
+ "audio_path": sample.audio_path,
532
+ "caption": sample.get_full_caption(self.metadata.tag_position),
533
+ "lyrics": sample.lyrics,
534
+ "bpm": sample.bpm,
535
+ "keyscale": sample.keyscale,
536
+ "timesignature": sample.timesignature,
537
+ "duration": sample.duration,
538
+ "language": sample.language,
539
+ "is_instrumental": sample.is_instrumental,
540
+ }
541
+ training_samples.append(training_sample)
542
+
543
+ return training_samples
544
+
545
+ def preprocess_to_tensors(
546
+ self,
547
+ dit_handler,
548
+ output_dir: str,
549
+ max_duration: float = 240.0,
550
+ progress_callback=None,
551
+ ) -> Tuple[List[str], str]:
552
+ """Preprocess all labeled samples to tensor files for efficient training.
553
+
554
+ This method pre-computes all tensors needed by the DiT decoder:
555
+ - target_latents: VAE-encoded audio
556
+ - encoder_hidden_states: Condition encoder output
557
+ - context_latents: Source context (silence_latent + zeros for text2music)
558
+
559
+ Args:
560
+ dit_handler: Initialized DiT handler with model, VAE, and text encoder
561
+ output_dir: Directory to save preprocessed .pt files
562
+ max_duration: Maximum audio duration in seconds (default 240s = 4 min)
563
+ progress_callback: Optional callback for progress updates
564
+
565
+ Returns:
566
+ Tuple of (list of output paths, status message)
567
+ """
568
+ if not self.samples:
569
+ return [], "❌ No samples to preprocess"
570
+
571
+ labeled_samples = [s for s in self.samples if s.labeled]
572
+ if not labeled_samples:
573
+ return [], "❌ No labeled samples to preprocess"
574
+
575
+ # Validate handler
576
+ if dit_handler is None or dit_handler.model is None:
577
+ return [], "❌ Model not initialized. Please initialize the service first."
578
+
579
+ # Create output directory
580
+ os.makedirs(output_dir, exist_ok=True)
581
+
582
+ output_paths = []
583
+ success_count = 0
584
+ fail_count = 0
585
+
586
+ # Get model and components
587
+ model = dit_handler.model
588
+ vae = dit_handler.vae
589
+ text_encoder = dit_handler.text_encoder
590
+ text_tokenizer = dit_handler.text_tokenizer
591
+ silence_latent = dit_handler.silence_latent
592
+ device = dit_handler.device
593
+ dtype = dit_handler.dtype
594
+
595
+ target_sample_rate = 48000
596
+
597
+ for i, sample in enumerate(labeled_samples):
598
+ try:
599
+ if progress_callback:
600
+ progress_callback(f"Preprocessing {i+1}/{len(labeled_samples)}: {sample.filename}")
601
+
602
+ # Step 1: Load and preprocess audio to stereo @ 48kHz
603
+ audio, sr = torchaudio.load(sample.audio_path)
604
+
605
+ # Resample if needed
606
+ if sr != target_sample_rate:
607
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
608
+ audio = resampler(audio)
609
+
610
+ # Convert to stereo
611
+ if audio.shape[0] == 1:
612
+ audio = audio.repeat(2, 1)
613
+ elif audio.shape[0] > 2:
614
+ audio = audio[:2, :]
615
+
616
+ # Truncate to max duration
617
+ max_samples = int(max_duration * target_sample_rate)
618
+ if audio.shape[1] > max_samples:
619
+ audio = audio[:, :max_samples]
620
+
621
+ # Add batch dimension: [2, T] -> [1, 2, T]
622
+ audio = audio.unsqueeze(0).to(device).to(vae.dtype)
623
+
624
+ # Step 2: VAE encode audio to get target_latents
625
+ with torch.no_grad():
626
+ latent = vae.encode(audio).latent_dist.sample()
627
+ # [1, 64, T_latent] -> [1, T_latent, 64]
628
+ target_latents = latent.transpose(1, 2).to(dtype)
629
+
630
+ latent_length = target_latents.shape[1]
631
+
632
+ # Step 3: Create attention mask (all ones for valid audio)
633
+ attention_mask = torch.ones(1, latent_length, device=device, dtype=dtype)
634
+
635
+ # Step 4: Encode caption text
636
+ caption = sample.get_full_caption(self.metadata.tag_position)
637
+ text_inputs = text_tokenizer(
638
+ caption,
639
+ padding="max_length",
640
+ max_length=256,
641
+ truncation=True,
642
+ return_tensors="pt",
643
+ )
644
+ text_input_ids = text_inputs.input_ids.to(device)
645
+ text_attention_mask = text_inputs.attention_mask.to(device).to(dtype)
646
+
647
+ with torch.no_grad():
648
+ text_outputs = text_encoder(text_input_ids)
649
+ text_hidden_states = text_outputs.last_hidden_state.to(dtype)
650
+
651
+ # Step 5: Encode lyrics
652
+ lyrics = sample.lyrics if sample.lyrics else "[Instrumental]"
653
+ lyric_inputs = text_tokenizer(
654
+ lyrics,
655
+ padding="max_length",
656
+ max_length=512,
657
+ truncation=True,
658
+ return_tensors="pt",
659
+ )
660
+ lyric_input_ids = lyric_inputs.input_ids.to(device)
661
+ lyric_attention_mask = lyric_inputs.attention_mask.to(device).to(dtype)
662
+
663
+ with torch.no_grad():
664
+ lyric_hidden_states = text_encoder.embed_tokens(lyric_input_ids).to(dtype)
665
+
666
+ # Step 6: Prepare refer_audio (empty for text2music)
667
+ # Create minimal refer_audio placeholder
668
+ refer_audio_hidden = torch.zeros(1, 1, 64, device=device, dtype=dtype)
669
+ refer_audio_order_mask = torch.zeros(1, device=device, dtype=torch.long)
670
+
671
+ # Step 7: Run model.encoder to get encoder_hidden_states
672
+ with torch.no_grad():
673
+ encoder_hidden_states, encoder_attention_mask = model.encoder(
674
+ text_hidden_states=text_hidden_states,
675
+ text_attention_mask=text_attention_mask,
676
+ lyric_hidden_states=lyric_hidden_states,
677
+ lyric_attention_mask=lyric_attention_mask,
678
+ refer_audio_acoustic_hidden_states_packed=refer_audio_hidden,
679
+ refer_audio_order_mask=refer_audio_order_mask,
680
+ )
681
+
682
+ # Step 8: Build context_latents for text2music
683
+ # For text2music: src_latents = silence_latent, is_covers = 0
684
+ # chunk_masks: 1 = generate, 0 = keep original
685
+ # IMPORTANT: chunk_masks must have same shape as src_latents [B, T, 64]
686
+ # For text2music, we want to generate the entire audio, so chunk_masks = all 1s
687
+ src_latents = silence_latent[:, :latent_length, :].to(dtype)
688
+ if src_latents.shape[0] < 1:
689
+ src_latents = src_latents.expand(1, -1, -1)
690
+
691
+ # Pad or truncate silence_latent to match latent_length
692
+ if src_latents.shape[1] < latent_length:
693
+ pad_len = latent_length - src_latents.shape[1]
694
+ src_latents = torch.cat([
695
+ src_latents,
696
+ silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)
697
+ ], dim=1)
698
+ elif src_latents.shape[1] > latent_length:
699
+ src_latents = src_latents[:, :latent_length, :]
700
+
701
+ # chunk_masks = 1 means "generate this region", 0 = keep original
702
+ # Shape must match src_latents: [B, T, 64] (NOT [B, T, 1])
703
+ # For text2music, generate everything -> all 1s with shape [1, T, 64]
704
+ chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
705
+ # context_latents = [src_latents, chunk_masks] -> [B, T, 128]
706
+ context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
707
+
708
+ # Step 9: Save all tensors to .pt file (squeeze batch dimension for storage)
709
+ output_data = {
710
+ "target_latents": target_latents.squeeze(0).cpu(), # [T, 64]
711
+ "attention_mask": attention_mask.squeeze(0).cpu(), # [T]
712
+ "encoder_hidden_states": encoder_hidden_states.squeeze(0).cpu(), # [L, D]
713
+ "encoder_attention_mask": encoder_attention_mask.squeeze(0).cpu(), # [L]
714
+ "context_latents": context_latents.squeeze(0).cpu(), # [T, 65]
715
+ "metadata": {
716
+ "audio_path": sample.audio_path,
717
+ "filename": sample.filename,
718
+ "caption": caption,
719
+ "lyrics": lyrics,
720
+ "duration": sample.duration,
721
+ "bpm": sample.bpm,
722
+ "keyscale": sample.keyscale,
723
+ "timesignature": sample.timesignature,
724
+ "language": sample.language,
725
+ "is_instrumental": sample.is_instrumental,
726
+ }
727
+ }
728
+
729
+ # Save with sample ID as filename
730
+ output_path = os.path.join(output_dir, f"{sample.id}.pt")
731
+ torch.save(output_data, output_path)
732
+ output_paths.append(output_path)
733
+ success_count += 1
734
+
735
+ except Exception as e:
736
+ logger.exception(f"Error preprocessing {sample.filename}")
737
+ fail_count += 1
738
+ if progress_callback:
739
+ progress_callback(f"❌ Failed: {sample.filename}: {str(e)}")
740
+
741
+ # Save manifest file listing all preprocessed samples
742
+ manifest = {
743
+ "metadata": self.metadata.to_dict(),
744
+ "samples": output_paths,
745
+ "num_samples": len(output_paths),
746
+ }
747
+ manifest_path = os.path.join(output_dir, "manifest.json")
748
+ with open(manifest_path, 'w', encoding='utf-8') as f:
749
+ json.dump(manifest, f, indent=2)
750
+
751
+ status = f"✅ Preprocessed {success_count}/{len(labeled_samples)} samples to {output_dir}"
752
+ if fail_count > 0:
753
+ status += f" ({fail_count} failed)"
754
+
755
+ return output_paths, status
acestep/training/lora_utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoRA Utilities for ACE-Step
3
+
4
+ Provides utilities for injecting LoRA adapters into the DiT decoder model.
5
+ Uses PEFT (Parameter-Efficient Fine-Tuning) library for LoRA implementation.
6
+ """
7
+
8
+ import os
9
+ from typing import Optional, List, Dict, Any, Tuple
10
+ from loguru import logger
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ try:
16
+ from peft import (
17
+ get_peft_model,
18
+ LoraConfig,
19
+ TaskType,
20
+ PeftModel,
21
+ PeftConfig,
22
+ )
23
+ PEFT_AVAILABLE = True
24
+ except ImportError:
25
+ PEFT_AVAILABLE = False
26
+ logger.warning("PEFT library not installed. LoRA training will not be available.")
27
+
28
+ from acestep.training.configs import LoRAConfig
29
+
30
+
31
+ def check_peft_available() -> bool:
32
+ """Check if PEFT library is available."""
33
+ return PEFT_AVAILABLE
34
+
35
+
36
+ def get_dit_target_modules(model) -> List[str]:
37
+ """Get the list of module names in the DiT decoder that can have LoRA applied.
38
+
39
+ Args:
40
+ model: The AceStepConditionGenerationModel
41
+
42
+ Returns:
43
+ List of module names suitable for LoRA
44
+ """
45
+ target_modules = []
46
+
47
+ # Focus on the decoder (DiT) attention layers
48
+ if hasattr(model, 'decoder'):
49
+ for name, module in model.decoder.named_modules():
50
+ # Target attention projection layers
51
+ if any(proj in name for proj in ['q_proj', 'k_proj', 'v_proj', 'o_proj']):
52
+ if isinstance(module, nn.Linear):
53
+ target_modules.append(name)
54
+
55
+ return target_modules
56
+
57
+
58
+ def freeze_non_lora_parameters(model, freeze_encoder: bool = True) -> None:
59
+ """Freeze all non-LoRA parameters in the model.
60
+
61
+ Args:
62
+ model: The model to freeze parameters for
63
+ freeze_encoder: Whether to freeze the encoder (condition encoder)
64
+ """
65
+ # Freeze all parameters first
66
+ for param in model.parameters():
67
+ param.requires_grad = False
68
+
69
+ # Count frozen and trainable parameters
70
+ total_params = 0
71
+ trainable_params = 0
72
+
73
+ for name, param in model.named_parameters():
74
+ total_params += param.numel()
75
+ if param.requires_grad:
76
+ trainable_params += param.numel()
77
+
78
+ logger.info(f"Frozen parameters: {total_params - trainable_params:,}")
79
+ logger.info(f"Trainable parameters: {trainable_params:,}")
80
+
81
+
82
+ def inject_lora_into_dit(
83
+ model,
84
+ lora_config: LoRAConfig,
85
+ ) -> Tuple[Any, Dict[str, Any]]:
86
+ """Inject LoRA adapters into the DiT decoder of the model.
87
+
88
+ Args:
89
+ model: The AceStepConditionGenerationModel
90
+ lora_config: LoRA configuration
91
+
92
+ Returns:
93
+ Tuple of (peft_model, info_dict)
94
+ """
95
+ if not PEFT_AVAILABLE:
96
+ raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft")
97
+
98
+ # Get the decoder (DiT model)
99
+ decoder = model.decoder
100
+
101
+ # Create PEFT LoRA config
102
+ peft_lora_config = LoraConfig(
103
+ r=lora_config.r,
104
+ lora_alpha=lora_config.alpha,
105
+ lora_dropout=lora_config.dropout,
106
+ target_modules=lora_config.target_modules,
107
+ bias=lora_config.bias,
108
+ task_type=TaskType.FEATURE_EXTRACTION, # For diffusion models
109
+ )
110
+
111
+ # Apply LoRA to the decoder
112
+ peft_decoder = get_peft_model(decoder, peft_lora_config)
113
+
114
+ # Replace the decoder in the original model
115
+ model.decoder = peft_decoder
116
+
117
+ # Freeze all non-LoRA parameters
118
+ # Freeze encoder, tokenizer, detokenizer
119
+ for name, param in model.named_parameters():
120
+ # Only keep LoRA parameters trainable
121
+ if 'lora_' not in name:
122
+ param.requires_grad = False
123
+
124
+ # Count parameters
125
+ total_params = sum(p.numel() for p in model.parameters())
126
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
127
+
128
+ info = {
129
+ "total_params": total_params,
130
+ "trainable_params": trainable_params,
131
+ "trainable_ratio": trainable_params / total_params if total_params > 0 else 0,
132
+ "lora_r": lora_config.r,
133
+ "lora_alpha": lora_config.alpha,
134
+ "target_modules": lora_config.target_modules,
135
+ }
136
+
137
+ logger.info(f"LoRA injected into DiT decoder:")
138
+ logger.info(f" Total parameters: {total_params:,}")
139
+ logger.info(f" Trainable parameters: {trainable_params:,} ({info['trainable_ratio']:.2%})")
140
+ logger.info(f" LoRA rank: {lora_config.r}, alpha: {lora_config.alpha}")
141
+
142
+ return model, info
143
+
144
+
145
+ def save_lora_weights(
146
+ model,
147
+ output_dir: str,
148
+ save_full_model: bool = False,
149
+ ) -> str:
150
+ """Save LoRA adapter weights.
151
+
152
+ Args:
153
+ model: Model with LoRA adapters
154
+ output_dir: Directory to save weights
155
+ save_full_model: Whether to save the full model state dict
156
+
157
+ Returns:
158
+ Path to saved weights
159
+ """
160
+ os.makedirs(output_dir, exist_ok=True)
161
+
162
+ if hasattr(model, 'decoder') and hasattr(model.decoder, 'save_pretrained'):
163
+ # Save PEFT adapter
164
+ adapter_path = os.path.join(output_dir, "adapter")
165
+ model.decoder.save_pretrained(adapter_path)
166
+ logger.info(f"LoRA adapter saved to {adapter_path}")
167
+ return adapter_path
168
+ elif save_full_model:
169
+ # Save full model state dict (larger file)
170
+ model_path = os.path.join(output_dir, "model.pt")
171
+ torch.save(model.state_dict(), model_path)
172
+ logger.info(f"Full model state dict saved to {model_path}")
173
+ return model_path
174
+ else:
175
+ # Extract only LoRA parameters
176
+ lora_state_dict = {}
177
+ for name, param in model.named_parameters():
178
+ if 'lora_' in name:
179
+ lora_state_dict[name] = param.data.clone()
180
+
181
+ if not lora_state_dict:
182
+ logger.warning("No LoRA parameters found to save!")
183
+ return ""
184
+
185
+ lora_path = os.path.join(output_dir, "lora_weights.pt")
186
+ torch.save(lora_state_dict, lora_path)
187
+ logger.info(f"LoRA weights saved to {lora_path}")
188
+ return lora_path
189
+
190
+
191
+ def load_lora_weights(
192
+ model,
193
+ lora_path: str,
194
+ lora_config: Optional[LoRAConfig] = None,
195
+ ) -> Any:
196
+ """Load LoRA adapter weights into the model.
197
+
198
+ Args:
199
+ model: The base model (without LoRA)
200
+ lora_path: Path to saved LoRA weights (adapter or .pt file)
201
+ lora_config: LoRA configuration (required if loading from .pt file)
202
+
203
+ Returns:
204
+ Model with LoRA weights loaded
205
+ """
206
+ if not os.path.exists(lora_path):
207
+ raise FileNotFoundError(f"LoRA weights not found: {lora_path}")
208
+
209
+ # Check if it's a PEFT adapter directory
210
+ if os.path.isdir(lora_path):
211
+ if not PEFT_AVAILABLE:
212
+ raise ImportError("PEFT library is required to load adapter. Install with: pip install peft")
213
+
214
+ # Load PEFT adapter
215
+ peft_config = PeftConfig.from_pretrained(lora_path)
216
+ model.decoder = PeftModel.from_pretrained(model.decoder, lora_path)
217
+ logger.info(f"LoRA adapter loaded from {lora_path}")
218
+
219
+ elif lora_path.endswith('.pt'):
220
+ # Load from PyTorch state dict
221
+ if lora_config is None:
222
+ raise ValueError("lora_config is required when loading from .pt file")
223
+
224
+ # First inject LoRA structure
225
+ model, _ = inject_lora_into_dit(model, lora_config)
226
+
227
+ # Load weights
228
+ lora_state_dict = torch.load(lora_path, map_location='cpu')
229
+
230
+ # Load into model
231
+ model_state = model.state_dict()
232
+ for name, param in lora_state_dict.items():
233
+ if name in model_state:
234
+ model_state[name].copy_(param)
235
+ else:
236
+ logger.warning(f"Unexpected key in LoRA state dict: {name}")
237
+
238
+ logger.info(f"LoRA weights loaded from {lora_path}")
239
+
240
+ else:
241
+ raise ValueError(f"Unsupported LoRA weight format: {lora_path}")
242
+
243
+ return model
244
+
245
+
246
+ def merge_lora_weights(model) -> Any:
247
+ """Merge LoRA weights into the base model.
248
+
249
+ This permanently integrates the LoRA adaptations into the model weights.
250
+ After merging, the model can be used without PEFT.
251
+
252
+ Args:
253
+ model: Model with LoRA adapters
254
+
255
+ Returns:
256
+ Model with merged weights
257
+ """
258
+ if hasattr(model, 'decoder') and hasattr(model.decoder, 'merge_and_unload'):
259
+ # PEFT model - merge and unload
260
+ model.decoder = model.decoder.merge_and_unload()
261
+ logger.info("LoRA weights merged into base model")
262
+ else:
263
+ logger.warning("Model does not support LoRA merging")
264
+
265
+ return model
266
+
267
+
268
+ def get_lora_info(model) -> Dict[str, Any]:
269
+ """Get information about LoRA adapters in the model.
270
+
271
+ Args:
272
+ model: Model to inspect
273
+
274
+ Returns:
275
+ Dictionary with LoRA information
276
+ """
277
+ info = {
278
+ "has_lora": False,
279
+ "lora_params": 0,
280
+ "total_params": 0,
281
+ "modules_with_lora": [],
282
+ }
283
+
284
+ total_params = 0
285
+ lora_params = 0
286
+ lora_modules = []
287
+
288
+ for name, param in model.named_parameters():
289
+ total_params += param.numel()
290
+ if 'lora_' in name:
291
+ lora_params += param.numel()
292
+ # Extract module name
293
+ module_name = name.rsplit('.lora_', 1)[0]
294
+ if module_name not in lora_modules:
295
+ lora_modules.append(module_name)
296
+
297
+ info["total_params"] = total_params
298
+ info["lora_params"] = lora_params
299
+ info["has_lora"] = lora_params > 0
300
+ info["modules_with_lora"] = lora_modules
301
+
302
+ if total_params > 0:
303
+ info["lora_ratio"] = lora_params / total_params
304
+
305
+ return info
acestep/training/trainer.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoRA Trainer for ACE-Step
3
+
4
+ Lightning Fabric-based trainer for LoRA fine-tuning of ACE-Step DiT decoder.
5
+ Supports training from preprocessed tensor files for optimal performance.
6
+ """
7
+
8
+ import os
9
+ import time
10
+ from typing import Optional, List, Dict, Any, Tuple, Generator
11
+ from loguru import logger
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.optim import AdamW
17
+ from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR
18
+
19
+ try:
20
+ from lightning.fabric import Fabric
21
+ from lightning.fabric.loggers import TensorBoardLogger
22
+ LIGHTNING_AVAILABLE = True
23
+ except ImportError:
24
+ LIGHTNING_AVAILABLE = False
25
+ logger.warning("Lightning Fabric not installed. Training will use basic training loop.")
26
+
27
+ from acestep.training.configs import LoRAConfig, TrainingConfig
28
+ from acestep.training.lora_utils import inject_lora_into_dit, save_lora_weights, check_peft_available
29
+ from acestep.training.data_module import PreprocessedDataModule
30
+
31
+
32
+ # Turbo model shift=3.0 discrete timesteps (8 steps, same as inference)
33
+ TURBO_SHIFT3_TIMESTEPS = [1.0, 0.9545454545454546, 0.9, 0.8333333333333334, 0.75, 0.6428571428571429, 0.5, 0.3]
34
+
35
+
36
+ def sample_discrete_timestep(bsz, device, dtype):
37
+ """Sample timesteps from discrete turbo shift=3 schedule.
38
+
39
+ For each sample in the batch, randomly select one of the 8 discrete timesteps
40
+ used by the turbo model with shift=3.0.
41
+
42
+ Args:
43
+ bsz: Batch size
44
+ device: Device
45
+ dtype: Data type (should be bfloat16)
46
+
47
+ Returns:
48
+ Tuple of (t, r) where both are the same sampled timestep
49
+ """
50
+ # Randomly select indices for each sample in batch
51
+ indices = torch.randint(0, len(TURBO_SHIFT3_TIMESTEPS), (bsz,), device=device)
52
+
53
+ # Convert to tensor and index
54
+ timesteps_tensor = torch.tensor(TURBO_SHIFT3_TIMESTEPS, device=device, dtype=dtype)
55
+ t = timesteps_tensor[indices]
56
+
57
+ # r = t for this training setup
58
+ r = t
59
+
60
+ return t, r
61
+
62
+
63
+ class PreprocessedLoRAModule(nn.Module):
64
+ """LoRA Training Module using preprocessed tensors.
65
+
66
+ This module trains only the DiT decoder with LoRA adapters.
67
+ All inputs are pre-computed tensors - no VAE or text encoder needed!
68
+
69
+ Training flow:
70
+ 1. Load pre-computed tensors (target_latents, encoder_hidden_states, context_latents)
71
+ 2. Sample noise and timestep
72
+ 3. Forward through decoder (with LoRA)
73
+ 4. Compute flow matching loss
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ model: nn.Module,
79
+ lora_config: LoRAConfig,
80
+ training_config: TrainingConfig,
81
+ device: torch.device,
82
+ dtype: torch.dtype,
83
+ ):
84
+ """Initialize the training module.
85
+
86
+ Args:
87
+ model: The AceStepConditionGenerationModel
88
+ lora_config: LoRA configuration
89
+ training_config: Training configuration
90
+ device: Device to use
91
+ dtype: Data type to use
92
+ """
93
+ super().__init__()
94
+
95
+ self.lora_config = lora_config
96
+ self.training_config = training_config
97
+ self.device = device
98
+ self.dtype = dtype
99
+
100
+ # Inject LoRA into the decoder only
101
+ if check_peft_available():
102
+ self.model, self.lora_info = inject_lora_into_dit(model, lora_config)
103
+ logger.info(f"LoRA injected: {self.lora_info['trainable_params']:,} trainable params")
104
+ else:
105
+ self.model = model
106
+ self.lora_info = {}
107
+ logger.warning("PEFT not available, training without LoRA adapters")
108
+
109
+ # Model config for flow matching
110
+ self.config = model.config
111
+
112
+ # Store training losses
113
+ self.training_losses = []
114
+
115
+ def training_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
116
+ """Single training step using preprocessed tensors.
117
+
118
+ Note: This is a distilled turbo model, NO CFG is used.
119
+
120
+ Args:
121
+ batch: Dictionary containing pre-computed tensors:
122
+ - target_latents: [B, T, 64] - VAE encoded audio
123
+ - attention_mask: [B, T] - Valid audio mask
124
+ - encoder_hidden_states: [B, L, D] - Condition encoder output
125
+ - encoder_attention_mask: [B, L] - Condition mask
126
+ - context_latents: [B, T, 128] - Source context
127
+
128
+ Returns:
129
+ Loss tensor (float32 for stable backward)
130
+ """
131
+ # Use autocast for bf16 mixed precision training
132
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
133
+ # Get tensors from batch (already on device from Fabric dataloader)
134
+ target_latents = batch["target_latents"].to(self.device) # x0
135
+ attention_mask = batch["attention_mask"].to(self.device)
136
+ encoder_hidden_states = batch["encoder_hidden_states"].to(self.device)
137
+ encoder_attention_mask = batch["encoder_attention_mask"].to(self.device)
138
+ context_latents = batch["context_latents"].to(self.device)
139
+
140
+ bsz = target_latents.shape[0]
141
+
142
+ # Flow matching: sample noise x1 and interpolate with data x0
143
+ x1 = torch.randn_like(target_latents) # Noise
144
+ x0 = target_latents # Data
145
+
146
+ # Sample timesteps from discrete turbo shift=3 schedule (8 steps)
147
+ t, r = sample_discrete_timestep(bsz, self.device, torch.bfloat16)
148
+ t_ = t.unsqueeze(-1).unsqueeze(-1)
149
+
150
+ # Interpolate: x_t = t * x1 + (1 - t) * x0
151
+ xt = t_ * x1 + (1.0 - t_) * x0
152
+
153
+ # Forward through decoder (distilled turbo model, no CFG)
154
+ decoder_outputs = self.model.decoder(
155
+ hidden_states=xt,
156
+ timestep=t,
157
+ timestep_r=t,
158
+ attention_mask=attention_mask,
159
+ encoder_hidden_states=encoder_hidden_states,
160
+ encoder_attention_mask=encoder_attention_mask,
161
+ context_latents=context_latents,
162
+ )
163
+
164
+ # Flow matching loss: predict the flow field v = x1 - x0
165
+ flow = x1 - x0
166
+ diffusion_loss = F.mse_loss(decoder_outputs[0], flow)
167
+
168
+ # Convert loss to float32 for stable backward pass
169
+ diffusion_loss = diffusion_loss.float()
170
+
171
+ self.training_losses.append(diffusion_loss.item())
172
+
173
+ return diffusion_loss
174
+
175
+
176
+ class LoRATrainer:
177
+ """High-level trainer for ACE-Step LoRA fine-tuning.
178
+
179
+ Uses Lightning Fabric for distributed training and mixed precision.
180
+ Supports training from preprocessed tensor directories.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ dit_handler,
186
+ lora_config: LoRAConfig,
187
+ training_config: TrainingConfig,
188
+ ):
189
+ """Initialize the trainer.
190
+
191
+ Args:
192
+ dit_handler: Initialized DiT handler (for model access)
193
+ lora_config: LoRA configuration
194
+ training_config: Training configuration
195
+ """
196
+ self.dit_handler = dit_handler
197
+ self.lora_config = lora_config
198
+ self.training_config = training_config
199
+
200
+ self.module = None
201
+ self.fabric = None
202
+ self.is_training = False
203
+
204
+ def train_from_preprocessed(
205
+ self,
206
+ tensor_dir: str,
207
+ training_state: Optional[Dict] = None,
208
+ ) -> Generator[Tuple[int, float, str], None, None]:
209
+ """Train LoRA adapters from preprocessed tensor files.
210
+
211
+ This is the recommended training method for best performance.
212
+
213
+ Args:
214
+ tensor_dir: Directory containing preprocessed .pt files
215
+ training_state: Optional state dict for stopping control
216
+
217
+ Yields:
218
+ Tuples of (step, loss, status_message)
219
+ """
220
+ self.is_training = True
221
+
222
+ try:
223
+ # Validate tensor directory
224
+ if not os.path.exists(tensor_dir):
225
+ yield 0, 0.0, f"❌ Tensor directory not found: {tensor_dir}"
226
+ return
227
+
228
+ # Create training module
229
+ self.module = PreprocessedLoRAModule(
230
+ model=self.dit_handler.model,
231
+ lora_config=self.lora_config,
232
+ training_config=self.training_config,
233
+ device=self.dit_handler.device,
234
+ dtype=self.dit_handler.dtype,
235
+ )
236
+
237
+ # Create data module
238
+ data_module = PreprocessedDataModule(
239
+ tensor_dir=tensor_dir,
240
+ batch_size=self.training_config.batch_size,
241
+ num_workers=self.training_config.num_workers,
242
+ pin_memory=self.training_config.pin_memory,
243
+ )
244
+
245
+ # Setup data
246
+ data_module.setup('fit')
247
+
248
+ if len(data_module.train_dataset) == 0:
249
+ yield 0, 0.0, "❌ No valid samples found in tensor directory"
250
+ return
251
+
252
+ yield 0, 0.0, f"📂 Loaded {len(data_module.train_dataset)} preprocessed samples"
253
+
254
+ if LIGHTNING_AVAILABLE:
255
+ yield from self._train_with_fabric(data_module, training_state)
256
+ else:
257
+ yield from self._train_basic(data_module, training_state)
258
+
259
+ except Exception as e:
260
+ logger.exception("Training failed")
261
+ yield 0, 0.0, f"❌ Training failed: {str(e)}"
262
+ finally:
263
+ self.is_training = False
264
+
265
+ def _train_with_fabric(
266
+ self,
267
+ data_module: PreprocessedDataModule,
268
+ training_state: Optional[Dict],
269
+ ) -> Generator[Tuple[int, float, str], None, None]:
270
+ """Train using Lightning Fabric."""
271
+ # Create output directory
272
+ os.makedirs(self.training_config.output_dir, exist_ok=True)
273
+
274
+ # Force BFloat16 precision (only supported precision for this model)
275
+ precision = "bf16-mixed"
276
+
277
+ # Create TensorBoard logger
278
+ tb_logger = TensorBoardLogger(
279
+ root_dir=self.training_config.output_dir,
280
+ name="logs"
281
+ )
282
+
283
+ # Initialize Fabric
284
+ self.fabric = Fabric(
285
+ accelerator="auto",
286
+ devices=1,
287
+ precision=precision,
288
+ loggers=[tb_logger],
289
+ )
290
+ self.fabric.launch()
291
+
292
+ yield 0, 0.0, f"🚀 Starting training (precision: {precision})..."
293
+
294
+ # Get dataloader
295
+ train_loader = data_module.train_dataloader()
296
+
297
+ # Setup optimizer - only LoRA parameters
298
+ trainable_params = [p for p in self.module.model.parameters() if p.requires_grad]
299
+
300
+ if not trainable_params:
301
+ yield 0, 0.0, "❌ No trainable parameters found!"
302
+ return
303
+
304
+ yield 0, 0.0, f"🎯 Training {sum(p.numel() for p in trainable_params):,} parameters"
305
+
306
+ optimizer = AdamW(
307
+ trainable_params,
308
+ lr=self.training_config.learning_rate,
309
+ weight_decay=self.training_config.weight_decay,
310
+ )
311
+
312
+ # Calculate total steps
313
+ total_steps = len(train_loader) * self.training_config.max_epochs // self.training_config.gradient_accumulation_steps
314
+ warmup_steps = min(self.training_config.warmup_steps, max(1, total_steps // 10))
315
+
316
+ # Scheduler
317
+ warmup_scheduler = LinearLR(
318
+ optimizer,
319
+ start_factor=0.1,
320
+ end_factor=1.0,
321
+ total_iters=warmup_steps,
322
+ )
323
+
324
+ main_scheduler = CosineAnnealingWarmRestarts(
325
+ optimizer,
326
+ T_0=max(1, total_steps - warmup_steps),
327
+ T_mult=1,
328
+ eta_min=self.training_config.learning_rate * 0.01,
329
+ )
330
+
331
+ scheduler = SequentialLR(
332
+ optimizer,
333
+ schedulers=[warmup_scheduler, main_scheduler],
334
+ milestones=[warmup_steps],
335
+ )
336
+
337
+ # Convert model to bfloat16 (entire model for consistent dtype)
338
+ self.module.model = self.module.model.to(torch.bfloat16)
339
+
340
+ # Setup with Fabric - only the decoder (which has LoRA)
341
+ self.module.model.decoder, optimizer = self.fabric.setup(self.module.model.decoder, optimizer)
342
+ train_loader = self.fabric.setup_dataloaders(train_loader)
343
+
344
+ # Training loop
345
+ global_step = 0
346
+ accumulation_step = 0
347
+ accumulated_loss = 0.0
348
+
349
+ self.module.model.decoder.train()
350
+
351
+ for epoch in range(self.training_config.max_epochs):
352
+ epoch_loss = 0.0
353
+ num_batches = 0
354
+ epoch_start_time = time.time()
355
+
356
+ for batch_idx, batch in enumerate(train_loader):
357
+ # Check for stop signal
358
+ if training_state and training_state.get("should_stop", False):
359
+ yield global_step, accumulated_loss / max(accumulation_step, 1), "⏹️ Training stopped by user"
360
+ return
361
+
362
+ # Forward pass
363
+ loss = self.module.training_step(batch)
364
+ loss = loss / self.training_config.gradient_accumulation_steps
365
+
366
+ # Backward pass
367
+ self.fabric.backward(loss)
368
+ accumulated_loss += loss.item()
369
+ accumulation_step += 1
370
+
371
+ # Optimizer step
372
+ if accumulation_step >= self.training_config.gradient_accumulation_steps:
373
+ self.fabric.clip_gradients(
374
+ self.module.model.decoder,
375
+ optimizer,
376
+ max_norm=self.training_config.max_grad_norm,
377
+ )
378
+
379
+ optimizer.step()
380
+ scheduler.step()
381
+ optimizer.zero_grad()
382
+
383
+ global_step += 1
384
+
385
+ # Log
386
+ avg_loss = accumulated_loss / accumulation_step
387
+ self.fabric.log("train/loss", avg_loss, step=global_step)
388
+ self.fabric.log("train/lr", scheduler.get_last_lr()[0], step=global_step)
389
+
390
+ if global_step % self.training_config.log_every_n_steps == 0:
391
+ yield global_step, avg_loss, f"Epoch {epoch+1}/{self.training_config.max_epochs}, Step {global_step}, Loss: {avg_loss:.4f}"
392
+
393
+ epoch_loss += accumulated_loss
394
+ num_batches += 1
395
+ accumulated_loss = 0.0
396
+ accumulation_step = 0
397
+
398
+ # End of epoch
399
+ epoch_time = time.time() - epoch_start_time
400
+ avg_epoch_loss = epoch_loss / max(num_batches, 1)
401
+
402
+ self.fabric.log("train/epoch_loss", avg_epoch_loss, step=epoch + 1)
403
+ yield global_step, avg_epoch_loss, f"✅ Epoch {epoch+1}/{self.training_config.max_epochs} in {epoch_time:.1f}s, Loss: {avg_epoch_loss:.4f}"
404
+
405
+ # Save checkpoint
406
+ if (epoch + 1) % self.training_config.save_every_n_epochs == 0:
407
+ checkpoint_dir = os.path.join(self.training_config.output_dir, "checkpoints", f"epoch_{epoch+1}")
408
+ save_lora_weights(self.module.model, checkpoint_dir)
409
+ yield global_step, avg_epoch_loss, f"💾 Checkpoint saved at epoch {epoch+1}"
410
+
411
+ # Save final model
412
+ final_path = os.path.join(self.training_config.output_dir, "final")
413
+ save_lora_weights(self.module.model, final_path)
414
+
415
+ final_loss = self.module.training_losses[-1] if self.module.training_losses else 0.0
416
+ yield global_step, final_loss, f"✅ Training complete! LoRA saved to {final_path}"
417
+
418
+ def _train_basic(
419
+ self,
420
+ data_module: PreprocessedDataModule,
421
+ training_state: Optional[Dict],
422
+ ) -> Generator[Tuple[int, float, str], None, None]:
423
+ """Basic training loop without Fabric."""
424
+ yield 0, 0.0, "🚀 Starting basic training loop..."
425
+
426
+ os.makedirs(self.training_config.output_dir, exist_ok=True)
427
+
428
+ train_loader = data_module.train_dataloader()
429
+
430
+ trainable_params = [p for p in self.module.model.parameters() if p.requires_grad]
431
+
432
+ if not trainable_params:
433
+ yield 0, 0.0, "❌ No trainable parameters found!"
434
+ return
435
+
436
+ optimizer = AdamW(
437
+ trainable_params,
438
+ lr=self.training_config.learning_rate,
439
+ weight_decay=self.training_config.weight_decay,
440
+ )
441
+
442
+ total_steps = len(train_loader) * self.training_config.max_epochs // self.training_config.gradient_accumulation_steps
443
+ warmup_steps = min(self.training_config.warmup_steps, max(1, total_steps // 10))
444
+
445
+ warmup_scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps)
446
+ main_scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=max(1, total_steps - warmup_steps), T_mult=1, eta_min=self.training_config.learning_rate * 0.01)
447
+ scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_steps])
448
+
449
+ global_step = 0
450
+ accumulation_step = 0
451
+ accumulated_loss = 0.0
452
+
453
+ self.module.model.decoder.train()
454
+
455
+ for epoch in range(self.training_config.max_epochs):
456
+ epoch_loss = 0.0
457
+ num_batches = 0
458
+ epoch_start_time = time.time()
459
+
460
+ for batch in train_loader:
461
+ if training_state and training_state.get("should_stop", False):
462
+ yield global_step, accumulated_loss / max(accumulation_step, 1), "⏹️ Training stopped"
463
+ return
464
+
465
+ loss = self.module.training_step(batch)
466
+ loss = loss / self.training_config.gradient_accumulation_steps
467
+ loss.backward()
468
+ accumulated_loss += loss.item()
469
+ accumulation_step += 1
470
+
471
+ if accumulation_step >= self.training_config.gradient_accumulation_steps:
472
+ torch.nn.utils.clip_grad_norm_(trainable_params, self.training_config.max_grad_norm)
473
+ optimizer.step()
474
+ scheduler.step()
475
+ optimizer.zero_grad()
476
+ global_step += 1
477
+
478
+ if global_step % self.training_config.log_every_n_steps == 0:
479
+ avg_loss = accumulated_loss / accumulation_step
480
+ yield global_step, avg_loss, f"Epoch {epoch+1}, Step {global_step}, Loss: {avg_loss:.4f}"
481
+
482
+ epoch_loss += accumulated_loss
483
+ num_batches += 1
484
+ accumulated_loss = 0.0
485
+ accumulation_step = 0
486
+
487
+ epoch_time = time.time() - epoch_start_time
488
+ avg_epoch_loss = epoch_loss / max(num_batches, 1)
489
+ yield global_step, avg_epoch_loss, f"✅ Epoch {epoch+1}/{self.training_config.max_epochs} in {epoch_time:.1f}s"
490
+
491
+ if (epoch + 1) % self.training_config.save_every_n_epochs == 0:
492
+ checkpoint_dir = os.path.join(self.training_config.output_dir, "checkpoints", f"epoch_{epoch+1}")
493
+ save_lora_weights(self.module.model, checkpoint_dir)
494
+ yield global_step, avg_epoch_loss, f"💾 Checkpoint saved"
495
+
496
+ final_path = os.path.join(self.training_config.output_dir, "final")
497
+ save_lora_weights(self.module.model, final_path)
498
+ final_loss = self.module.training_losses[-1] if self.module.training_losses else 0.0
499
+ yield global_step, final_loss, f"✅ Training complete! LoRA saved to {final_path}"
500
+
501
+ def stop(self):
502
+ """Stop training."""
503
+ self.is_training = False
requirements.txt CHANGED
@@ -24,6 +24,10 @@ numba>=0.63.1
24
  vector-quantize-pytorch>=1.27.15
25
  torchcodec>=0.9.1
26
 
 
 
 
 
27
  # nano-vllm dependencies
28
  triton-windows>=3.0.0,<3.4; sys_platform == 'win32'
29
  triton>=3.0.0; sys_platform != 'win32'
 
24
  vector-quantize-pytorch>=1.27.15
25
  torchcodec>=0.9.1
26
 
27
+ # LoRA Training dependencies (optional)
28
+ peft>=0.7.0
29
+ lightning>=2.0.0
30
+
31
  # nano-vllm dependencies
32
  triton-windows>=3.0.0,<3.4; sys_platform == 'win32'
33
  triton>=3.0.0; sys_platform != 'win32'