| """ |
| ACE-Step 1.5 LoRA Training Engine |
| |
| Handles dataset building, VAE encoding, and flow-matching LoRA training |
| of the DiT decoder. Designed to work with the existing AceStepHandler. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import math |
| import time |
| import random |
| import hashlib |
| import argparse |
| import tempfile |
| from pathlib import Path |
| from dataclasses import dataclass, field, asdict |
| from typing import Optional, List, Dict, Any, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
| import torchaudio |
| import soundfile as sf |
| import numpy as np |
| from loguru import logger |
| from tqdm import tqdm |
|
|
| |
| |
| |
|
|
| AUDIO_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aac"} |
|
|
|
|
| @dataclass |
| class TrackEntry: |
| """One audio file + its metadata.""" |
|
|
| audio_path: str |
| caption: str = "" |
| lyrics: str = "" |
| bpm: Optional[int] = None |
| keyscale: str = "" |
| timesignature: str = "4/4" |
| vocal_language: str = "en" |
| duration: Optional[float] = None |
|
|
|
|
| def _load_track_entry(audio_path: Path) -> TrackEntry: |
| """Load one track + optional sidecar metadata.""" |
| sidecar = audio_path.with_suffix(".json") |
| meta: Dict[str, Any] = {} |
| if sidecar.exists(): |
| try: |
| meta = json.loads(sidecar.read_text(encoding="utf-8")) |
| except Exception as exc: |
| logger.warning(f"Bad sidecar {sidecar}: {exc}") |
|
|
| try: |
| info = torchaudio.info(str(audio_path)) |
| duration = info.num_frames / info.sample_rate |
| except Exception: |
| duration = meta.get("duration") |
|
|
| return TrackEntry( |
| audio_path=str(audio_path), |
| caption=meta.get("caption", ""), |
| lyrics=meta.get("lyrics", ""), |
| bpm=meta.get("bpm"), |
| keyscale=meta.get("keyscale", ""), |
| timesignature=meta.get("timesignature", "4/4"), |
| vocal_language=meta.get("vocal_language", "en"), |
| duration=duration, |
| ) |
|
|
|
|
| def scan_dataset_folder(folder: str) -> List[TrackEntry]: |
| """Scan *folder* for audio files and optional JSON sidecars. |
| |
| For every ``track.wav`` found, if ``track.json`` exists next to it the |
| metadata fields are loaded from the sidecar. Missing sidecars are fine – |
| the entry will have empty metadata that can be filled later. |
| """ |
| folder = Path(folder) |
| if not folder.is_dir(): |
| raise FileNotFoundError(f"Dataset folder not found: {folder}") |
|
|
| entries: List[TrackEntry] = [] |
| for audio_path in sorted(folder.rglob("*")): |
| if audio_path.suffix.lower() not in AUDIO_EXTENSIONS: |
| continue |
| entries.append(_load_track_entry(audio_path)) |
|
|
| logger.info(f"Scanned {len(entries)} audio files in {folder}") |
| return entries |
|
|
|
|
| def scan_uploaded_files(file_paths: List[str]) -> List[TrackEntry]: |
| """Build entries from dropped/uploaded files. |
| |
| Supports uploading audio files together with optional ``.json`` sidecars. |
| Sidecars are matched by basename stem (``song.mp3`` <-> ``song.json``). |
| """ |
| meta_by_stem: Dict[str, Dict[str, Any]] = {} |
| for path in file_paths: |
| p = Path(path) |
| if not p.exists() or p.suffix.lower() != ".json": |
| continue |
| try: |
| meta_by_stem[p.stem] = json.loads(p.read_text(encoding="utf-8")) |
| except Exception as exc: |
| logger.warning(f"Bad uploaded sidecar {p}: {exc}") |
|
|
| entries: List[TrackEntry] = [] |
| for path in file_paths: |
| p = Path(path) |
| if not p.exists() or p.suffix.lower() not in AUDIO_EXTENSIONS: |
| continue |
|
|
| uploaded_meta = meta_by_stem.get(p.stem) |
| if uploaded_meta is None: |
| entries.append(_load_track_entry(p)) |
| continue |
|
|
| try: |
| info = torchaudio.info(str(p)) |
| duration = info.num_frames / info.sample_rate |
| except Exception: |
| duration = uploaded_meta.get("duration") |
|
|
| bpm_val = uploaded_meta.get("bpm") |
| if isinstance(bpm_val, str) and bpm_val.strip(): |
| try: |
| bpm_val = int(float(bpm_val)) |
| except Exception: |
| bpm_val = None |
|
|
| entries.append( |
| TrackEntry( |
| audio_path=str(p), |
| caption=uploaded_meta.get("caption", "") or "", |
| lyrics=uploaded_meta.get("lyrics", "") or "", |
| bpm=bpm_val if isinstance(bpm_val, int) else None, |
| keyscale=uploaded_meta.get("keyscale", "") or "", |
| timesignature=uploaded_meta.get("timesignature", "4/4") or "4/4", |
| vocal_language=uploaded_meta.get("vocal_language", uploaded_meta.get("language", "en")) or "en", |
| duration=duration, |
| ) |
| ) |
|
|
| logger.info( |
| "Loaded {} uploaded audio files ({} uploaded sidecars detected)".format( |
| len(entries), len(meta_by_stem) |
| ) |
| ) |
| return entries |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class LoRATrainConfig: |
| """All tuneable knobs for a LoRA run.""" |
|
|
| |
| lora_rank: int = 64 |
| lora_alpha: int = 64 |
| lora_dropout: float = 0.1 |
| lora_target_modules: List[str] = field( |
| default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] |
| ) |
|
|
| |
| learning_rate: float = 1e-4 |
| weight_decay: float = 0.01 |
| optimizer: str = "adamw_8bit" |
| max_grad_norm: float = 1.0 |
|
|
| |
| warmup_ratio: float = 0.03 |
| scheduler: str = "constant_with_warmup" |
|
|
| |
| num_epochs: int = 50 |
| batch_size: int = 1 |
| gradient_accumulation_steps: int = 1 |
| save_every_n_epochs: int = 10 |
| log_every_n_steps: int = 5 |
|
|
| |
| shift: float = 3.0 |
|
|
| |
| max_duration_sec: float = 240.0 |
| sample_rate: int = 48000 |
|
|
| |
| output_dir: str = "lora_output" |
| resume_from: Optional[str] = None |
|
|
| |
| device: str = "auto" |
| dtype: str = "bf16" |
| mixed_precision: bool = True |
|
|
|
|
| |
| |
| |
|
|
|
|
| class LoRATrainer: |
| """Thin training loop that wraps the existing AceStepHandler.""" |
|
|
| def __init__(self, handler, config: LoRATrainConfig): |
| """ |
| Args: |
| handler: Initialised ``AceStepHandler`` (model, vae, text_encoder loaded). |
| config: Training hyper-parameters. |
| """ |
| self.handler = handler |
| self.cfg = config |
|
|
| self.device = handler.device |
| self.dtype = handler.dtype |
|
|
| |
| self.peft_model = None |
| self.optimizer = None |
| self.scheduler = None |
| self.global_step = 0 |
| self.current_epoch = 0 |
|
|
| |
| self.loss_history: List[Dict[str, Any]] = [] |
| self._stop_requested = False |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _resolve_lora_target_modules(model, requested_targets: Optional[List[str]]) -> List[str]: |
| """Resolve LoRA target module suffixes against the actual decoder module names.""" |
| linear_module_names = [ |
| name for name, module in model.named_modules() if isinstance(module, torch.nn.Linear) |
| ] |
|
|
| def _exists_as_suffix(target: str) -> bool: |
| return any(name.endswith(target) for name in linear_module_names) |
|
|
| requested_targets = requested_targets or [] |
| resolved = [target for target in requested_targets if _exists_as_suffix(target)] |
| if resolved: |
| return resolved |
|
|
| fallback_groups = [ |
| ["q_proj", "k_proj", "v_proj", "o_proj"], |
| ["to_q", "to_k", "to_v", "to_out.0"], |
| ["query", "key", "value", "out_proj"], |
| ["wq", "wk", "wv", "wo"], |
| ["qkv", "proj_out"], |
| ] |
| for group in fallback_groups: |
| group_resolved = [target for target in group if _exists_as_suffix(target)] |
| if len(group_resolved) >= 2: |
| return group_resolved |
|
|
| sample = ", ".join(linear_module_names[:30]) |
| raise ValueError( |
| "Could not find LoRA target modules in decoder. " |
| f"Requested={requested_targets}. " |
| f"Sample linear modules: {sample}" |
| ) |
|
|
| def prepare(self): |
| """Attach LoRA adapters to the decoder and build the optimiser.""" |
| import copy |
| from peft import LoraConfig, PeftModel, TaskType, get_peft_model |
|
|
| |
| if self.handler._base_decoder is None: |
| self.handler._base_decoder = copy.deepcopy(self.handler.model.decoder) |
| else: |
| self.handler.model.decoder = copy.deepcopy(self.handler._base_decoder) |
| self.handler.model.decoder = self.handler.model.decoder.to(self.device).to(self.dtype) |
| self.handler.model.decoder.eval() |
|
|
| resume_adapter = None |
| if self.cfg.resume_from: |
| adapter_cfg = os.path.join(self.cfg.resume_from, "adapter_config.json") |
| if os.path.isfile(adapter_cfg): |
| resume_adapter = self.cfg.resume_from |
|
|
| if resume_adapter: |
| logger.info(f"Loading existing LoRA adapter for resume: {resume_adapter}") |
| self.peft_model = PeftModel.from_pretrained( |
| self.handler.model.decoder, |
| resume_adapter, |
| is_trainable=True, |
| ) |
| else: |
| resolved_targets = self._resolve_lora_target_modules( |
| self.handler.model.decoder, |
| self.cfg.lora_target_modules, |
| ) |
| logger.info(f"Using LoRA target modules: {resolved_targets}") |
| peft_cfg = LoraConfig( |
| r=self.cfg.lora_rank, |
| lora_alpha=self.cfg.lora_alpha, |
| lora_dropout=self.cfg.lora_dropout, |
| target_modules=resolved_targets, |
| bias="none", |
| task_type=TaskType.FEATURE_EXTRACTION, |
| ) |
| self.peft_model = get_peft_model(self.handler.model.decoder, peft_cfg) |
|
|
| self.peft_model.print_trainable_parameters() |
| self.handler.model.decoder = self.peft_model |
| self.handler.model.decoder.to(self.device).to(self.dtype) |
| self.handler.model.decoder.train() |
| self.handler.lora_loaded = True |
| self.handler.use_lora = True |
|
|
| |
| trainable_params = [p for p in self.peft_model.parameters() if p.requires_grad] |
| if self.cfg.optimizer == "adamw_8bit": |
| try: |
| import bitsandbytes as bnb |
| self.optimizer = bnb.optim.AdamW8bit( |
| trainable_params, |
| lr=self.cfg.learning_rate, |
| weight_decay=self.cfg.weight_decay, |
| ) |
| except ImportError: |
| logger.warning("bitsandbytes not found – falling back to standard AdamW") |
| self.optimizer = torch.optim.AdamW( |
| trainable_params, |
| lr=self.cfg.learning_rate, |
| weight_decay=self.cfg.weight_decay, |
| ) |
| else: |
| self.optimizer = torch.optim.AdamW( |
| trainable_params, |
| lr=self.cfg.learning_rate, |
| weight_decay=self.cfg.weight_decay, |
| ) |
|
|
| |
| if self.cfg.resume_from and os.path.isfile( |
| os.path.join(self.cfg.resume_from, "training_state.pt") |
| ): |
| state = torch.load( |
| os.path.join(self.cfg.resume_from, "training_state.pt"), |
| weights_only=False, |
| ) |
| try: |
| self.optimizer.load_state_dict(state["optimizer"]) |
| except Exception as exc: |
| logger.warning(f"Could not restore optimizer state, continuing fresh optimizer: {exc}") |
| self.global_step = int(state.get("global_step", 0)) |
| |
| self.current_epoch = int(state.get("epoch", -1)) + 1 |
| loss_path = os.path.join(self.cfg.resume_from, "loss_history.json") |
| if os.path.isfile(loss_path): |
| try: |
| with open(loss_path, "r", encoding="utf-8") as f: |
| self.loss_history = json.load(f) |
| except Exception: |
| pass |
| logger.info( |
| f"Resumed from {self.cfg.resume_from} " |
| f"(epoch {self.current_epoch}, step {self.global_step})" |
| ) |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _coerce_audio_tensor(audio: Any) -> torch.Tensor: |
| """Coerce decoded audio into torch.Tensor with shape [C, T].""" |
| if isinstance(audio, list): |
| audio = np.asarray(audio, dtype=np.float32) |
| if isinstance(audio, np.ndarray): |
| audio = torch.from_numpy(audio) |
| if not torch.is_tensor(audio): |
| raise TypeError(f"Unsupported audio type: {type(audio)}") |
|
|
| |
| if not torch.is_floating_point(audio): |
| audio = audio.float() |
|
|
| |
| if audio.dim() == 1: |
| audio = audio.unsqueeze(0) |
| elif audio.dim() == 2: |
| |
| if audio.shape[0] > audio.shape[1] and audio.shape[1] <= 8: |
| audio = audio.transpose(0, 1) |
| elif audio.dim() == 3: |
| |
| audio = audio[0] |
| else: |
| raise ValueError(f"Unexpected audio dims: {tuple(audio.shape)}") |
|
|
| return audio.contiguous() |
|
|
| def _load_audio(self, path: str) -> torch.Tensor: |
| """Load audio, resample to 48 kHz stereo, clamp to max_duration.""" |
| try: |
| wav, sr = torchaudio.load(path) |
| except Exception as torchaudio_exc: |
| |
| |
| try: |
| audio_np, sr = sf.read(path, dtype="float32", always_2d=True) |
| wav = torch.from_numpy(audio_np.T) |
| except Exception as sf_exc: |
| raise RuntimeError( |
| f"Failed to decode audio '{path}' with torchaudio ({torchaudio_exc}) " |
| f"and soundfile ({sf_exc})." |
| ) from sf_exc |
|
|
| wav = self._coerce_audio_tensor(wav) |
|
|
| |
| if sr != self.cfg.sample_rate: |
| wav = torchaudio.functional.resample(wav, sr, self.cfg.sample_rate) |
|
|
| |
| if wav.shape[0] == 1: |
| wav = wav.repeat(2, 1) |
| elif wav.shape[0] > 2: |
| wav = wav[:2] |
|
|
| |
| max_samples = int(self.cfg.max_duration_sec * self.cfg.sample_rate) |
| if wav.shape[1] > max_samples: |
| wav = wav[:, :max_samples] |
|
|
| return wav |
|
|
| def _encode_audio(self, wav: torch.Tensor) -> torch.Tensor: |
| """Encode raw waveform → VAE latent on device.""" |
| with torch.no_grad(): |
| latent = self.handler._encode_audio_to_latents(wav) |
| if latent.dim() == 2: |
| latent = latent.unsqueeze(0) |
| latent = latent.to(self.dtype) |
| return latent |
|
|
| def _build_text_embeddings(self, caption: str, lyrics: str): |
| """Compute text & lyric embeddings using the text encoder.""" |
| tokenizer = self.handler.text_tokenizer |
| text_encoder = self.handler.text_encoder |
|
|
| |
| text_tokens = tokenizer( |
| caption or "", |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=512, |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| text_hidden = text_encoder( |
| input_ids=text_tokens["input_ids"] |
| ).last_hidden_state.to(self.dtype) |
| text_mask = text_tokens["attention_mask"].to(self.dtype) |
|
|
| |
| lyric_tokens = tokenizer( |
| lyrics or "", |
| return_tensors="pt", |
| padding="max_length", |
| truncation=True, |
| max_length=512, |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| lyric_hidden = text_encoder.embed_tokens( |
| lyric_tokens["input_ids"] |
| ).to(self.dtype) |
| lyric_mask = lyric_tokens["attention_mask"].to(self.dtype) |
|
|
| return text_hidden, text_mask, lyric_hidden, lyric_mask |
|
|
| |
| |
| |
|
|
| def _flow_matching_loss( |
| self, |
| x1: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| encoder_attention_mask: torch.Tensor, |
| context_latents: torch.Tensor, |
| ) -> torch.Tensor: |
| """Compute rectified-flow MSE loss for one sample. |
| |
| Notation follows ACE-Step convention: |
| x0 = noise, x1 = clean latent |
| xt = t * x0 + (1 - t) * x1 |
| target velocity = x0 - x1 |
| """ |
| bsz = x1.shape[0] |
|
|
| |
| t = torch.rand(bsz, device=self.device, dtype=self.dtype) |
|
|
| |
| if self.cfg.shift != 1.0: |
| t = self.cfg.shift * t / (1.0 + (self.cfg.shift - 1.0) * t) |
|
|
| t = t.clamp(1e-5, 1.0 - 1e-5) |
|
|
| |
| x0 = torch.randn_like(x1) |
|
|
| |
| t_expand = t.view(bsz, 1, 1) |
| xt = t_expand * x0 + (1.0 - t_expand) * x1 |
|
|
| |
| velocity_target = x0 - x1 |
|
|
| |
| attention_mask = torch.ones( |
| bsz, x1.shape[1], device=self.device, dtype=self.dtype |
| ) |
|
|
| |
| decoder_out = self.handler.model.decoder( |
| hidden_states=xt, |
| timestep=t, |
| timestep_r=t, |
| attention_mask=attention_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| context_latents=context_latents, |
| use_cache=False, |
| output_attentions=False, |
| ) |
|
|
| velocity_pred = decoder_out[0] |
| loss = F.mse_loss(velocity_pred, velocity_target) |
| return loss |
|
|
| @staticmethod |
| def _pad_and_stack(tensors: List[torch.Tensor], pad_value: float = 0.0) -> torch.Tensor: |
| """Pad variable-length tensors on dimension 0 and stack as batch.""" |
| normalized = [] |
| for t in tensors: |
| if t.dim() >= 2 and t.shape[0] == 1: |
| normalized.append(t.squeeze(0)) |
| else: |
| normalized.append(t) |
|
|
| max_len = max(t.shape[0] for t in normalized) |
| template = normalized[0] |
| out_shape = (len(normalized), max_len, *template.shape[1:]) |
| out = template.new_full(out_shape, pad_value) |
| for i, t in enumerate(normalized): |
| out[i, : t.shape[0]] = t |
| return out |
|
|
| |
| |
| |
|
|
| def request_stop(self): |
| """Ask the training loop to stop after the current step.""" |
| self._stop_requested = True |
|
|
| def train( |
| self, |
| entries: List[TrackEntry], |
| progress_callback=None, |
| ) -> str: |
| """Run the full LoRA training. |
| |
| Args: |
| entries: List of scanned TrackEntry objects. |
| progress_callback: ``fn(step, total_steps, loss, epoch)`` for UI updates. |
| |
| Returns: |
| Status message. |
| """ |
| self._stop_requested = False |
| self.loss_history.clear() |
| os.makedirs(self.cfg.output_dir, exist_ok=True) |
|
|
| if not entries: |
| return "No training data provided." |
|
|
| num_entries = len(entries) |
| total_steps = ( |
| math.ceil(num_entries / self.cfg.batch_size) |
| * self.cfg.num_epochs |
| ) |
|
|
| |
| logger.info("Pre-encoding dataset through VAE & text encoder ...") |
| dataset: List[Dict[str, Any]] = [] |
| failed_encode: List[str] = [] |
|
|
| |
| self.handler.vae.eval() |
| self.handler.text_encoder.eval() |
|
|
| |
| ref_latent = self.handler.silence_latent[:, :750, :].to(self.device).to(self.dtype) |
| ref_order_mask = torch.zeros(1, device=self.device, dtype=torch.long) |
|
|
| for idx, entry in enumerate(tqdm(entries, desc="Encoding dataset")): |
| try: |
| wav = self._load_audio(entry.audio_path) |
| latent = self._encode_audio(wav) |
| text_h, text_m, lyric_h, lyric_m = self._build_text_embeddings( |
| entry.caption, entry.lyrics |
| ) |
|
|
| |
| with torch.no_grad(): |
| enc_hs, enc_mask, ctx_lat = self.handler.model.prepare_condition( |
| text_hidden_states=text_h, |
| text_attention_mask=text_m, |
| lyric_hidden_states=lyric_h, |
| lyric_attention_mask=lyric_m, |
| refer_audio_acoustic_hidden_states_packed=ref_latent, |
| refer_audio_order_mask=ref_order_mask, |
| hidden_states=latent, |
| attention_mask=torch.ones( |
| 1, latent.shape[1], |
| device=self.device, dtype=self.dtype, |
| ), |
| silence_latent=self.handler.silence_latent, |
| src_latents=latent, |
| chunk_masks=torch.ones_like(latent), |
| is_covers=[False], |
| ) |
|
|
| dataset.append( |
| { |
| "latent": latent.cpu(), |
| "enc_hs": enc_hs.cpu(), |
| "enc_mask": enc_mask.cpu(), |
| "ctx_lat": ctx_lat.cpu(), |
| "name": Path(entry.audio_path).stem, |
| } |
| ) |
| except Exception as exc: |
| reason = f"{Path(entry.audio_path).name}: {exc}" |
| failed_encode.append(reason) |
| logger.warning(f"Skipping {entry.audio_path}: {exc}") |
|
|
| if not dataset: |
| preview = "\n".join(f"- {msg}" for msg in failed_encode[:8]) or "- (no detailed errors captured)" |
| return ( |
| "All tracks failed to encode. Check audio files.\n" |
| "First errors:\n" |
| f"{preview}\n" |
| "Tip: try WAV/FLAC files and dataset folder scan instead of temporary uploads." |
| ) |
|
|
| logger.info(f"Encoded {len(dataset)}/{num_entries} tracks.") |
|
|
| |
| total_optim_steps = math.ceil( |
| total_steps / self.cfg.gradient_accumulation_steps |
| ) |
| warmup_steps = int(total_optim_steps * self.cfg.warmup_ratio) |
|
|
| if self.cfg.scheduler in {"constant_with_warmup", "linear", "cosine"}: |
| try: |
| from transformers import get_scheduler |
| self.scheduler = get_scheduler( |
| name=self.cfg.scheduler, |
| optimizer=self.optimizer, |
| num_warmup_steps=warmup_steps, |
| num_training_steps=total_optim_steps, |
| ) |
| except Exception as exc: |
| logger.warning(f"Could not create scheduler '{self.cfg.scheduler}', disabling scheduler: {exc}") |
| self.scheduler = None |
| else: |
| self.scheduler = None |
|
|
| |
| logger.info( |
| f"Starting LoRA training: {self.cfg.num_epochs} epochs, " |
| f"{len(dataset)} samples, {total_optim_steps} optimiser steps" |
| ) |
|
|
| self.peft_model.train() |
| accum_loss = 0.0 |
| step_in_accum = 0 |
|
|
| for epoch in range(self.current_epoch, self.cfg.num_epochs): |
| if self._stop_requested: |
| break |
|
|
| self.current_epoch = epoch |
| indices = list(range(len(dataset))) |
| random.shuffle(indices) |
|
|
| epoch_loss = 0.0 |
| epoch_steps = 0 |
|
|
| for i in range(0, len(indices), self.cfg.batch_size): |
| if self._stop_requested: |
| break |
|
|
| batch_indices = indices[i : i + self.cfg.batch_size] |
| batch_items = [dataset[j] for j in batch_indices] |
|
|
| |
| latents = self._pad_and_stack([it["latent"] for it in batch_items]).to(self.device, self.dtype) |
| enc_hs = self._pad_and_stack([it["enc_hs"] for it in batch_items]).to(self.device, self.dtype) |
| enc_mask = self._pad_and_stack([it["enc_mask"] for it in batch_items], pad_value=0.0).to(self.device) |
| if enc_mask.dtype != self.dtype: |
| enc_mask = enc_mask.to(self.dtype) |
| ctx_lat = self._pad_and_stack([it["ctx_lat"] for it in batch_items]).to(self.device, self.dtype) |
|
|
| |
| loss = self._flow_matching_loss(latents, enc_hs, enc_mask, ctx_lat) |
| loss = loss / self.cfg.gradient_accumulation_steps |
| loss.backward() |
|
|
| accum_loss += loss.item() |
| step_in_accum += 1 |
|
|
| if step_in_accum >= self.cfg.gradient_accumulation_steps: |
| torch.nn.utils.clip_grad_norm_( |
| self.peft_model.parameters(), self.cfg.max_grad_norm |
| ) |
| self.optimizer.step() |
| if self.scheduler is not None: |
| self.scheduler.step() |
| self.optimizer.zero_grad() |
|
|
| self.global_step += 1 |
| avg_loss = accum_loss |
| accum_loss = 0.0 |
| step_in_accum = 0 |
|
|
| self.loss_history.append( |
| { |
| "step": self.global_step, |
| "epoch": epoch, |
| "loss": avg_loss, |
| "lr": self.optimizer.param_groups[0]["lr"], |
| } |
| ) |
|
|
| if self.global_step % self.cfg.log_every_n_steps == 0: |
| logger.info( |
| f"Epoch {epoch+1}/{self.cfg.num_epochs} " |
| f"Step {self.global_step}/{total_optim_steps} " |
| f"Loss {avg_loss:.6f} " |
| f"LR {self.optimizer.param_groups[0]['lr']:.2e}" |
| ) |
|
|
| if progress_callback: |
| progress_callback( |
| self.global_step, total_optim_steps, avg_loss, epoch |
| ) |
|
|
| epoch_loss += loss.item() * self.cfg.gradient_accumulation_steps |
| epoch_steps += 1 |
|
|
| |
| if step_in_accum > 0: |
| torch.nn.utils.clip_grad_norm_(self.peft_model.parameters(), self.cfg.max_grad_norm) |
| self.optimizer.step() |
| if self.scheduler is not None: |
| self.scheduler.step() |
| self.optimizer.zero_grad() |
| self.global_step += 1 |
| avg_loss = accum_loss |
| accum_loss = 0.0 |
| step_in_accum = 0 |
| self.loss_history.append( |
| { |
| "step": self.global_step, |
| "epoch": epoch, |
| "loss": avg_loss, |
| "lr": self.optimizer.param_groups[0]["lr"], |
| } |
| ) |
|
|
| |
| if ( |
| (epoch + 1) % self.cfg.save_every_n_epochs == 0 |
| or epoch == self.cfg.num_epochs - 1 |
| or self._stop_requested |
| ): |
| self._save_checkpoint(epoch) |
|
|
| if epoch_steps > 0: |
| avg_epoch_loss = epoch_loss / epoch_steps |
| logger.info( |
| f"Epoch {epoch+1} complete – avg loss {avg_epoch_loss:.6f}" |
| ) |
|
|
| |
| final_dir = self._save_checkpoint(self.current_epoch, final=True) |
| status = ( |
| "Training stopped early." if self._stop_requested else "Training complete!" |
| ) |
| return f"{status} Adapter saved to {final_dir}" |
|
|
| |
| |
| |
|
|
| def _save_checkpoint(self, epoch: int, final: bool = False) -> str: |
| tag = "final" if final else f"epoch-{epoch+1}" |
| save_dir = os.path.join(self.cfg.output_dir, tag) |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| |
| self.peft_model.save_pretrained(save_dir) |
|
|
| |
| torch.save( |
| { |
| "optimizer": self.optimizer.state_dict(), |
| "global_step": self.global_step, |
| "epoch": epoch, |
| }, |
| os.path.join(save_dir, "training_state.pt"), |
| ) |
|
|
| |
| loss_path = os.path.join(save_dir, "loss_history.json") |
| with open(loss_path, "w") as f: |
| json.dump(self.loss_history, f) |
|
|
| |
| cfg_path = os.path.join(save_dir, "train_config.json") |
| with open(cfg_path, "w") as f: |
| json.dump(asdict(self.cfg), f, indent=2) |
|
|
| logger.info(f"Checkpoint saved → {save_dir}") |
| return save_dir |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def list_adapters(output_dir: str = "lora_output") -> List[str]: |
| """Return adapter directories inside *output_dir* (recursive).""" |
| results = [] |
| root = Path(output_dir) |
| if not root.is_dir(): |
| return results |
| for cfg in sorted(root.rglob("adapter_config.json")): |
| d = cfg.parent |
| if d.is_dir(): |
| results.append(str(d)) |
| return results |
|
|
|
|
| def _build_arg_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="ACE-Step 1.5 LoRA trainer (CLI)") |
|
|
| |
| parser.add_argument("--dataset-dir", type=str, default="", help="Local dataset folder path") |
| parser.add_argument("--dataset-repo", type=str, default="", help="HF dataset repo id (optional)") |
| parser.add_argument("--dataset-revision", type=str, default="main", help="HF dataset revision") |
| parser.add_argument("--dataset-subdir", type=str, default="", help="Subdirectory inside downloaded dataset") |
|
|
| |
| parser.add_argument("--model-config", type=str, default="acestep-v15-base", help="DiT config name") |
| parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"]) |
| parser.add_argument("--offload-to-cpu", action="store_true") |
| parser.add_argument("--offload-dit-to-cpu", action="store_true") |
| parser.add_argument("--prefer-source", type=str, default="huggingface", choices=["huggingface", "modelscope"]) |
|
|
| |
| parser.add_argument("--output-dir", type=str, default="lora_output") |
| parser.add_argument("--resume-from", type=str, default="") |
| parser.add_argument("--num-epochs", type=int, default=50) |
| parser.add_argument("--batch-size", type=int, default=1) |
| parser.add_argument("--grad-accum", type=int, default=1) |
| parser.add_argument("--save-every", type=int, default=10) |
| parser.add_argument("--log-every", type=int, default=5) |
| parser.add_argument("--max-duration-sec", type=float, default=240.0) |
|
|
| parser.add_argument("--lora-rank", type=int, default=64) |
| parser.add_argument("--lora-alpha", type=int, default=64) |
| parser.add_argument("--lora-dropout", type=float, default=0.1) |
|
|
| parser.add_argument("--learning-rate", type=float, default=1e-4) |
| parser.add_argument("--weight-decay", type=float, default=0.01) |
| parser.add_argument("--optimizer", type=str, default="adamw_8bit", choices=["adamw", "adamw_8bit"]) |
| parser.add_argument("--max-grad-norm", type=float, default=1.0) |
| parser.add_argument("--warmup-ratio", type=float, default=0.03) |
| parser.add_argument("--scheduler", type=str, default="constant_with_warmup", choices=["constant_with_warmup", "linear", "cosine"]) |
| parser.add_argument("--shift", type=float, default=3.0) |
|
|
| |
| parser.add_argument("--upload-repo", type=str, default="", help="HF model repo to upload final adapter") |
| parser.add_argument("--upload-path", type=str, default="", help="Path inside upload repo (optional)") |
| parser.add_argument("--upload-private", action="store_true") |
| parser.add_argument("--hf-token-env", type=str, default="HF_TOKEN", help="Environment variable name for HF token") |
|
|
| return parser |
|
|
|
|
| def _resolve_dataset_dir(args) -> str: |
| if args.dataset_dir: |
| return args.dataset_dir |
|
|
| if not args.dataset_repo: |
| raise ValueError("Provide --dataset-dir or --dataset-repo.") |
|
|
| from huggingface_hub import snapshot_download |
|
|
| token = os.getenv(args.hf_token_env) |
| temp_root = tempfile.mkdtemp(prefix="acestep_lora_dataset_") |
| local_dir = os.path.join(temp_root, "dataset") |
| logger.info(f"Downloading dataset repo {args.dataset_repo}@{args.dataset_revision} to {local_dir}") |
| snapshot_download( |
| repo_id=args.dataset_repo, |
| repo_type="dataset", |
| revision=args.dataset_revision, |
| local_dir=local_dir, |
| local_dir_use_symlinks=False, |
| token=token, |
| ) |
| if args.dataset_subdir: |
| sub = os.path.join(local_dir, args.dataset_subdir) |
| if not os.path.isdir(sub): |
| raise FileNotFoundError(f"Dataset subdir not found: {sub}") |
| return sub |
| return local_dir |
|
|
|
|
| def _upload_adapter_if_requested(args, final_dir: str): |
| if not args.upload_repo: |
| return |
|
|
| from huggingface_hub import HfApi |
|
|
| token = os.getenv(args.hf_token_env) |
| if not token: |
| raise RuntimeError( |
| f"{args.hf_token_env} is not set. Needed for upload to {args.upload_repo}." |
| ) |
|
|
| api = HfApi(token=token) |
| api.create_repo( |
| repo_id=args.upload_repo, |
| repo_type="model", |
| exist_ok=True, |
| private=bool(args.upload_private), |
| ) |
|
|
| path_in_repo = args.upload_path.strip().strip("/") if args.upload_path else "" |
| commit_message = f"Upload ACE-Step LoRA adapter from {Path(final_dir).name}" |
| logger.info(f"Uploading adapter from {final_dir} to {args.upload_repo}/{path_in_repo}") |
| api.upload_folder( |
| repo_id=args.upload_repo, |
| repo_type="model", |
| folder_path=final_dir, |
| path_in_repo=path_in_repo, |
| commit_message=commit_message, |
| ) |
| logger.info("Upload complete") |
|
|
|
|
| def main(): |
| args = _build_arg_parser().parse_args() |
|
|
| dataset_dir = _resolve_dataset_dir(args) |
| entries = scan_dataset_folder(dataset_dir) |
| if not entries: |
| raise RuntimeError(f"No audio files found in dataset: {dataset_dir}") |
|
|
| from acestep.handler import AceStepHandler |
|
|
| project_root = str(Path(__file__).resolve().parent) |
| handler = AceStepHandler() |
| status, ok = handler.initialize_service( |
| project_root=project_root, |
| config_path=args.model_config, |
| device=args.device, |
| use_flash_attention=False, |
| compile_model=False, |
| offload_to_cpu=bool(args.offload_to_cpu), |
| offload_dit_to_cpu=bool(args.offload_dit_to_cpu), |
| prefer_source=args.prefer_source, |
| ) |
| print(status) |
| if not ok: |
| raise RuntimeError("Model initialization failed") |
|
|
| cfg = LoRATrainConfig( |
| lora_rank=args.lora_rank, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| learning_rate=args.learning_rate, |
| weight_decay=args.weight_decay, |
| optimizer=args.optimizer, |
| max_grad_norm=args.max_grad_norm, |
| warmup_ratio=args.warmup_ratio, |
| scheduler=args.scheduler, |
| num_epochs=args.num_epochs, |
| batch_size=args.batch_size, |
| gradient_accumulation_steps=args.grad_accum, |
| save_every_n_epochs=args.save_every, |
| log_every_n_steps=args.log_every, |
| shift=args.shift, |
| max_duration_sec=args.max_duration_sec, |
| output_dir=args.output_dir, |
| resume_from=(args.resume_from.strip() if args.resume_from else None), |
| device=args.device, |
| ) |
|
|
| trainer = LoRATrainer(handler, cfg) |
| trainer.prepare() |
|
|
| start = time.time() |
|
|
| def _progress(step, total, loss, epoch): |
| elapsed = time.time() - start |
| rate = step / elapsed if elapsed > 0 else 0.0 |
| remaining = max(0.0, total - step) |
| eta_sec = remaining / rate if rate > 0 else -1.0 |
| eta_msg = f"{eta_sec/60:.1f}m" if eta_sec >= 0 else "unknown" |
| logger.info( |
| f"[progress] step={step}/{total} epoch={epoch+1} loss={loss:.6f} elapsed={elapsed/60:.1f}m eta={eta_msg}" |
| ) |
|
|
| msg = trainer.train(entries, progress_callback=_progress) |
| print(msg) |
|
|
| final_dir = os.path.join(cfg.output_dir, "final") |
| if os.path.isdir(final_dir): |
| _upload_adapter_if_requested(args, final_dir) |
| print(f"Final adapter directory: {final_dir}") |
| else: |
| print(f"Warning: final adapter directory not found at {final_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|