pedroapfilho commited on
Commit
3bc6e37
·
unverified ·
1 Parent(s): 5b76ce1

Rewrite Tab 3 to wire real training pipeline with HF dataset support

Browse files

Replace placeholder LoRA trainer (hardcoded loss=0.5) with 4-step wizard:
- Data Source: upload audio files or download from HuggingFace Hub
- Label & Review: auto-label via LLM with editable dataframe
- Preprocess: VAE + text encoding to training tensors
- Train: real Fabric-based LoRA training with stop control

Files changed (2) hide show
  1. app.py +443 -138
  2. src/lora_trainer.py +61 -351
app.py CHANGED
@@ -16,9 +16,12 @@ import spaces
16
 
17
  from src.ace_step_engine import ACEStepEngine
18
  from src.timeline_manager import TimelineManager
19
- from src.lora_trainer import LoRATrainer
20
  from src.audio_processor import AudioProcessor
21
  from src.utils import setup_logging, load_config
 
 
 
22
 
23
  # Setup
24
  logger = setup_logging()
@@ -27,9 +30,13 @@ config = load_config()
27
  # Lazy initialize components (will be initialized on first use)
28
  ace_engine = None
29
  timeline_manager = None
30
- lora_trainer = None
31
  audio_processor = None
32
 
 
 
 
 
33
  def get_ace_engine():
34
  """Lazy-load ACE-Step engine."""
35
  global ace_engine
@@ -44,12 +51,12 @@ def get_timeline_manager():
44
  timeline_manager = TimelineManager(config)
45
  return timeline_manager
46
 
47
- def get_lora_trainer():
48
- """Lazy-load LoRA trainer."""
49
- global lora_trainer
50
- if lora_trainer is None:
51
- lora_trainer = LoRATrainer(config)
52
- return lora_trainer
53
 
54
  def get_audio_processor():
55
  """Lazy-load audio processor."""
@@ -277,64 +284,246 @@ def timeline_reset(session_state: dict) -> Tuple[None, None, str, dict]:
277
  return None, None, "Timeline cleared", session_state
278
 
279
 
280
- # ==================== TAB 3: LORA TRAINING ====================
 
 
 
 
 
 
 
 
 
281
 
282
- def lora_upload_files(files: List[str]) -> str:
283
- """Upload and prepare audio files for LoRA training."""
284
  try:
285
- prepared_files = get_lora_trainer().prepare_dataset(files)
286
- return f" Prepared {len(prepared_files)} files for training"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  except Exception as e:
288
- return f" Error: {str(e)}"
 
 
289
 
290
  @spaces.GPU(duration=300)
 
 
 
 
291
 
292
- def lora_train(
293
- dataset_path: str,
294
- model_name: str,
295
- learning_rate: float,
296
- batch_size: int,
297
- num_epochs: int,
298
- rank: int,
299
- alpha: int,
300
- use_existing_lora: bool,
301
- existing_lora_path: Optional[str] = None,
302
- progress=gr.Progress()
303
- ) -> Tuple[str, str]:
304
- """Train LoRA model on uploaded dataset."""
 
 
 
 
 
 
 
 
 
 
 
 
305
  try:
306
- logger.info(f"Starting LoRA training: {model_name}")
307
-
308
- # Initialize or load LoRA
309
- if use_existing_lora and existing_lora_path:
310
- lora_trainer.load_lora(existing_lora_path)
311
- else:
312
- lora_trainer.initialize_lora(rank=rank, alpha=alpha)
313
-
314
- # Train
315
- def progress_callback(step, total_steps, loss):
316
- progress((step, total_steps), desc=f"Training (loss: {loss:.4f})")
317
-
318
- result_path = lora_trainer.train(
319
- dataset_path=dataset_path,
320
- model_name=model_name,
321
- learning_rate=learning_rate,
322
- batch_size=batch_size,
323
- num_epochs=num_epochs,
324
- progress_callback=progress_callback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  )
326
-
327
- info = f"✅ Training complete! Model saved to {result_path}"
328
- return result_path, info
329
-
 
 
330
  except Exception as e:
331
- logger.error(f"LoRA training failed: {e}")
332
- return None, f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
333
 
 
 
334
 
335
- def lora_download(lora_path: str) -> str:
336
- """Provide LoRA model for download."""
337
- return lora_path if Path(lora_path).exists() else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
 
340
  # ==================== GRADIO UI ====================
@@ -543,103 +732,219 @@ def create_ui():
543
  outputs=[tl_full_audio, tl_timeline_viz, timeline_state, tl_info]
544
  )
545
 
546
- # ============ TAB 3: LORA TRAINING ============
547
  with gr.Tab("🎓 LoRA Training Studio"):
548
  gr.Markdown("""
549
  ### Train Custom LoRA Models
550
- Upload audio files to train specialized models for voice cloning, style adaptation, etc.
551
  """)
