""" Standalone ACE-Step CPU LoRA Training Engine. Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained module. No external Side-Step dependency required. Exports: preprocess_audio() - 2-pass sequential preprocessing train_lora_generator() - Generator-based LoRA training loop cancel_training() - Set the cancel flag get_trained_loras() - List saved adapters """ from __future__ import annotations import gc import json import logging import math import os import random import sys import time import types from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import ( CosineAnnealingLR, LinearLR, SequentialLR, ) from torch.utils.data import DataLoader, Dataset logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Configurable caps (edit these at the top of the file) # --------------------------------------------------------------------------- MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file MAX_TRAINING_TIME = 28800 # 8 hours hard timeout TARGET_SR = 48000 AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"}) # bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32 CPU_DTYPE = torch.float32 import threading _training_cancel = threading.Event() def cancel_training() -> None: _training_cancel.set() # ============================================================================ # CONFIGS # ============================================================================ @dataclass class LoRAConfig: r: int = 64 alpha: int = 128 dropout: float = 0.1 target_modules: List[str] = field(default_factory=lambda: [ "q_proj", "k_proj", "v_proj", "o_proj", ]) bias: str = "none" attention_type: str = "both" target_mlp: bool = True # ============================================================================ # TIMESTEP SAMPLING & CFG DROPOUT # ============================================================================ def sample_timesteps( batch_size: int, device: torch.device, dtype: torch.dtype, timestep_mu: float = -0.4, timestep_sigma: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: t = torch.sigmoid( torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu ) r = torch.sigmoid( torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu ) t, r = torch.maximum(t, r), torch.minimum(t, r) # use_meanflow=False forces r=t (ACE-Step convention) return t, t def apply_cfg_dropout( encoder_hidden_states: torch.Tensor, null_condition_emb: torch.Tensor, cfg_ratio: float = 0.15, ) -> torch.Tensor: bsz = encoder_hidden_states.shape[0] device = encoder_hidden_states.device dtype = encoder_hidden_states.dtype mask = torch.where( torch.rand(size=(bsz,), device=device, dtype=dtype) < cfg_ratio, torch.zeros(size=(bsz,), device=device, dtype=dtype), torch.ones(size=(bsz,), device=device, dtype=dtype), ).view(-1, 1, 1) return torch.where( mask > 0, encoder_hidden_states, null_condition_emb.expand_as(encoder_hidden_states), ) # ============================================================================ # OPTIMIZER (Adafactor preferred for CPU -- 1.5 bytes/param) # ============================================================================ def build_optimizer( params, lr: float = 1e-4, weight_decay: float = 0.01, ) -> torch.optim.Optimizer: try: from transformers.optimization import Adafactor logger.info("Using Adafactor optimizer (minimal state memory)") return Adafactor( params, lr=lr, weight_decay=weight_decay, scale_parameter=False, relative_step=False, ) except ImportError: logger.warning("transformers not installed, falling back to AdamW") return AdamW(params, lr=lr, weight_decay=weight_decay) def build_scheduler( optimizer, total_steps: int, warmup_steps: int, lr: float, ): _max_warmup = max(1, total_steps // 10) if warmup_steps > _max_warmup: warmup_steps = _max_warmup warmup = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps) remaining = max(1, total_steps - warmup_steps) main = CosineAnnealingLR(optimizer, T_max=remaining, eta_min=lr * 0.01) return SequentialLR(optimizer, [warmup, main], milestones=[warmup_steps]) # ============================================================================ # DATASET # ============================================================================ def _collate_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: max_t = max(s["target_latents"].shape[0] for s in batch) max_e = max(s["encoder_hidden_states"].shape[0] for s in batch) def pad(t, max_len, dim=0): diff = max_len - t.shape[dim] if diff <= 0: return t shape = list(t.shape) shape[dim] = diff return torch.cat([t, t.new_zeros(*shape)], dim=dim) return { "target_latents": torch.stack([pad(s["target_latents"], max_t) for s in batch]), "attention_mask": torch.stack([pad(s["attention_mask"], max_t) for s in batch]), "encoder_hidden_states": torch.stack([pad(s["encoder_hidden_states"], max_e) for s in batch]), "encoder_attention_mask": torch.stack([pad(s["encoder_attention_mask"], max_e) for s in batch]), "context_latents": torch.stack([pad(s["context_latents"], max_t) for s in batch]), } class TensorDataset(Dataset): _REQUIRED = frozenset([ "target_latents", "attention_mask", "encoder_hidden_states", "encoder_attention_mask", "context_latents", ]) def __init__(self, tensor_dir: str): self.paths: List[str] = [] for f in sorted(os.listdir(tensor_dir)): if f.endswith(".pt") and not f.endswith(".tmp.pt") and f != "manifest.json": self.paths.append(str(Path(tensor_dir) / f)) def __len__(self) -> int: return len(self.paths) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: data = torch.load(self.paths[idx], map_location="cpu", weights_only=True) missing = self._REQUIRED - data.keys() if missing: raise KeyError(f"Missing keys {sorted(missing)} in {self.paths[idx]}") for k in ("target_latents", "encoder_hidden_states", "context_latents"): t = data[k] if torch.isnan(t).any() or torch.isinf(t).any(): t.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) return {k: data[k] for k in self._REQUIRED} # ============================================================================ # GRADIENT CHECKPOINTING # ============================================================================ def _find_decoder_layers(decoder: nn.Module) -> Optional[nn.ModuleList]: for attr in ("layers", "blocks", "transformer_blocks"): c = getattr(decoder, attr, None) if isinstance(c, nn.ModuleList) and len(c) > 0: return c for child in decoder.children(): for attr in ("layers", "blocks", "transformer_blocks"): c = getattr(child, attr, None) if isinstance(c, nn.ModuleList) and len(c) > 0: return c return None def enable_gradient_checkpointing(decoder: nn.Module) -> bool: """Enable gradient checkpointing on the decoder to save memory.""" enabled = False # Walk wrapper chain stack = [decoder] visited = set() while stack: mod = stack.pop() if not isinstance(mod, nn.Module): continue mid = id(mod) if mid in visited: continue visited.add(mid) if hasattr(mod, "gradient_checkpointing_enable"): try: mod.gradient_checkpointing_enable() enabled = True except Exception: pass elif hasattr(mod, "gradient_checkpointing"): try: mod.gradient_checkpointing = True enabled = True except Exception: pass if hasattr(mod, "enable_input_require_grads"): try: mod.enable_input_require_grads() except Exception: pass cfg = getattr(mod, "config", None) if cfg is not None and hasattr(cfg, "use_cache"): try: cfg.use_cache = False except Exception: pass for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"): child = getattr(mod, a, None) if isinstance(child, nn.Module): stack.append(child) return enabled def force_disable_cache(decoder: nn.Module) -> None: stack = [decoder] visited = set() while stack: mod = stack.pop() if not isinstance(mod, nn.Module): continue mid = id(mod) if mid in visited: continue visited.add(mid) cfg = getattr(mod, "config", None) if cfg is not None and hasattr(cfg, "use_cache"): try: cfg.use_cache = False except Exception: pass for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"): child = getattr(mod, a, None) if isinstance(child, nn.Module): stack.append(child) # ============================================================================ # LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT) # ============================================================================ def _unwrap_decoder(model): decoder = model.decoder if hasattr(model, "decoder") else model while hasattr(decoder, "_forward_module"): decoder = decoder._forward_module if hasattr(decoder, "base_model"): bm = decoder.base_model decoder = bm.model if hasattr(bm, "model") else bm if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module): decoder = decoder.model return decoder def inject_lora(model, lora_cfg: LoRAConfig) -> Tuple[Any, Dict[str, Any]]: from peft import get_peft_model, LoraConfig as PeftLoraConfig, TaskType decoder = _unwrap_decoder(model) model.decoder = decoder # Guard enable_input_require_grads for DiT (no get_input_embeddings) if hasattr(decoder, "enable_input_require_grads"): orig = decoder.enable_input_require_grads def _safe(self): try: return orig() except NotImplementedError: return None decoder.enable_input_require_grads = types.MethodType(_safe, decoder) if hasattr(decoder, "is_gradient_checkpointing"): try: decoder.is_gradient_checkpointing = False except Exception: pass peft_cfg = PeftLoraConfig( r=lora_cfg.r, lora_alpha=lora_cfg.alpha, lora_dropout=lora_cfg.dropout, target_modules=lora_cfg.target_modules, bias=lora_cfg.bias, task_type=TaskType.FEATURE_EXTRACTION, ) model.decoder = get_peft_model(decoder, peft_cfg) for name, param in model.named_parameters(): if "lora_" not in name: param.requires_grad = False total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return model, { "total_params": total, "trainable_params": trainable, "trainable_ratio": trainable / total if total > 0 else 0, } def save_lora_adapter(model, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) decoder = model.decoder if hasattr(model, "decoder") else model while hasattr(decoder, "_forward_module"): decoder = decoder._forward_module if hasattr(decoder, "save_pretrained"): decoder.save_pretrained(output_dir) # Scrub base_model path for portability cfg_path = os.path.join(output_dir, "adapter_config.json") if os.path.isfile(cfg_path): try: with open(cfg_path, "r") as f: cfg = json.load(f) if cfg.get("base_model_name_or_path"): cfg["base_model_name_or_path"] = "" with open(cfg_path, "w") as f: json.dump(cfg, f, indent=2) except Exception: pass logger.info("LoRA adapter saved to %s", output_dir) else: # Fallback: manual extraction state = {} for name, param in decoder.named_parameters(): if "lora_" in name: state[name] = param.data.clone() if state: try: from safetensors.torch import save_file save_file(state, str(Path(output_dir) / "adapter_model.safetensors")) except ImportError: torch.save(state, str(Path(output_dir) / "lora_weights.pt")) logger.info("LoRA adapter saved (fallback) to %s", output_dir) # ============================================================================ # MODEL LOADING (FA2 -> SDPA -> eager fallback) # ============================================================================ _VARIANT_DIR = { "turbo": "acestep-v15-turbo", "base": "acestep-v15-base", "sft": "acestep-v15-sft", } def _resolve_model_dir(checkpoint_dir: str, variant: str) -> Path: base = Path(checkpoint_dir).resolve() subdir = _VARIANT_DIR.get(variant) if subdir: p = (Path(checkpoint_dir) / subdir).resolve() if p.is_dir(): return p p = (Path(checkpoint_dir) / variant).resolve() if p.is_dir(): return p raise FileNotFoundError( f"Model directory not found: tried {_VARIANT_DIR.get(variant, variant)!r} " f"and {variant!r} under {checkpoint_dir}" ) def _ensure_acestep_imports(): """Register stub modules so AutoModel can load ACE-Step checkpoints.""" for name in ( "acestep", "acestep.models", "acestep.models.common", "acestep.models.xl_base", "acestep.models.xl_turbo", "acestep.models.xl_sft", ): if name not in sys.modules: stub = types.ModuleType(name) stub.__path__ = [] sys.modules[name] = stub # Try to load real modules from adjacent ACE-Step checkout for name in ( "acestep.models.common.configuration_acestep_v15", "acestep.models.common.apg_guidance", ): if name not in sys.modules: sys.modules[name] = types.ModuleType(name) def _attn_candidates(device: str) -> List[str]: """FA2 -> SDPA -> eager, filtered by availability.""" candidates = [] if device.startswith("cuda"): try: import flash_attn # noqa: F401 dev_idx = int(device.split(":")[1]) if ":" in device else 0 props = torch.cuda.get_device_properties(dev_idx) if props.major >= 8: candidates.append("flash_attention_2") except (ImportError, Exception): pass candidates.extend(["sdpa", "eager"]) return candidates def load_model_for_training( checkpoint_dir: str, variant: str = "base", device: str = "cpu", ) -> Any: from transformers import AutoModel model_dir = _resolve_model_dir(checkpoint_dir, variant) # CPU always uses float32 dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 _ensure_acestep_imports() candidates = _attn_candidates(device) model = None last_err = None for idx, attn in enumerate(candidates): try: load_kwargs = dict( trust_remote_code=True, attn_implementation=attn, torch_dtype=dtype, low_cpu_mem_usage=False, ) if device != "cpu": load_kwargs["device_map"] = {"": device} model = AutoModel.from_pretrained(str(model_dir), **load_kwargs) logger.info("Model loaded with attn_implementation=%s", attn) break except Exception as exc: err_text = str(exc) if "packages that were not found" in err_text or "No module named" in err_text: raise RuntimeError( f"Model files in {model_dir} require a missing Python package.\n" f" Original error: {err_text}" ) from exc last_err = exc logger.warning("attn backend '%s' failed: %s", attn, exc) if model is None: raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err for param in model.parameters(): param.requires_grad = False model.eval() return model def load_vae(checkpoint_dir: str, device: str = "cpu"): from diffusers.models import AutoencoderOobleck vae_path = Path(checkpoint_dir) / "vae" if not vae_path.is_dir(): raise FileNotFoundError(f"VAE directory not found: {vae_path}") dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype) vae = vae.to(device=device) vae.eval() return vae def load_text_encoder(checkpoint_dir: str, device: str = "cpu"): from transformers import AutoModel, AutoTokenizer text_path = Path(checkpoint_dir) / "Qwen3-Embedding-0.6B" if not text_path.is_dir(): raise FileNotFoundError(f"Text encoder not found: {text_path}") dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 tokenizer = AutoTokenizer.from_pretrained(str(text_path)) encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype) encoder = encoder.to(device=device) encoder.eval() return tokenizer, encoder def load_silence_latent( checkpoint_dir: str, device: str = "cpu", variant: str = "base", ) -> torch.Tensor: ckpt = Path(checkpoint_dir) dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 candidates = [ckpt / "silence_latent.pt"] subdir = _VARIANT_DIR.get(variant) if subdir: candidates.append(ckpt / subdir / "silence_latent.pt") for sd in _VARIANT_DIR.values(): candidates.append(ckpt / sd / "silence_latent.pt") for c in candidates: if c.is_file(): sl = torch.load(str(c), weights_only=True).transpose(1, 2) return sl.to(device=device, dtype=dtype) raise FileNotFoundError(f"silence_latent.pt not found under {ckpt}") def unload_models(*models) -> None: for obj in models: if obj is None: continue if hasattr(obj, "to"): try: obj.to("cpu") except Exception: pass del obj gc.collect() # ============================================================================ # AUDIO LOADING # ============================================================================ def load_audio_stereo( audio_path: str, target_sr: int, max_duration: float, ) -> Tuple[torch.Tensor, int]: import numpy as np try: import soundfile as sf data, sr = sf.read(audio_path, dtype="float32", always_2d=True) audio_np = np.ascontiguousarray(data.T) sr = int(sr) if sr != target_sr: import librosa audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr, axis=1) sr = target_sr audio = torch.from_numpy(np.ascontiguousarray(audio_np)) except Exception: import torchaudio audio, sr = torchaudio.load(audio_path) sr = int(sr) if sr != target_sr: audio = torchaudio.transforms.Resample(sr, target_sr)(audio) sr = target_sr if audio.shape[0] == 1: audio = audio.repeat(2, 1) elif audio.shape[0] > 2: audio = audio[:2, :] max_samples = int(max_duration * target_sr) if audio.shape[1] > max_samples: audio = audio[:, :max_samples] return audio, sr # ============================================================================ # TEXT / LYRICS ENCODING # ============================================================================ def encode_text(text_encoder, tokenizer, text_prompt: str, device, dtype): inputs = tokenizer( text_prompt, padding="max_length", max_length=256, truncation=True, return_tensors="pt", ) ids = inputs.input_ids.to(device) mask = inputs.attention_mask.to(device).to(dtype) enc_dev = next(text_encoder.parameters()).device if ids.device != enc_dev: ids = ids.to(enc_dev) mask = mask.to(enc_dev) with torch.no_grad(): hs = text_encoder(ids).last_hidden_state.to(dtype) return hs, mask def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype): inputs = tokenizer( lyrics, padding="max_length", max_length=512, truncation=True, return_tensors="pt", ) ids = inputs.input_ids.to(device) mask = inputs.attention_mask.to(device).to(dtype) enc_dev = next(text_encoder.parameters()).device if ids.device != enc_dev: ids = ids.to(enc_dev) mask = mask.to(enc_dev) with torch.no_grad(): hs = text_encoder.embed_tokens(ids).to(dtype) return hs, mask # ============================================================================ # VAE TILED ENCODING # ============================================================================ def tiled_vae_encode( vae, audio: torch.Tensor, dtype: torch.dtype, chunk_size: Optional[int] = None, overlap: int = 96000, ) -> torch.Tensor: vae_device = next(vae.parameters()).device vae_dtype = vae.dtype if chunk_size is None: chunk_size = TARGET_SR * 30 B, C, S = audio.shape if S <= chunk_size: vae_input = audio.to(vae_device, dtype=vae_dtype) with torch.inference_mode(): latents = vae.encode(vae_input).latent_dist.sample() return latents.transpose(1, 2).to(dtype) stride = chunk_size - 2 * overlap if stride <= 0: raise ValueError(f"chunk_size ({chunk_size}) must be > 2 * overlap ({overlap})") num_steps = math.ceil(S / stride) ds_factor = None write_pos = 0 final = None for i in range(num_steps): core_start = i * stride core_end = min(core_start + stride, S) win_start = max(0, core_start - overlap) win_end = min(S, core_end + overlap) chunk = audio[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype) with torch.inference_mode(): lat = vae.encode(chunk).latent_dist.sample() if ds_factor is None: ds_factor = chunk.shape[-1] / lat.shape[-1] total_len = int(round(S / ds_factor)) final = torch.zeros(B, lat.shape[1], total_len, dtype=lat.dtype, device="cpu") trim_start = int(round((core_start - win_start) / ds_factor)) trim_end = int(round((win_end - core_end) / ds_factor)) end_idx = lat.shape[-1] - trim_end if trim_end > 0 else lat.shape[-1] core = lat[:, :, trim_start:end_idx] core_len = core.shape[-1] final[:, :, write_pos:write_pos + core_len] = core.cpu() write_pos += core_len del chunk, lat, core final = final[:, :, :write_pos] return final.transpose(1, 2).to(dtype) # ============================================================================ # ENCODER / CONTEXT HELPERS # ============================================================================ def run_encoder( model, text_hs, text_mask, lyric_hs, lyric_mask, device, dtype, ): refer = torch.zeros(1, 1, 64, device=device, dtype=dtype) order_mask = torch.zeros(1, device=device, dtype=torch.long) with torch.no_grad(): enc_hs, enc_mask = model.encoder( text_hidden_states=text_hs, text_attention_mask=text_mask, lyric_hidden_states=lyric_hs, lyric_attention_mask=lyric_mask, refer_audio_acoustic_hidden_states_packed=refer, refer_audio_order_mask=order_mask, ) return enc_hs, enc_mask def build_context_latents(silence_latent, latent_length: int, device, dtype): src = silence_latent[:, :latent_length, :].to(dtype) if src.shape[0] < 1: src = src.expand(1, -1, -1) if src.shape[1] < latent_length: pad_len = latent_length - src.shape[1] src = torch.cat([src, silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)], dim=1) elif src.shape[1] > latent_length: src = src[:, :latent_length, :] masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype) return torch.cat([src, masks], dim=-1) # ============================================================================ # AUDIO DISCOVERY # ============================================================================ def _discover_audio_files(audio_dir: str) -> List[Path]: files = [] for root, _, names in os.walk(audio_dir): for name in sorted(names): if Path(name).suffix.lower() in AUDIO_EXTENSIONS: files.append(Path(root) / name) return files def _detect_max_duration(files: List[Path]) -> float: """Return the longest audio file duration (capped at MAX_AUDIO_DURATION).""" max_dur = 0.0 try: import soundfile as sf for f in files[:50]: try: info = sf.info(str(f)) max_dur = max(max_dur, info.duration) except Exception: pass except ImportError: pass return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION) # ============================================================================ # PREPROCESSING (2-pass sequential) # ============================================================================ def preprocess_audio( audio_dir: str, output_dir: str, checkpoint_dir: str, device: str = "cpu", variant: str = "base", max_duration: float = 0, progress_callback: Optional[Callable] = None, cancel_check: Optional[Callable] = None, ) -> Dict[str, Any]: """2-pass sequential preprocessing. Pass 1: Load VAE + text encoder, encode audio + text, save intermediates. Pass 2: Load DIT model, run encoder, build context, save final .pt files. """ out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) # Clean orphaned staging files for orphan in out.glob("*.__writing__"): try: orphan.unlink() except OSError: pass audio_files = _discover_audio_files(audio_dir) if not audio_files: return {"processed": 0, "failed": 0, "total": 0, "output_dir": str(out)} total = len(audio_files) if max_duration <= 0: max_duration = _detect_max_duration(audio_files) dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 # ---- Pass 1: VAE + Text Encoder ---- logger.info("Pass 1/2: Loading VAE + Text Encoder...") vae = load_vae(checkpoint_dir, device) tokenizer, text_enc = load_text_encoder(checkpoint_dir, device) silence_lat = load_silence_latent(checkpoint_dir, device, variant=variant) intermediates: List[Path] = [] p1_failed = 0 try: for i, af in enumerate(audio_files): if cancel_check and cancel_check(): break stem = af.stem final_pt = out / f"{stem}.pt" if final_pt.exists(): continue try: audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration) audio = audio.unsqueeze(0).to(device=device, dtype=vae.dtype) with torch.no_grad(): target_latents = tiled_vae_encode(vae, audio, dtype) del audio if torch.isnan(target_latents).any() or torch.isinf(target_latents).any(): p1_failed += 1 del target_latents continue lat_len = target_latents.shape[1] att_mask = torch.ones(1, lat_len, device=device, dtype=dtype) caption = af.stem lyrics = "[Instrumental]" text_prompt = caption with torch.no_grad(): text_hs, text_mask = encode_text(text_enc, tokenizer, text_prompt, device, dtype) lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype) has_bad = any( torch.isnan(t).any() or torch.isinf(t).any() for t in [text_hs, lyric_hs] ) if has_bad: p1_failed += 1 del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask continue tmp_path = out / f"{stem}.tmp.pt" torch.save({ "target_latents": target_latents.squeeze(0).cpu(), "attention_mask": att_mask.squeeze(0).cpu(), "text_hidden_states": text_hs.cpu(), "text_attention_mask": text_mask.cpu(), "lyric_hidden_states": lyric_hs.cpu(), "lyric_attention_mask": lyric_mask.cpu(), "silence_latent": silence_lat.cpu(), "latent_length": lat_len, "metadata": { "audio_path": str(af), "filename": af.name, "caption": caption, "lyrics": lyrics, }, }, tmp_path) del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask intermediates.append(tmp_path) if progress_callback: progress_callback(i + 1, total, f"[Pass 1] {af.name}") except Exception as exc: p1_failed += 1 logger.error("Pass 1 FAIL %s: %s", af.name, exc) finally: logger.info("Unloading VAE + Text Encoder...") unload_models(vae, text_enc, tokenizer, silence_lat) # ---- Pass 2: DIT Encoder ---- if not intermediates: return {"processed": 0, "failed": p1_failed, "total": total, "output_dir": str(out)} logger.info("Pass 2/2: Loading DIT model (variant=%s)...", variant) model = load_model_for_training(checkpoint_dir, variant, device) processed = 0 p2_failed = 0 p2_total = len(intermediates) try: for i, tmp_path in enumerate(intermediates): if cancel_check and cancel_check(): break try: data = torch.load(str(tmp_path), weights_only=True) m_device = next(model.parameters()).device m_dtype = next(model.parameters()).dtype text_hs = data["text_hidden_states"].to(m_device, dtype=m_dtype) text_mask = data["text_attention_mask"].to(m_device, dtype=m_dtype) lyric_hs = data["lyric_hidden_states"].to(m_device, dtype=m_dtype) lyric_mask = data["lyric_attention_mask"].to(m_device, dtype=m_dtype) silence_lat = data["silence_latent"].to(m_device, dtype=m_dtype) lat_len = data["latent_length"] enc_hs, enc_mask = run_encoder( model, text_hs, text_mask, lyric_hs, lyric_mask, str(m_device), m_dtype, ) del text_hs, text_mask, lyric_hs, lyric_mask if silence_lat.dim() == 2: silence_lat = silence_lat.unsqueeze(0) ctx = build_context_latents(silence_lat, lat_len, str(m_device), m_dtype) del silence_lat has_bad = any( torch.isnan(t).any() or torch.isinf(t).any() for t in [enc_hs, ctx] ) if has_bad: p2_failed += 1 del enc_hs, enc_mask, ctx, data continue base_name = tmp_path.name.replace(".tmp.pt", ".pt") final_path = out / base_name staging_path = out / (base_name + ".__writing__") torch.save({ "target_latents": data["target_latents"], "attention_mask": data["attention_mask"], "encoder_hidden_states": enc_hs.squeeze(0).cpu(), "encoder_attention_mask": enc_mask.squeeze(0).cpu(), "context_latents": ctx.squeeze(0).cpu(), "metadata": data.get("metadata", {}), }, staging_path) os.replace(staging_path, final_path) del enc_hs, enc_mask, ctx, data tmp_path.unlink(missing_ok=True) processed += 1 if progress_callback: progress_callback(i + 1, p2_total, f"[Pass 2] {tmp_path.stem}") except Exception as exc: p2_failed += 1 logger.error("Pass 2 FAIL %s: %s", tmp_path.stem, exc) finally: logger.info("Unloading DIT model...") unload_models(model) failed = p1_failed + p2_failed return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)} # ============================================================================ # TRAINING LOOP (generator for Gradio compatibility) # ============================================================================ def train_lora_generator( dataset_dir: str, output_dir: str, checkpoint_dir: str, epochs: int = 1000, lr: float = 3e-4, rank: int = 64, alpha: int = 128, dropout: float = 0.1, batch_size: int = 1, gradient_accumulation_steps: int = 4, warmup_steps: int = 100, weight_decay: float = 0.01, max_grad_norm: float = 1.0, save_every_n_epochs: int = 50, seed: int = 42, variant: str = "base", device: str = "cpu", cfg_ratio: float = 0.15, timestep_mu: float = -0.4, timestep_sigma: float = 1.0, target_modules: Optional[List[str]] = None, log_every: int = 10, resume_from: Optional[str] = None, ) -> Generator[str, None, None]: """Run LoRA training, yielding progress strings each epoch. This is a generator for Gradio live-update compatibility. Call cancel_training() to stop after the current epoch. """ _training_cancel.clear() train_start = time.time() if target_modules is None: target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] ds_path = Path(dataset_dir) if not ds_path.is_dir(): yield f"[FAIL] Dataset directory not found: {ds_path}" return out_path = Path(output_dir) out_path.mkdir(parents=True, exist_ok=True) yield "[INFO] Loading model..." try: model = load_model_for_training(checkpoint_dir, variant, device) except Exception as exc: yield f"[FAIL] Model load failed: {exc}" return # float32 on CPU (bfloat16 deadlocks) dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 model = model.to(dtype=dtype) yield "[INFO] Injecting LoRA..." lora_cfg = LoRAConfig( r=rank, alpha=alpha, dropout=dropout, target_modules=target_modules, bias="none", ) try: model, info = inject_lora(model, lora_cfg) except Exception as exc: yield f"[FAIL] LoRA injection failed: {exc}" unload_models(model) return yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params" # Gradient checkpointing + cache disable force_disable_cache(model.decoder) ckpt_ok = enable_gradient_checkpointing(model.decoder) force_input_grads = ckpt_ok if ckpt_ok: yield "[INFO] Gradient checkpointing enabled" # Dataset dataset = TensorDataset(dataset_dir) if len(dataset) == 0: yield "[FAIL] No valid .pt files found in dataset directory" unload_models(model) return yield f"[OK] Loaded {len(dataset)} preprocessed samples" loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=_collate_batch, drop_last=False, ) # Optimizer & scheduler torch.manual_seed(seed) random.seed(seed) trainable_params = [p for p in model.parameters() if p.requires_grad] if not trainable_params: yield "[FAIL] No trainable parameters found" unload_models(model) return optimizer = build_optimizer(trainable_params, lr=lr, weight_decay=weight_decay) steps_per_epoch = max(1, math.ceil(len(loader) / gradient_accumulation_steps)) total_steps = steps_per_epoch * epochs scheduler = build_scheduler(optimizer, total_steps, warmup_steps, lr) yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs" yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}" # Null condition embedding for CFG dropout null_cond = getattr(model, "null_condition_emb", None) # Resume checkpoint start_epoch = 0 global_step = 0 if resume_from and Path(resume_from).exists(): try: yield f"[INFO] Resuming from {resume_from}" ckpt_dir = Path(resume_from) if ckpt_dir.is_file(): ckpt_dir = ckpt_dir.parent # Load adapter weights aw = ckpt_dir / "adapter_model.safetensors" if aw.exists(): from safetensors.torch import load_file state = load_file(str(aw)) decoder = model.decoder while hasattr(decoder, "_forward_module"): decoder = decoder._forward_module decoder.load_state_dict(state, strict=False) # Load training state ts = ckpt_dir / "training_state.pt" if ts.exists(): tstate = torch.load(str(ts), map_location=device, weights_only=True) start_epoch = tstate.get("epoch", 0) global_step = tstate.get("global_step", 0) if "optimizer_state_dict" in tstate: try: optimizer.load_state_dict(tstate["optimizer_state_dict"]) except Exception: pass if "scheduler_state_dict" in tstate: try: scheduler.load_state_dict(tstate["scheduler_state_dict"]) except Exception: pass yield f"[OK] Resumed from epoch {start_epoch}, step {global_step}" except Exception as exc: yield f"[WARN] Checkpoint load failed: {exc}, starting fresh" start_epoch = 0 global_step = 0 # Training loop model.decoder.train() acc_step = 0 acc_loss = 0.0 optimizer.zero_grad(set_to_none=True) best_loss = float("inf") best_epoch = 0 consecutive_nan = 0 MAX_NAN = 10 for epoch in range(start_epoch, epochs): # Cancel check if _training_cancel.is_set(): _training_cancel.clear() early_path = str(out_path / "early_exit") model.decoder.eval() save_lora_adapter(model, early_path) model.decoder.train() yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}" yield "[DONE]" unload_models(model) return # Timeout check elapsed = time.time() - train_start if elapsed > MAX_TRAINING_TIME: early_path = str(out_path / "timeout_exit") model.decoder.eval() save_lora_adapter(model, early_path) yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}" yield "[DONE]" unload_models(model) return epoch_loss = 0.0 num_updates = 0 epoch_start = time.time() for batch in loader: # Forward nb = device != "cpu" tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb) att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb) enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb) enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb) ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb) bsz = tgt.shape[0] # CFG dropout if null_cond is not None and cfg_ratio > 0: enc_hs = apply_cfg_dropout(enc_hs, null_cond, cfg_ratio) # Timestep sampling t, _r = sample_timesteps(bsz, torch.device(device), dtype, timestep_mu, timestep_sigma) # Flow matching noise x1 = torch.randn_like(tgt) x0 = tgt t_ = t.unsqueeze(-1).unsqueeze(-1) xt = t_ * x1 + (1.0 - t_) * x0 if force_input_grads: xt = xt.requires_grad_(True) # Decoder forward dec_out = model.decoder( hidden_states=xt, timestep=t, timestep_r=t, attention_mask=att, encoder_hidden_states=enc_hs, encoder_attention_mask=enc_mask, context_latents=ctx, ) flow = x1 - x0 loss = F.mse_loss(dec_out[0], flow) loss = loss.float() # fp32 for stable backward # NaN guard if torch.isnan(loss) or torch.isinf(loss): consecutive_nan += 1 del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow if consecutive_nan >= MAX_NAN: yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting" unload_models(model) return if acc_step > 0: optimizer.zero_grad(set_to_none=True) acc_loss = 0.0 acc_step = 0 continue consecutive_nan = 0 loss = loss / gradient_accumulation_steps loss.backward() acc_loss += loss.item() del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow acc_step += 1 if acc_step >= gradient_accumulation_steps: torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) optimizer.step() scheduler.step() global_step += 1 avg_loss = acc_loss * gradient_accumulation_steps / acc_step if global_step % log_every == 0: current_lr = scheduler.get_last_lr()[0] yield ( f"Epoch {epoch + 1}/{epochs}, " f"Step {global_step}, " f"Loss: {avg_loss:.4f}, " f"LR: {current_lr:.2e}" ) optimizer.zero_grad(set_to_none=True) epoch_loss += avg_loss num_updates += 1 acc_loss = 0.0 acc_step = 0 # Flush remainder if acc_step > 0: torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) optimizer.step() scheduler.step() global_step += 1 avg_loss = acc_loss * gradient_accumulation_steps / acc_step optimizer.zero_grad(set_to_none=True) epoch_loss += avg_loss num_updates += 1 acc_loss = 0.0 acc_step = 0 epoch_time = time.time() - epoch_start avg_epoch_loss = epoch_loss / max(num_updates, 1) is_best = avg_epoch_loss < best_loss - 0.001 if is_best: best_loss = avg_epoch_loss best_epoch = epoch + 1 best_str = f" (best: {best_loss:.4f} @ ep{best_epoch})" if best_epoch > 0 else "" yield ( f"[OK] Epoch {epoch + 1}/{epochs} in {epoch_time:.1f}s, " f"Loss: {avg_epoch_loss:.4f}{best_str}" ) # Save best if is_best and epoch + 1 >= 10: best_path = str(out_path / "best") model.decoder.eval() save_lora_adapter(model, best_path) model.decoder.train() yield f"[OK] Best model saved (epoch {epoch + 1}, loss: {best_loss:.4f})" # Periodic checkpoint if (epoch + 1) % save_every_n_epochs == 0: ckpt_path = str(out_path / "checkpoints" / f"epoch_{epoch + 1}") model.decoder.eval() save_lora_adapter(model, ckpt_path) tstate = { "epoch": epoch + 1, "global_step": global_step, "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), } os.makedirs(ckpt_path, exist_ok=True) torch.save(tstate, str(Path(ckpt_path) / "training_state.pt")) model.decoder.train() yield f"[OK] Checkpoint saved at epoch {epoch + 1}" # Sanity check if global_step == 0: yield "[FAIL] Training completed 0 steps -- no batches processed" unload_models(model) return # Final save final_path = str(out_path / "final") model.decoder.eval() save_lora_adapter(model, final_path) final_loss = avg_epoch_loss if num_updates > 0 else 0.0 best_note = "" if best_epoch > 0 and Path(out_path / "best").exists(): best_note = f"\n Best: {out_path / 'best'} (epoch {best_epoch}, loss: {best_loss:.4f})" yield ( f"[OK] Training complete! LoRA saved to {final_path}{best_note}\n" f" For inference, set your LoRA path to: {final_path}" ) yield "[DONE]" unload_models(model) # ============================================================================ # ADAPTER LISTING # ============================================================================ def get_trained_loras(adapter_dir: str) -> List[str]: """List all saved LoRA adapter directories under adapter_dir.""" result = [] base = Path(adapter_dir) if not base.is_dir(): return result for root, dirs, files in os.walk(str(base)): for f in files: if f in ("adapter_config.json", "adapter_model.safetensors", "lora_weights.pt"): result.append(root) break return sorted(set(result))