Nekochu commited on
Commit
372f08e
·
1 Parent(s): 9d2d424

fix: train on XL turbo (matches XL GGUF for inference), add XL checkpoint download

Browse files
Files changed (2) hide show
  1. Dockerfile +7 -1
  2. app.py +2 -2
Dockerfile CHANGED
@@ -79,11 +79,17 @@ RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/w
79
  # Clone ACE-Step repo for training module
80
  RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
81
 
82
- # Pre-download training checkpoints (avoids 10GB runtime download)
 
83
  RUN python3 -c "from huggingface_hub import snapshot_download; \
84
  snapshot_download('ACE-Step/Ace-Step1.5', local_dir='/app/checkpoints', \
85
  ignore_patterns=['*.md', '*.txt', '.gitattributes'])"
86
 
 
 
 
 
 
87
  # Copy application files
88
  COPY app.py /app/app.py
89
  COPY train_engine.py /app/train_engine.py
 
79
  # Clone ACE-Step repo for training module
80
  RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
81
 
82
+ # Pre-download training checkpoints (avoids runtime download)
83
+ # Base repo has VAE + text encoder + standard turbo
84
  RUN python3 -c "from huggingface_hub import snapshot_download; \
85
  snapshot_download('ACE-Step/Ace-Step1.5', local_dir='/app/checkpoints', \
86
  ignore_patterns=['*.md', '*.txt', '.gitattributes'])"
87
 
88
+ # XL turbo checkpoint for training (matches XL GGUF used for inference)
89
+ RUN python3 -c "from huggingface_hub import snapshot_download; \
90
+ snapshot_download('ACE-Step/acestep-v15-xl-turbo', local_dir='/app/checkpoints/acestep-v15-xl-turbo', \
91
+ ignore_patterns=['*.md', '*.txt', '.gitattributes'])"
92
+
93
  # Copy application files
94
  COPY app.py /app/app.py
95
  COPY train_engine.py /app/train_engine.py
app.py CHANGED
@@ -514,7 +514,7 @@ def gradio_main():
514
  output_dir=preprocessed_dir,
515
  checkpoint_dir=ACE_CHECKPOINT_DIR,
516
  device="cpu",
517
- variant="turbo",
518
  max_duration=float(MAX_AUDIO_DURATION),
519
  progress_callback=preprocess_progress,
520
  cancel_check=lambda: False,
@@ -554,7 +554,7 @@ def gradio_main():
554
  max_grad_norm=1.0,
555
  save_every_n_epochs=max(1, epochs // 2),
556
  seed=42,
557
- variant="turbo",
558
  device="cpu",
559
  log_every=5,
560
  ):
 
514
  output_dir=preprocessed_dir,
515
  checkpoint_dir=ACE_CHECKPOINT_DIR,
516
  device="cpu",
517
+ variant="xl-turbo",
518
  max_duration=float(MAX_AUDIO_DURATION),
519
  progress_callback=preprocess_progress,
520
  cancel_check=lambda: False,
 
554
  max_grad_norm=1.0,
555
  save_every_n_epochs=max(1, epochs // 2),
556
  seed=42,
557
+ variant="xl-turbo",
558
  device="cpu",
559
  log_every=5,
560
  ):