552
-
553
- with gr.Row():
554
- with gr.Column():
555
- gr.Markdown("#### 1. Upload Training Data")
556
- lora_files = gr.File(
557
- label="Audio Files",
558
- file_count="multiple",
559
- file_types=["audio"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  )
561
- lora_upload_btn = gr.Button("📤 Upload & Prepare Dataset")
562
- lora_upload_status = gr.Textbox(label="Upload Status", lines=2)
563
-
564
- gr.Markdown("#### 2. Training Configuration")
565
- lora_dataset_path = gr.Textbox(
566
- label="Dataset Path",
567
- placeholder="Path to prepared dataset"
568
  )
569
- lora_model_name = gr.Textbox(
570
- label="Model Name",
571
- placeholder="my_custom_lora"
572
  )
573
-
574
- with gr.Row():
575
- lora_learning_rate = gr.Number(
576
- label="Learning Rate",
577
- value=1e-4
578
- )
579
- lora_batch_size = gr.Slider(
580
- minimum=1, maximum=16, value=4, step=1,
581
- label="Batch Size"
582
- )
583
-
584
- with gr.Row():
585
- lora_num_epochs = gr.Slider(
586
- minimum=1, maximum=100, value=10, step=1,
587
- label="Epochs"
588
- )
589
- lora_rank = gr.Slider(
590
- minimum=4, maximum=128, value=16, step=4,
591
- label="LoRA Rank"
592
- )
593
- lora_alpha = gr.Slider(
594
- minimum=4, maximum=128, value=32, step=4,
595
- label="LoRA Alpha"
596
- )
597
-
598
- lora_use_existing = gr.Checkbox(
599
- label="Continue training from existing LoRA",
600
- value=False
601
  )
602
- lora_existing_path = gr.Textbox(
603
- label="Existing LoRA Path",
604
- placeholder="Path to existing LoRA model"
 
 
 
605
  )
606
-
607
- lora_train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
608
-
609
- with gr.Column():
610
- lora_train_status = gr.Textbox(label="Training Status", lines=3)
611
- lora_model_path = gr.Textbox(label="Trained Model Path", lines=1)
612
- lora_download_btn = gr.Button("💾 Download Model")
613
- lora_download_file = gr.File(label="Download")
614
-
615
- gr.Markdown("""
616
- #### Training Tips
617
- - Upload 10+ audio samples for best results
618
- - Keep samples consistent in style/quality
619
- - Higher rank = more capacity but slower training
620
- - Start with 10-20 epochs and adjust
621
- - Use existing LoRA to continue training
622
- """)
623
-
624
- # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  lora_upload_btn.click(
626
- fn=lora_upload_files,
627
- inputs=[lora_files],
628
- outputs=[lora_upload_status]
629
  )
630
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  lora_train_btn.click(
632
- fn=lora_train,
633
- inputs=[lora_dataset_path, lora_model_name, lora_learning_rate,
634
- lora_batch_size, lora_num_epochs, lora_rank, lora_alpha,
635
- lora_use_existing, lora_existing_path],
636
- outputs=[lora_model_path, lora_train_status]
 
 
637
  )
638
-
639
- lora_download_btn.click(
640
- fn=lora_download,
 
 
 
 
 
 
641
  inputs=[lora_model_path],
642
- outputs=[lora_download_file]
643
  )
644
 
645
  gr.Markdown("""
 
16
 
17
  from src.ace_step_engine import ACEStepEngine
18
  from src.timeline_manager import TimelineManager
19
+ from src.lora_trainer import download_hf_dataset
20
  from src.audio_processor import AudioProcessor
21
  from src.utils import setup_logging, load_config
22
+ from acestep.training.dataset_builder import DatasetBuilder
23
+ from acestep.training.configs import LoRAConfig, TrainingConfig
24
+ from acestep.training.trainer import LoRATrainer as FabricLoRATrainer
25
 
26
  # Setup
27
  logger = setup_logging()
 
30
  # Lazy initialize components (will be initialized on first use)
31
  ace_engine = None
32
  timeline_manager = None
33
+ dataset_builder = None
34
  audio_processor = None
35
 
36
+ # Module-level mutable dict for training stop signal
37
+ # (gr.State is not shared between concurrent Gradio calls)
38
+ _training_control = {"should_stop": False}
39
+
40
  def get_ace_engine():
41
  """Lazy-load ACE-Step engine."""
42
  global ace_engine
 
51
  timeline_manager = TimelineManager(config)
52
  return timeline_manager
53
 
54
+ def get_dataset_builder():
55
+ """Lazy-load dataset builder."""
56
+ global dataset_builder
57
+ if dataset_builder is None:
58
+ dataset_builder = DatasetBuilder()
59
+ return dataset_builder
60
 
61
  def get_audio_processor():
62
  """Lazy-load audio processor."""
 
284
  return None, None, "Timeline cleared", session_state
285
 
286
 
287
+ # ==================== TAB 3: LORA TRAINING STUDIO ====================
288
+
289
+ DATAFRAME_HEADERS = ["#", "Filename", "Duration", "Lyrics", "Labeled", "BPM", "Key", "Caption"]
290
+
291
+
292
+ def _build_review_dataframe():
293
+ """Build editable dataframe rows from current dataset builder state."""
294
+ builder = get_dataset_builder()
295
+ return builder.get_samples_dataframe_data()
296
+
297
 
298
+ def lora_upload_and_scan(files, training_state):
299
+ """Copy uploaded audio files to working dir and scan."""
300
  try:
301
+ if not files:
302
+ return "No files uploaded", training_state
303
+
304
+ import shutil
305
+
306
+ work_dir = Path("lora_training") / "uploaded"
307
+ work_dir.mkdir(parents=True, exist_ok=True)
308
+
309
+ for f in files:
310
+ src = Path(f)
311
+ shutil.copy2(str(src), str(work_dir / src.name))
312
+
313
+ builder = get_dataset_builder()
314
+ samples, status = builder.scan_directory(str(work_dir))
315
+
316
+ training_state = training_state or {}
317
+ training_state["audio_dir"] = str(work_dir)
318
+
319
+ return f"Scanned {len(samples)} audio files from uploads", training_state
320
+
321
+ except Exception as e:
322
+ logger.error(f"Upload scan failed: {e}")
323
+ return f"Error: {e}", training_state or {}
324
+
325
+
326
+ def lora_download_hf(dataset_id, hf_token, training_state):
327
+ """Download HuggingFace dataset and scan for audio files."""
328
+ try:
329
+ if not dataset_id or not dataset_id.strip():
330
+ return "Enter a dataset ID (e.g. pedroapfilho/lofi-tracks)", training_state
331
+
332
+ token = hf_token.strip() if hf_token else None
333
+ output_dir = str(Path("lora_training") / "hf_datasets")
334
+
335
+ local_dir, dl_status = download_hf_dataset(
336
+ dataset_id.strip(), output_dir, hf_token=token
337
+ )
338
+
339
+ if not local_dir:
340
+ return f"Download failed: {dl_status}", training_state
341
+
342
+ builder = get_dataset_builder()
343
+ samples, scan_status = builder.scan_directory(local_dir)
344
+
345
+ training_state = training_state or {}
346
+ training_state["audio_dir"] = local_dir
347
+
348
+ return f"{dl_status} | {scan_status}", training_state
349
+
350
  except Exception as e:
351
+ logger.error(f"HF download failed: {e}")
352
+ return f"Error: {e}", training_state or {}
353
+
354
 
355
  @spaces.GPU(duration=300)
356
+ def lora_auto_label(training_state, progress=gr.Progress()):
357
+ """Auto-label all samples using LLM analysis."""
358
+ try:
359
+ builder = get_dataset_builder()
360
 
361
+ if builder.get_sample_count() == 0:
362
+ return [], "No samples loaded. Upload files or download a dataset first."
363
+
364
+ engine = get_ace_engine()
365
+ if not engine.is_initialized():
366
+ return [], "ACE-Step engine not initialized. Models may still be loading."
367
+
368
+ def progress_callback(msg):
369
+ progress(0, desc=msg)
370
+
371
+ samples, status = builder.label_all_samples(
372
+ dit_handler=engine.dit_handler,
373
+ llm_handler=engine.llm_handler,
374
+ progress_callback=progress_callback,
375
+ )
376
+
377
+ return _build_review_dataframe(), status
378
+
379
+ except Exception as e:
380
+ logger.error(f"Auto-label failed: {e}")
381
+ return [], f"Error: {e}"
382
+
383
+
384
+ def lora_save_edits(df_data, training_state):
385
+ """Save user edits from the review dataframe back to samples."""
386
  try:
387
+ builder = get_dataset_builder()
388
+
389
+ if not df_data or len(df_data) == 0:
390
+ return "No data to save"
391
+
392
+ updated = 0
393
+ for row in df_data:
394
+ idx = int(row[0])
395
+ updates = {}
396
+
397
+ # Map editable columns back to sample fields
398
+ bpm_val = row[5]
399
+ if bpm_val and bpm_val != "-":
400
+ try:
401
+ updates["bpm"] = int(bpm_val)
402
+ except (ValueError, TypeError):
403
+ pass
404
+
405
+ key_val = row[6]
406
+ if key_val and key_val != "-":
407
+ updates["keyscale"] = str(key_val)
408
+
409
+ caption_val = row[7]
410
+ if caption_val and caption_val != "-":
411
+ updates["caption"] = str(caption_val)
412
+
413
+ if updates:
414
+ builder.update_sample(idx, **updates)
415
+ updated += 1
416
+
417
+ return f"Updated {updated} samples"
418
+
419
+ except Exception as e:
420
+ logger.error(f"Save edits failed: {e}")
421
+ return f"Error: {e}"
422
+
423
+
424
+ @spaces.GPU(duration=300)
425
+ def lora_preprocess(training_state, progress=gr.Progress()):
426
+ """Preprocess labeled samples to training tensors."""
427
+ try:
428
+ builder = get_dataset_builder()
429
+
430
+ if builder.get_labeled_count() == 0:
431
+ return "No labeled samples. Run auto-label first."
432
+
433
+ engine = get_ace_engine()
434
+ if not engine.is_initialized():
435
+ return "ACE-Step engine not initialized."
436
+
437
+ tensor_dir = str(Path("lora_training") / "tensors")
438
+
439
+ def progress_callback(msg):
440
+ progress(0, desc=msg)
441
+
442
+ output_paths, status = builder.preprocess_to_tensors(
443
+ dit_handler=engine.dit_handler,
444
+ output_dir=tensor_dir,
445
+ progress_callback=progress_callback,
446
  )
447
+
448
+ training_state = training_state or {}
449
+ training_state["tensor_dir"] = tensor_dir
450
+
451
+ return status
452
+
453
  except Exception as e:
454
+ logger.error(f"Preprocess failed: {e}")
455
+ return f"Error: {e}"
456
+
457
+
458
+ @spaces.GPU(duration=600)
459
+ def lora_train_real(
460
+ lr, batch_size, epochs, rank, alpha,
461
+ grad_accum, model_name, training_state,
462
+ progress=gr.Progress(),
463
+ ):
464
+ """Train LoRA using the real Fabric-based trainer."""
465
+ try:
466
+ training_state = training_state or {}
467
+ tensor_dir = training_state.get("tensor_dir", "")
468
 
469
+ if not tensor_dir or not Path(tensor_dir).exists():
470
+ return "", "No preprocessed tensors found. Run preprocessing first."
471
 
472
+ engine = get_ace_engine()
473
+ if not engine.is_initialized():
474
+ return "", "ACE-Step engine not initialized."
475
+
476
+ lora_cfg = LoRAConfig(r=int(rank), alpha=int(alpha))
477
+ output_dir = str(Path("lora_training") / "models" / (model_name or "lora_model"))
478
+
479
+ train_cfg = TrainingConfig(
480
+ learning_rate=float(lr),
481
+ batch_size=int(batch_size),
482
+ max_epochs=int(epochs),
483
+ gradient_accumulation_steps=int(grad_accum),
484
+ output_dir=output_dir,
485
+ )
486
+
487
+ trainer = FabricLoRATrainer(
488
+ dit_handler=engine.dit_handler,
489
+ lora_config=lora_cfg,
490
+ training_config=train_cfg,
491
+ )
492
+
493
+ _training_control["should_stop"] = False
494
+ last_msg = ""
495
+
496
+ for step, loss, message in trainer.train_from_preprocessed(
497
+ tensor_dir=tensor_dir,
498
+ training_state=_training_control,
499
+ ):
500
+ last_msg = f"Step {step} | Loss: {loss:.4f} | {message}"
501
+ progress(0, desc=last_msg)
502
+
503
+ if _training_control.get("should_stop"):
504
+ trainer.stop()
505
+ last_msg = f"Training stopped at step {step} (loss: {loss:.4f})"
506
+ break
507
+
508
+ final_path = str(Path(output_dir) / "final")
509
+ return final_path, last_msg
510
+
511
+ except Exception as e:
512
+ logger.error(f"Training failed: {e}")
513
+ return "", f"Error: {e}"
514
+
515
+
516
+ def lora_stop_training():
517
+ """Signal the training loop to stop."""
518
+ _training_control["should_stop"] = True
519
+ return "Stop signal sent. Training will stop after current step."
520
+
521
+
522
+ def lora_download_model(model_path):
523
+ """Return model path for Gradio file download."""
524
+ if model_path and Path(model_path).exists():
525
+ return model_path
526
+ return None
527
 
528
 
529
  # ==================== GRADIO UI ====================
 
732
  outputs=[tl_full_audio, tl_timeline_viz, timeline_state, tl_info]
733
  )
734
 
735
+ # ============ TAB 3: LORA TRAINING STUDIO ============
736
  with gr.Tab("🎓 LoRA Training Studio"):
737
  gr.Markdown("""
738
  ### Train Custom LoRA Models
739
+ Step-by-step wizard: provide audio data, auto-label with LLM, preprocess, and train.
740
  """)
741
+
742
+ training_state = gr.State(value={})
743
+
744
+ with gr.Tabs():
745
+
746
+ # ---------- Sub-tab 1: Data Source ----------
747
+ with gr.Tab("1. Data Source"):
748
+ gr.Markdown("Choose one: upload audio files or download from HuggingFace.")
749
+
750
+ with gr.Row():
751
+ with gr.Column():
752
+ gr.Markdown("#### Upload Files")
753
+ lora_files = gr.File(
754
+ label="Audio Files (WAV, MP3, FLAC, OGG, OPUS)",
755
+ file_count="multiple",
756
+ file_types=["audio"],
757
+ )
758
+ lora_upload_btn = gr.Button(
759
+ "Upload & Scan", variant="primary"
760
+ )
761
+
762
+ with gr.Column():
763
+ gr.Markdown("#### HuggingFace Dataset")
764
+ lora_hf_id = gr.Textbox(
765
+ label="Dataset ID",
766
+ placeholder="pedroapfilho/lofi-tracks",
767
+ )
768
+ lora_hf_token = gr.Textbox(
769
+ label="HF Token (optional, for private repos)",
770
+ type="password",
771
+ )
772
+ lora_hf_btn = gr.Button(
773
+ "Download & Scan", variant="primary"
774
+ )
775
+
776
+ lora_source_status = gr.Textbox(
777
+ label="Status", lines=2, interactive=False
778
  )
779
+
780
+ # ---------- Sub-tab 2: Label & Review ----------
781
+ with gr.Tab("2. Label & Review"):
782
+ gr.Markdown(
783
+ "Auto-label samples using the LLM, then review and edit metadata."
 
 
784
  )
785
+
786
+ lora_label_btn = gr.Button(
787
+ "Auto-Label All Samples", variant="primary"
788
  )
789
+ lora_label_status = gr.Textbox(
790
+ label="Label Status", lines=2, interactive=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
  )
792
+
793
+ lora_review_df = gr.Dataframe(
794
+ headers=DATAFRAME_HEADERS,
795
+ label="Sample Review (editable: BPM, Key, Caption)",
796
+ interactive=True,
797
+ wrap=True,
798
  )
799
+
800
+ lora_save_btn = gr.Button("Save Edits")
801
+ lora_save_status = gr.Textbox(
802
+ label="Save Status", interactive=False
803
+ )
804
+
805
+ # ---------- Sub-tab 3: Preprocess ----------
806
+ with gr.Tab("3. Preprocess"):
807
+ gr.Markdown(
808
+ "Encode audio through VAE and text encoders to create training tensors."
809
+ )
810
+
811
+ lora_preprocess_btn = gr.Button(
812
+ "Preprocess to Tensors", variant="primary"
813
+ )
814
+ lora_preprocess_status = gr.Textbox(
815
+ label="Preprocess Status", lines=3, interactive=False
816
+ )
817
+
818
+ # ---------- Sub-tab 4: Train ----------
819
+ with gr.Tab("4. Train"):
820
+ gr.Markdown("Configure and run LoRA training.")
821
+
822
+ with gr.Row():
823
+ with gr.Column():
824
+ lora_model_name = gr.Textbox(
825
+ label="Model Name",
826
+ value="my_lora",
827
+ placeholder="my_lora",
828
+ )
829
+
830
+ with gr.Row():
831
+ lora_lr = gr.Number(
832
+ label="Learning Rate", value=1e-4
833
+ )
834
+ lora_batch_size = gr.Slider(
835
+ minimum=1, maximum=8, value=1, step=1,
836
+ label="Batch Size",
837
+ )
838
+
839
+ with gr.Row():
840
+ lora_epochs = gr.Slider(
841
+ minimum=1, maximum=500, value=100, step=1,
842
+ label="Epochs",
843
+ )
844
+ lora_grad_accum = gr.Slider(
845
+ minimum=1, maximum=16, value=4, step=1,
846
+ label="Gradient Accumulation",
847
+ )
848
+
849
+ with gr.Row():
850
+ lora_rank = gr.Slider(
851
+ minimum=4, maximum=128, value=8, step=4,
852
+ label="LoRA Rank",
853
+ )
854
+ lora_alpha = gr.Slider(
855
+ minimum=4, maximum=128, value=16, step=4,
856
+ label="LoRA Alpha",
857
+ )
858
+
859
+ with gr.Row():
860
+ lora_train_btn = gr.Button(
861
+ "Start Training",
862
+ variant="primary",
863
+ size="lg",
864
+ )
865
+ lora_stop_btn = gr.Button(
866
+ "Stop Training",
867
+ variant="stop",
868
+ size="lg",
869
+ )
870
+
871
+ with gr.Column():
872
+ lora_train_status = gr.Textbox(
873
+ label="Training Status",
874
+ lines=4,
875
+ interactive=False,
876
+ )
877
+ lora_model_path = gr.Textbox(
878
+ label="Model Path",
879
+ interactive=False,
880
+ )
881
+ lora_dl_btn = gr.Button("Download Model")
882
+ lora_dl_file = gr.File(label="Download")
883
+
884
+ gr.Markdown("""
885
+ #### Tips
886
+ - Upload 10+ audio samples for best results
887
+ - Keep samples consistent in style/quality
888
+ - Higher rank = more capacity but slower training
889
+ - Default settings (rank=8, lr=1e-4, 100 epochs) are a good starting point
890
+ """)
891
+
892
+ # ---------- Event handlers ----------
893
+
894
+ # Data Source
895
  lora_upload_btn.click(
896
+ fn=lora_upload_and_scan,
897
+ inputs=[lora_files, training_state],
898
+ outputs=[lora_source_status, training_state],
899
  )
900
+
901
+ lora_hf_btn.click(
902
+ fn=lora_download_hf,
903
+ inputs=[lora_hf_id, lora_hf_token, training_state],
904
+ outputs=[lora_source_status, training_state],
905
+ )
906
+
907
+ # Label & Review
908
+ lora_label_btn.click(
909
+ fn=lora_auto_label,
910
+ inputs=[training_state],
911
+ outputs=[lora_review_df, lora_label_status],
912
+ )
913
+
914
+ lora_save_btn.click(
915
+ fn=lora_save_edits,
916
+ inputs=[lora_review_df, training_state],
917
+ outputs=[lora_save_status],
918
+ )
919
+
920
+ # Preprocess
921
+ lora_preprocess_btn.click(
922
+ fn=lora_preprocess,
923
+ inputs=[training_state],
924
+ outputs=[lora_preprocess_status],
925
+ )
926
+
927
+ # Train
928
  lora_train_btn.click(
929
+ fn=lora_train_real,
930
+ inputs=[
931
+ lora_lr, lora_batch_size, lora_epochs,
932
+ lora_rank, lora_alpha, lora_grad_accum,
933
+ lora_model_name, training_state,
934
+ ],
935
+ outputs=[lora_model_path, lora_train_status],
936
  )
937
+
938
+ lora_stop_btn.click(
939
+ fn=lora_stop_training,
940
+ inputs=[],
941
+ outputs=[lora_train_status],
942
+ )
943
+
944
+ lora_dl_btn.click(
945
+ fn=lora_download_model,
946
  inputs=[lora_model_path],
947
+ outputs=[lora_dl_file],
948
  )
949
 
950
  gr.Markdown("""
src/lora_trainer.py CHANGED
@@ -1,359 +1,69 @@
1
  """
2
- LoRA Trainer - Handles LoRA training for custom models
 
 
 
3
  """
4
 
5
- import torch
6
- import torchaudio
7
- from pathlib import Path
8
  import logging
9
- from typing import List, Dict, Any, Optional, Callable
10
- import json
11
- from datetime import datetime
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- class LoRATrainer:
17
- """Manages LoRA training for ACE-Step model."""
18
-
19
- def __init__(self, config: Dict[str, Any]):
20
- """
21
- Initialize LoRA trainer.
22
-
23
- Args:
24
- config: Configuration dictionary
25
- """
26
- self.config = config
27
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- self.training_dir = Path(config.get("training_dir", "lora_training"))
29
- self.training_dir.mkdir(exist_ok=True)
30
-
31
- self.model = None
32
- self.lora_config = None
33
-
34
- logger.info(f"LoRA Trainer initialized on {self.device}")
35
-
36
- def prepare_dataset(self, audio_files: List[str]) -> List[str]:
37
- """
38
- Prepare audio files for training.
39
-
40
- Args:
41
- audio_files: List of audio file paths
42
-
43
- Returns:
44
- List of prepared file paths
45
- """
46
- try:
47
- logger.info(f"Preparing {len(audio_files)} files for training...")
48
-
49
- prepared_dir = self.training_dir / "prepared_data" / datetime.now().strftime("%Y%m%d_%H%M%S")
50
- prepared_dir.mkdir(parents=True, exist_ok=True)
51
-
52
- prepared_files = []
53
-
54
- for i, file_path in enumerate(audio_files):
55
- try:
56
- # Load audio
57
- audio, sr = torchaudio.load(file_path)
58
-
59
- # Resample to target sample rate if needed
60
- target_sr = self.config.get("sample_rate", 44100)
61
- if sr != target_sr:
62
- resampler = torchaudio.transforms.Resample(sr, target_sr)
63
- audio = resampler(audio)
64
-
65
- # Convert to mono if needed (for some training scenarios)
66
- if audio.shape[0] > 1 and self.config.get("force_mono", False):
67
- audio = torch.mean(audio, dim=0, keepdim=True)
68
-
69
- # Normalize
70
- audio = audio / (torch.abs(audio).max() + 1e-8)
71
-
72
- # Split long files into chunks if needed
73
- chunk_duration = self.config.get("chunk_duration", 30) # seconds
74
- chunk_samples = int(chunk_duration * target_sr)
75
-
76
- if audio.shape[1] > chunk_samples:
77
- # Split into chunks
78
- num_chunks = audio.shape[1] // chunk_samples
79
- for j in range(num_chunks):
80
- start = j * chunk_samples
81
- end = start + chunk_samples
82
- chunk = audio[:, start:end]
83
-
84
- # Save chunk
85
- chunk_path = prepared_dir / f"audio_{i:04d}_chunk_{j:02d}.wav"
86
- torchaudio.save(
87
- str(chunk_path),
88
- chunk,
89
- target_sr,
90
- encoding="PCM_S",
91
- bits_per_sample=16
92
- )
93
- prepared_files.append(str(chunk_path))
94
- else:
95
- # Save as-is
96
- output_path = prepared_dir / f"audio_{i:04d}.wav"
97
- torchaudio.save(
98
- str(output_path),
99
- audio,
100
- target_sr,
101
- encoding="PCM_S",
102
- bits_per_sample=16
103
- )
104
- prepared_files.append(str(output_path))
105
-
106
- except Exception as e:
107
- logger.warning(f"Failed to process {file_path}: {e}")
108
- continue
109
-
110
- # Save dataset metadata
111
- metadata = {
112
- "num_files": len(prepared_files),
113
- "original_files": len(audio_files),
114
- "sample_rate": target_sr,
115
- "prepared_at": datetime.now().isoformat(),
116
- "files": prepared_files
117
- }
118
-
119
- metadata_path = prepared_dir / "metadata.json"
120
- with open(metadata_path, 'w') as f:
121
- json.dump(metadata, f, indent=2)
122
-
123
- logger.info(f"✅ Prepared {len(prepared_files)} training files")
124
- return prepared_files
125
-
126
- except Exception as e:
127
- logger.error(f"Dataset preparation failed: {e}")
128
- raise
129
-
130
- def initialize_lora(self, rank: int = 16, alpha: int = 32):
131
- """
132
- Initialize LoRA configuration.
133
-
134
- Args:
135
- rank: LoRA rank
136
- alpha: LoRA alpha
137
- """
138
- try:
139
- from peft import LoraConfig, get_peft_model
140
-
141
- self.lora_config = LoraConfig(
142
- r=rank,
143
- lora_alpha=alpha,
144
- target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Attention layers
145
- lora_dropout=0.1,
146
- bias="none",
147
- task_type="CAUSAL_LM"
148
- )
149
-
150
- logger.info(f"✅ LoRA initialized: rank={rank}, alpha={alpha}")
151
-
152
- except Exception as e:
153
- logger.error(f"LoRA initialization failed: {e}")
154
- raise
155
-
156
- def load_lora(self, lora_path: str):
157
- """
158
- Load existing LoRA model for continued training.
159
-
160
- Args:
161
- lora_path: Path to LoRA model
162
- """
163
- try:
164
- from peft import PeftModel
165
- from transformers import AutoModel
166
-
167
- # Load base model
168
- base_model = AutoModel.from_pretrained(
169
- self.config.get("model_path", "ACE-Step/ACE-Step-v1-3.5B"),
170
- torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
171
- )
172
-
173
- # Load with LoRA
174
- self.model = PeftModel.from_pretrained(base_model, lora_path)
175
-
176
- logger.info(f"✅ Loaded LoRA from {lora_path}")
177
-
178
- except Exception as e:
179
- logger.error(f"Failed to load LoRA: {e}")
180
- raise
181
-
182
- def train(
183
- self,
184
- dataset_path: str,
185
- model_name: str,
186
- learning_rate: float = 1e-4,
187
- batch_size: int = 4,
188
- num_epochs: int = 10,
189
- progress_callback: Optional[Callable] = None
190
- ) -> str:
191
- """
192
- Train LoRA model.
193
-
194
- Args:
195
- dataset_path: Path to prepared dataset
196
- model_name: Name for the trained model
197
- learning_rate: Learning rate
198
- batch_size: Batch size
199
- num_epochs: Number of epochs
200
- progress_callback: Optional callback for progress updates
201
-
202
- Returns:
203
- Path to trained model
204
- """
205
- try:
206
- logger.info(f"Starting LoRA training: {model_name}")
207
-
208
- # Load dataset
209
- dataset = self._load_dataset(dataset_path)
210
-
211
- # Load base model if not already loaded
212
- if self.model is None:
213
- from transformers import AutoModel
214
- from peft import get_peft_model
215
-
216
- base_model = AutoModel.from_pretrained(
217
- self.config.get("model_path", "ACE-Step/ACE-Step-v1-3.5B"),
218
- torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
219
- device_map="auto"
220
- )
221
-
222
- self.model = get_peft_model(base_model, self.lora_config)
223
-
224
- self.model.train()
225
-
226
- # Setup optimizer
227
- optimizer = torch.optim.AdamW(
228
- self.model.parameters(),
229
- lr=learning_rate,
230
- weight_decay=0.01
231
- )
232
-
233
- # Training loop
234
- total_steps = (len(dataset) // batch_size) * num_epochs
235
- step = 0
236
-
237
- for epoch in range(num_epochs):
238
- epoch_loss = 0.0
239
-
240
- for batch_idx in range(0, len(dataset), batch_size):
241
- batch = dataset[batch_idx:batch_idx + batch_size]
242
-
243
- # Forward pass (simplified - actual implementation would be more complex)
244
- loss = self._training_step(batch)
245
-
246
- # Backward pass
247
- optimizer.zero_grad()
248
- loss.backward()
249
- optimizer.step()
250
-
251
- epoch_loss += loss.item()
252
- step += 1
253
-
254
- # Progress callback
255
- if progress_callback:
256
- progress_callback(step, total_steps, loss.item())
257
-
258
- avg_loss = epoch_loss / (len(dataset) // batch_size)
259
- logger.info(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")
260
-
261
- # Save trained model
262
- output_dir = self.training_dir / "models" / model_name
263
- output_dir.mkdir(parents=True, exist_ok=True)
264
-
265
- self.model.save_pretrained(str(output_dir))
266
-
267
- # Save training info
268
- info = {
269
- "model_name": model_name,
270
- "learning_rate": learning_rate,
271
- "batch_size": batch_size,
272
- "num_epochs": num_epochs,
273
- "dataset_size": len(dataset),
274
- "trained_at": datetime.now().isoformat(),
275
- "lora_config": {
276
- "rank": self.lora_config.r,
277
- "alpha": self.lora_config.lora_alpha
278
- }
279
- }
280
-
281
- info_path = output_dir / "training_info.json"
282
- with open(info_path, 'w') as f:
283
- json.dump(info, f, indent=2)
284
-
285
- logger.info(f"✅ Training complete! Model saved to {output_dir}")
286
- return str(output_dir)
287
-
288
- except Exception as e:
289
- logger.error(f"Training failed: {e}")
290
- raise
291
-
292
- def _load_dataset(self, dataset_path: str) -> List[Dict[str, Any]]:
293
- """Load prepared dataset."""
294
- dataset_path = Path(dataset_path)
295
-
296
- # Load metadata
297
- metadata_path = dataset_path / "metadata.json"
298
- if metadata_path.exists():
299
- with open(metadata_path, 'r') as f:
300
- metadata = json.load(f)
301
- files = metadata.get("files", [])
302
- else:
303
- # Scan directory for audio files
304
- files = list(dataset_path.glob("*.wav"))
305
-
306
- dataset = []
307
- for file_path in files:
308
- dataset.append({
309
- "path": str(file_path),
310
- "audio": None # Lazy loading
311
- })
312
-
313
- return dataset
314
-
315
- def _training_step(self, batch: List[Dict[str, Any]]) -> torch.Tensor:
316
- """
317
- Perform single training step.
318
-
319
- This is a simplified placeholder - actual implementation would:
320
- 1. Load audio from batch
321
- 2. Encode to latent space
322
- 3. Generate predictions
323
- 4. Calculate loss
324
- 5. Return loss
325
-
326
- Args:
327
- batch: Training batch
328
-
329
- Returns:
330
- Loss tensor
331
- """
332
- # Placeholder loss calculation
333
- # Actual implementation would process audio through model
334
- loss = torch.tensor(0.5, requires_grad=True, device=self.device)
335
- return loss
336
-
337
- def export_for_inference(self, lora_path: str, output_path: str):
338
- """
339
- Export LoRA model for inference.
340
-
341
- Args:
342
- lora_path: Path to LoRA model
343
- output_path: Output path for exported model
344
- """
345
- try:
346
- # Load LoRA
347
- self.load_lora(lora_path)
348
-
349
- # Merge LoRA with base model
350
- merged_model = self.model.merge_and_unload()
351
-
352
- # Save merged model
353
- merged_model.save_pretrained(output_path)
354
-
355
- logger.info(f"✅ Exported model to {output_path}")
356
-
357
- except Exception as e:
358
- logger.error(f"Export failed: {e}")
359
- raise
 
1
  """
2
+ HuggingFace Dataset Download Utility for LoRA Training Studio.
3
+
4
+ Provides a helper to download audio datasets from HuggingFace Hub.
5
+ The actual training pipeline lives in acestep/training/.
6
  """
7
 
 
 
 
8
  import logging
9
+ from pathlib import Path
10
+ from typing import Optional, Tuple
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
+ AUDIO_EXTENSIONS = ["*.wav", "*.mp3", "*.flac", "*.ogg", "*.opus"]
15
+
16
+
17
+ def download_hf_dataset(
18
+ dataset_id: str,
19
+ output_dir: str,
20
+ hf_token: Optional[str] = None,
21
+ ) -> Tuple[str, str]:
22
+ """
23
+ Download an audio dataset from HuggingFace Hub.
24
+
25
+ Uses snapshot_download to fetch only audio files from the repo,
26
+ skipping non-audio content like READMEs, metadata, etc.
27
+
28
+ Args:
29
+ dataset_id: HuggingFace dataset repo ID (e.g. "pedroapfilho/lofi-tracks")
30
+ output_dir: Local directory to download into
31
+ hf_token: Optional HuggingFace token for private repos
32
+
33
+ Returns:
34
+ Tuple of (local_dir, status_message)
35
+ """
36
+ try:
37
+ from huggingface_hub import snapshot_download
38
+
39
+ output_path = Path(output_dir)
40
+ output_path.mkdir(parents=True, exist_ok=True)
41
+
42
+ logger.info(f"Downloading dataset '{dataset_id}' to {output_dir}...")
43
+
44
+ local_dir = snapshot_download(
45
+ repo_id=dataset_id,
46
+ repo_type="dataset",
47
+ local_dir=str(output_path / dataset_id.replace("/", "_")),
48
+ token=hf_token or None,
49
+ allow_patterns=AUDIO_EXTENSIONS,
50
+ )
51
+
52
+ audio_count = sum(
53
+ 1
54
+ for ext in AUDIO_EXTENSIONS
55
+ for _ in Path(local_dir).rglob(ext)
56
+ )
57
+
58
+ status = f"Downloaded {audio_count} audio files from {dataset_id}"
59
+ logger.info(status)
60
+ return local_dir, status
61
 
62
+ except ImportError:
63
+ msg = "huggingface_hub is not installed. Run: pip install huggingface_hub"
64
+ logger.error(msg)
65
+ return "", msg
66
+ except Exception as e:
67
+ msg = f"Failed to download dataset: {e}"
68
+ logger.error(msg)
69
+ return "", msg