ACE-Step-CPU / train_engine.py
Nekochu's picture
fix: adapter saved to clean dir, LM dropdown no 'Default', on-demand download
e62602f
raw
history blame
47.3 kB
"""
Standalone ACE-Step CPU LoRA Training Engine.
Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
module. No external Side-Step dependency required.
Exports:
preprocess_audio() - 2-pass sequential preprocessing
train_lora_generator() - Generator-based LoRA training loop
cancel_training() - Set the cancel flag
get_trained_loras() - List saved adapters
"""
from __future__ import annotations
import gc
import json
import logging
import math
import os
import random
import sys
import time
import types
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
LinearLR,
SequentialLR,
)
from torch.utils.data import DataLoader, Dataset
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configurable caps (edit these at the top of the file)
# ---------------------------------------------------------------------------
MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file
MAX_TRAINING_TIME = 28800 # 8 hours hard timeout
TARGET_SR = 48000
AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"})
# bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32
CPU_DTYPE = torch.float32
import threading
_training_cancel = threading.Event()
def cancel_training() -> None:
_training_cancel.set()
# ============================================================================
# CONFIGS
# ============================================================================
@dataclass
class LoRAConfig:
r: int = 64
alpha: int = 128
dropout: float = 0.1
target_modules: List[str] = field(default_factory=lambda: [
"q_proj", "k_proj", "v_proj", "o_proj",
])
bias: str = "none"
attention_type: str = "both"
target_mlp: bool = True
# ============================================================================
# TIMESTEP SAMPLING & CFG DROPOUT
# ============================================================================
def sample_timesteps(
batch_size: int,
device: torch.device,
dtype: torch.dtype,
timestep_mu: float = -0.4,
timestep_sigma: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
t = torch.sigmoid(
torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu
)
r = torch.sigmoid(
torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu
)
t, r = torch.maximum(t, r), torch.minimum(t, r)
# use_meanflow=False forces r=t (ACE-Step convention)
return t, t
def apply_cfg_dropout(
encoder_hidden_states: torch.Tensor,
null_condition_emb: torch.Tensor,
cfg_ratio: float = 0.15,
) -> torch.Tensor:
bsz = encoder_hidden_states.shape[0]
device = encoder_hidden_states.device
dtype = encoder_hidden_states.dtype
mask = torch.where(
torch.rand(size=(bsz,), device=device, dtype=dtype) < cfg_ratio,
torch.zeros(size=(bsz,), device=device, dtype=dtype),
torch.ones(size=(bsz,), device=device, dtype=dtype),
).view(-1, 1, 1)
return torch.where(
mask > 0,
encoder_hidden_states,
null_condition_emb.expand_as(encoder_hidden_states),
)
# ============================================================================
# OPTIMIZER (Adafactor preferred for CPU -- 1.5 bytes/param)
# ============================================================================
def build_optimizer(
params, lr: float = 1e-4, weight_decay: float = 0.01,
) -> torch.optim.Optimizer:
try:
from transformers.optimization import Adafactor
logger.info("Using Adafactor optimizer (minimal state memory)")
return Adafactor(
params, lr=lr, weight_decay=weight_decay,
scale_parameter=False, relative_step=False,
)
except ImportError:
logger.warning("transformers not installed, falling back to AdamW")
return AdamW(params, lr=lr, weight_decay=weight_decay)
def build_scheduler(
optimizer, total_steps: int, warmup_steps: int, lr: float,
):
_max_warmup = max(1, total_steps // 10)
if warmup_steps > _max_warmup:
warmup_steps = _max_warmup
warmup = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps)
remaining = max(1, total_steps - warmup_steps)
main = CosineAnnealingLR(optimizer, T_max=remaining, eta_min=lr * 0.01)
return SequentialLR(optimizer, [warmup, main], milestones=[warmup_steps])
# ============================================================================
# DATASET
# ============================================================================
def _collate_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
max_t = max(s["target_latents"].shape[0] for s in batch)
max_e = max(s["encoder_hidden_states"].shape[0] for s in batch)
def pad(t, max_len, dim=0):
diff = max_len - t.shape[dim]
if diff <= 0:
return t
shape = list(t.shape)
shape[dim] = diff
return torch.cat([t, t.new_zeros(*shape)], dim=dim)
return {
"target_latents": torch.stack([pad(s["target_latents"], max_t) for s in batch]),
"attention_mask": torch.stack([pad(s["attention_mask"], max_t) for s in batch]),
"encoder_hidden_states": torch.stack([pad(s["encoder_hidden_states"], max_e) for s in batch]),
"encoder_attention_mask": torch.stack([pad(s["encoder_attention_mask"], max_e) for s in batch]),
"context_latents": torch.stack([pad(s["context_latents"], max_t) for s in batch]),
}
class TensorDataset(Dataset):
_REQUIRED = frozenset([
"target_latents", "attention_mask", "encoder_hidden_states",
"encoder_attention_mask", "context_latents",
])
def __init__(self, tensor_dir: str):
self.paths: List[str] = []
for f in sorted(os.listdir(tensor_dir)):
if f.endswith(".pt") and not f.endswith(".tmp.pt") and f != "manifest.json":
self.paths.append(str(Path(tensor_dir) / f))
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
data = torch.load(self.paths[idx], map_location="cpu", weights_only=True)
missing = self._REQUIRED - data.keys()
if missing:
raise KeyError(f"Missing keys {sorted(missing)} in {self.paths[idx]}")
for k in ("target_latents", "encoder_hidden_states", "context_latents"):
t = data[k]
if torch.isnan(t).any() or torch.isinf(t).any():
t.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
return {k: data[k] for k in self._REQUIRED}
# ============================================================================
# GRADIENT CHECKPOINTING
# ============================================================================
def _find_decoder_layers(decoder: nn.Module) -> Optional[nn.ModuleList]:
for attr in ("layers", "blocks", "transformer_blocks"):
c = getattr(decoder, attr, None)
if isinstance(c, nn.ModuleList) and len(c) > 0:
return c
for child in decoder.children():
for attr in ("layers", "blocks", "transformer_blocks"):
c = getattr(child, attr, None)
if isinstance(c, nn.ModuleList) and len(c) > 0:
return c
return None
def enable_gradient_checkpointing(decoder: nn.Module) -> bool:
"""Enable gradient checkpointing on the decoder to save memory."""
enabled = False
# Walk wrapper chain
stack = [decoder]
visited = set()
while stack:
mod = stack.pop()
if not isinstance(mod, nn.Module):
continue
mid = id(mod)
if mid in visited:
continue
visited.add(mid)
if hasattr(mod, "gradient_checkpointing_enable"):
try:
mod.gradient_checkpointing_enable()
enabled = True
except Exception:
pass
elif hasattr(mod, "gradient_checkpointing"):
try:
mod.gradient_checkpointing = True
enabled = True
except Exception:
pass
if hasattr(mod, "enable_input_require_grads"):
try:
mod.enable_input_require_grads()
except Exception:
pass
cfg = getattr(mod, "config", None)
if cfg is not None and hasattr(cfg, "use_cache"):
try:
cfg.use_cache = False
except Exception:
pass
for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
child = getattr(mod, a, None)
if isinstance(child, nn.Module):
stack.append(child)
return enabled
def force_disable_cache(decoder: nn.Module) -> None:
stack = [decoder]
visited = set()
while stack:
mod = stack.pop()
if not isinstance(mod, nn.Module):
continue
mid = id(mod)
if mid in visited:
continue
visited.add(mid)
cfg = getattr(mod, "config", None)
if cfg is not None and hasattr(cfg, "use_cache"):
try:
cfg.use_cache = False
except Exception:
pass
for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
child = getattr(mod, a, None)
if isinstance(child, nn.Module):
stack.append(child)
# ============================================================================
# LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT)
# ============================================================================
def _unwrap_decoder(model):
decoder = model.decoder if hasattr(model, "decoder") else model
while hasattr(decoder, "_forward_module"):
decoder = decoder._forward_module
if hasattr(decoder, "base_model"):
bm = decoder.base_model
decoder = bm.model if hasattr(bm, "model") else bm
if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
decoder = decoder.model
return decoder
def inject_lora(model, lora_cfg: LoRAConfig) -> Tuple[Any, Dict[str, Any]]:
from peft import get_peft_model, LoraConfig as PeftLoraConfig, TaskType
decoder = _unwrap_decoder(model)
model.decoder = decoder
# Guard enable_input_require_grads for DiT (no get_input_embeddings)
if hasattr(decoder, "enable_input_require_grads"):
orig = decoder.enable_input_require_grads
def _safe(self):
try:
return orig()
except NotImplementedError:
return None
decoder.enable_input_require_grads = types.MethodType(_safe, decoder)
if hasattr(decoder, "is_gradient_checkpointing"):
try:
decoder.is_gradient_checkpointing = False
except Exception:
pass
peft_cfg = PeftLoraConfig(
r=lora_cfg.r,
lora_alpha=lora_cfg.alpha,
lora_dropout=lora_cfg.dropout,
target_modules=lora_cfg.target_modules,
bias=lora_cfg.bias,
task_type=TaskType.FEATURE_EXTRACTION,
)
model.decoder = get_peft_model(decoder, peft_cfg)
for name, param in model.named_parameters():
if "lora_" not in name:
param.requires_grad = False
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return model, {
"total_params": total,
"trainable_params": trainable,
"trainable_ratio": trainable / total if total > 0 else 0,
}
def save_lora_adapter(model, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
decoder = model.decoder if hasattr(model, "decoder") else model
while hasattr(decoder, "_forward_module"):
decoder = decoder._forward_module
if hasattr(decoder, "save_pretrained"):
decoder.save_pretrained(output_dir)
# Scrub base_model path for portability
cfg_path = os.path.join(output_dir, "adapter_config.json")
if os.path.isfile(cfg_path):
try:
with open(cfg_path, "r") as f:
cfg = json.load(f)
if cfg.get("base_model_name_or_path"):
cfg["base_model_name_or_path"] = ""
with open(cfg_path, "w") as f:
json.dump(cfg, f, indent=2)
except Exception:
pass
logger.info("LoRA adapter saved to %s", output_dir)
else:
# Fallback: manual extraction
state = {}
for name, param in decoder.named_parameters():
if "lora_" in name:
state[name] = param.data.clone()
if state:
try:
from safetensors.torch import save_file
save_file(state, str(Path(output_dir) / "adapter_model.safetensors"))
except ImportError:
torch.save(state, str(Path(output_dir) / "lora_weights.pt"))
logger.info("LoRA adapter saved (fallback) to %s", output_dir)
# ============================================================================
# MODEL LOADING (FA2 -> SDPA -> eager fallback)
# ============================================================================
_VARIANT_DIR = {
"turbo": "acestep-v15-turbo",
"base": "acestep-v15-base",
"sft": "acestep-v15-sft",
}
def _resolve_model_dir(checkpoint_dir: str, variant: str) -> Path:
base = Path(checkpoint_dir).resolve()
subdir = _VARIANT_DIR.get(variant)
if subdir:
p = (Path(checkpoint_dir) / subdir).resolve()
if p.is_dir():
return p
p = (Path(checkpoint_dir) / variant).resolve()
if p.is_dir():
return p
raise FileNotFoundError(
f"Model directory not found: tried {_VARIANT_DIR.get(variant, variant)!r} "
f"and {variant!r} under {checkpoint_dir}"
)
def _ensure_acestep_imports():
"""Register stub modules so AutoModel can load ACE-Step checkpoints."""
for name in (
"acestep", "acestep.models", "acestep.models.common",
"acestep.models.xl_base", "acestep.models.xl_turbo", "acestep.models.xl_sft",
):
if name not in sys.modules:
stub = types.ModuleType(name)
stub.__path__ = []
sys.modules[name] = stub
# Try to load real modules from adjacent ACE-Step checkout
for name in (
"acestep.models.common.configuration_acestep_v15",
"acestep.models.common.apg_guidance",
):
if name not in sys.modules:
sys.modules[name] = types.ModuleType(name)
def _attn_candidates(device: str) -> List[str]:
"""FA2 -> SDPA -> eager, filtered by availability."""
candidates = []
if device.startswith("cuda"):
try:
import flash_attn # noqa: F401
dev_idx = int(device.split(":")[1]) if ":" in device else 0
props = torch.cuda.get_device_properties(dev_idx)
if props.major >= 8:
candidates.append("flash_attention_2")
except (ImportError, Exception):
pass
candidates.extend(["sdpa", "eager"])
return candidates
def load_model_for_training(
checkpoint_dir: str, variant: str = "base", device: str = "cpu",
) -> Any:
from transformers import AutoModel
model_dir = _resolve_model_dir(checkpoint_dir, variant)
# CPU always uses float32
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
_ensure_acestep_imports()
candidates = _attn_candidates(device)
model = None
last_err = None
for idx, attn in enumerate(candidates):
try:
load_kwargs = dict(
trust_remote_code=True,
attn_implementation=attn,
torch_dtype=dtype,
low_cpu_mem_usage=False,
)
if device != "cpu":
load_kwargs["device_map"] = {"": device}
model = AutoModel.from_pretrained(str(model_dir), **load_kwargs)
logger.info("Model loaded with attn_implementation=%s", attn)
break
except Exception as exc:
err_text = str(exc)
if "packages that were not found" in err_text or "No module named" in err_text:
raise RuntimeError(
f"Model files in {model_dir} require a missing Python package.\n"
f" Original error: {err_text}"
) from exc
last_err = exc
logger.warning("attn backend '%s' failed: %s", attn, exc)
if model is None:
raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err
for param in model.parameters():
param.requires_grad = False
model.eval()
return model
def load_vae(checkpoint_dir: str, device: str = "cpu"):
from diffusers.models import AutoencoderOobleck
vae_path = Path(checkpoint_dir) / "vae"
if not vae_path.is_dir():
raise FileNotFoundError(f"VAE directory not found: {vae_path}")
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype)
vae = vae.to(device=device)
vae.eval()
return vae
def load_text_encoder(checkpoint_dir: str, device: str = "cpu"):
from transformers import AutoModel, AutoTokenizer
text_path = Path(checkpoint_dir) / "Qwen3-Embedding-0.6B"
if not text_path.is_dir():
raise FileNotFoundError(f"Text encoder not found: {text_path}")
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
tokenizer = AutoTokenizer.from_pretrained(str(text_path))
encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype)
encoder = encoder.to(device=device)
encoder.eval()
return tokenizer, encoder
def load_silence_latent(
checkpoint_dir: str, device: str = "cpu", variant: str = "base",
) -> torch.Tensor:
ckpt = Path(checkpoint_dir)
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
candidates = [ckpt / "silence_latent.pt"]
subdir = _VARIANT_DIR.get(variant)
if subdir:
candidates.append(ckpt / subdir / "silence_latent.pt")
for sd in _VARIANT_DIR.values():
candidates.append(ckpt / sd / "silence_latent.pt")
for c in candidates:
if c.is_file():
sl = torch.load(str(c), weights_only=True).transpose(1, 2)
return sl.to(device=device, dtype=dtype)
raise FileNotFoundError(f"silence_latent.pt not found under {ckpt}")
def unload_models(*models) -> None:
for obj in models:
if obj is None:
continue
if hasattr(obj, "to"):
try:
obj.to("cpu")
except Exception:
pass
del obj
gc.collect()
# ============================================================================
# AUDIO LOADING
# ============================================================================
def load_audio_stereo(
audio_path: str, target_sr: int, max_duration: float,
) -> Tuple[torch.Tensor, int]:
import numpy as np
try:
import soundfile as sf
data, sr = sf.read(audio_path, dtype="float32", always_2d=True)
audio_np = np.ascontiguousarray(data.T)
sr = int(sr)
if sr != target_sr:
import librosa
audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr, axis=1)
sr = target_sr
audio = torch.from_numpy(np.ascontiguousarray(audio_np))
except Exception:
import torchaudio
audio, sr = torchaudio.load(audio_path)
sr = int(sr)
if sr != target_sr:
audio = torchaudio.transforms.Resample(sr, target_sr)(audio)
sr = target_sr
if audio.shape[0] == 1:
audio = audio.repeat(2, 1)
elif audio.shape[0] > 2:
audio = audio[:2, :]
max_samples = int(max_duration * target_sr)
if audio.shape[1] > max_samples:
audio = audio[:, :max_samples]
return audio, sr
# ============================================================================
# TEXT / LYRICS ENCODING
# ============================================================================
def encode_text(text_encoder, tokenizer, text_prompt: str, device, dtype):
inputs = tokenizer(
text_prompt, padding="max_length", max_length=256,
truncation=True, return_tensors="pt",
)
ids = inputs.input_ids.to(device)
mask = inputs.attention_mask.to(device).to(dtype)
enc_dev = next(text_encoder.parameters()).device
if ids.device != enc_dev:
ids = ids.to(enc_dev)
mask = mask.to(enc_dev)
with torch.no_grad():
hs = text_encoder(ids).last_hidden_state.to(dtype)
return hs, mask
def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype):
inputs = tokenizer(
lyrics, padding="max_length", max_length=512,
truncation=True, return_tensors="pt",
)
ids = inputs.input_ids.to(device)
mask = inputs.attention_mask.to(device).to(dtype)
enc_dev = next(text_encoder.parameters()).device
if ids.device != enc_dev:
ids = ids.to(enc_dev)
mask = mask.to(enc_dev)
with torch.no_grad():
hs = text_encoder.embed_tokens(ids).to(dtype)
return hs, mask
# ============================================================================
# VAE TILED ENCODING
# ============================================================================
def tiled_vae_encode(
vae, audio: torch.Tensor, dtype: torch.dtype,
chunk_size: Optional[int] = None, overlap: int = 96000,
) -> torch.Tensor:
vae_device = next(vae.parameters()).device
vae_dtype = vae.dtype
if chunk_size is None:
chunk_size = TARGET_SR * 30
B, C, S = audio.shape
if S <= chunk_size:
vae_input = audio.to(vae_device, dtype=vae_dtype)
with torch.inference_mode():
latents = vae.encode(vae_input).latent_dist.sample()
return latents.transpose(1, 2).to(dtype)
stride = chunk_size - 2 * overlap
if stride <= 0:
raise ValueError(f"chunk_size ({chunk_size}) must be > 2 * overlap ({overlap})")
num_steps = math.ceil(S / stride)
ds_factor = None
write_pos = 0
final = None
for i in range(num_steps):
core_start = i * stride
core_end = min(core_start + stride, S)
win_start = max(0, core_start - overlap)
win_end = min(S, core_end + overlap)
chunk = audio[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype)
with torch.inference_mode():
lat = vae.encode(chunk).latent_dist.sample()
if ds_factor is None:
ds_factor = chunk.shape[-1] / lat.shape[-1]
total_len = int(round(S / ds_factor))
final = torch.zeros(B, lat.shape[1], total_len, dtype=lat.dtype, device="cpu")
trim_start = int(round((core_start - win_start) / ds_factor))
trim_end = int(round((win_end - core_end) / ds_factor))
end_idx = lat.shape[-1] - trim_end if trim_end > 0 else lat.shape[-1]
core = lat[:, :, trim_start:end_idx]
core_len = core.shape[-1]
final[:, :, write_pos:write_pos + core_len] = core.cpu()
write_pos += core_len
del chunk, lat, core
final = final[:, :, :write_pos]
return final.transpose(1, 2).to(dtype)
# ============================================================================
# ENCODER / CONTEXT HELPERS
# ============================================================================
def run_encoder(
model, text_hs, text_mask, lyric_hs, lyric_mask, device, dtype,
):
refer = torch.zeros(1, 1, 64, device=device, dtype=dtype)
order_mask = torch.zeros(1, device=device, dtype=torch.long)
with torch.no_grad():
enc_hs, enc_mask = model.encoder(
text_hidden_states=text_hs,
text_attention_mask=text_mask,
lyric_hidden_states=lyric_hs,
lyric_attention_mask=lyric_mask,
refer_audio_acoustic_hidden_states_packed=refer,
refer_audio_order_mask=order_mask,
)
return enc_hs, enc_mask
def build_context_latents(silence_latent, latent_length: int, device, dtype):
src = silence_latent[:, :latent_length, :].to(dtype)
if src.shape[0] < 1:
src = src.expand(1, -1, -1)
if src.shape[1] < latent_length:
pad_len = latent_length - src.shape[1]
src = torch.cat([src, silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)], dim=1)
elif src.shape[1] > latent_length:
src = src[:, :latent_length, :]
masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
return torch.cat([src, masks], dim=-1)
# ============================================================================
# AUDIO DISCOVERY
# ============================================================================
def _discover_audio_files(audio_dir: str) -> List[Path]:
files = []
for root, _, names in os.walk(audio_dir):
for name in sorted(names):
if Path(name).suffix.lower() in AUDIO_EXTENSIONS:
files.append(Path(root) / name)
return files
def _detect_max_duration(files: List[Path]) -> float:
"""Return the longest audio file duration (capped at MAX_AUDIO_DURATION)."""
max_dur = 0.0
try:
import soundfile as sf
for f in files[:50]:
try:
info = sf.info(str(f))
max_dur = max(max_dur, info.duration)
except Exception:
pass
except ImportError:
pass
return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION)
# ============================================================================
# PREPROCESSING (2-pass sequential)
# ============================================================================
def preprocess_audio(
audio_dir: str,
output_dir: str,
checkpoint_dir: str,
device: str = "cpu",
variant: str = "base",
max_duration: float = 0,
progress_callback: Optional[Callable] = None,
cancel_check: Optional[Callable] = None,
) -> Dict[str, Any]:
"""2-pass sequential preprocessing.
Pass 1: Load VAE + text encoder, encode audio + text, save intermediates.
Pass 2: Load DIT model, run encoder, build context, save final .pt files.
"""
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
# Clean orphaned staging files
for orphan in out.glob("*.__writing__"):
try:
orphan.unlink()
except OSError:
pass
audio_files = _discover_audio_files(audio_dir)
if not audio_files:
return {"processed": 0, "failed": 0, "total": 0, "output_dir": str(out)}
total = len(audio_files)
if max_duration <= 0:
max_duration = _detect_max_duration(audio_files)
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
# ---- Pass 1: VAE + Text Encoder ----
logger.info("Pass 1/2: Loading VAE + Text Encoder...")
vae = load_vae(checkpoint_dir, device)
tokenizer, text_enc = load_text_encoder(checkpoint_dir, device)
silence_lat = load_silence_latent(checkpoint_dir, device, variant=variant)
intermediates: List[Path] = []
p1_failed = 0
try:
for i, af in enumerate(audio_files):
if cancel_check and cancel_check():
break
stem = af.stem
final_pt = out / f"{stem}.pt"
if final_pt.exists():
continue
try:
audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration)
audio = audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
with torch.no_grad():
target_latents = tiled_vae_encode(vae, audio, dtype)
del audio
if torch.isnan(target_latents).any() or torch.isinf(target_latents).any():
p1_failed += 1
del target_latents
continue
lat_len = target_latents.shape[1]
att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
caption = af.stem
lyrics = "[Instrumental]"
text_prompt = caption
with torch.no_grad():
text_hs, text_mask = encode_text(text_enc, tokenizer, text_prompt, device, dtype)
lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype)
has_bad = any(
torch.isnan(t).any() or torch.isinf(t).any()
for t in [text_hs, lyric_hs]
)
if has_bad:
p1_failed += 1
del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask
continue
tmp_path = out / f"{stem}.tmp.pt"
torch.save({
"target_latents": target_latents.squeeze(0).cpu(),
"attention_mask": att_mask.squeeze(0).cpu(),
"text_hidden_states": text_hs.cpu(),
"text_attention_mask": text_mask.cpu(),
"lyric_hidden_states": lyric_hs.cpu(),
"lyric_attention_mask": lyric_mask.cpu(),
"silence_latent": silence_lat.cpu(),
"latent_length": lat_len,
"metadata": {
"audio_path": str(af),
"filename": af.name,
"caption": caption,
"lyrics": lyrics,
},
}, tmp_path)
del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask
intermediates.append(tmp_path)
if progress_callback:
progress_callback(i + 1, total, f"[Pass 1] {af.name}")
except Exception as exc:
p1_failed += 1
logger.error("Pass 1 FAIL %s: %s", af.name, exc)
finally:
logger.info("Unloading VAE + Text Encoder...")
unload_models(vae, text_enc, tokenizer, silence_lat)
# ---- Pass 2: DIT Encoder ----
if not intermediates:
return {"processed": 0, "failed": p1_failed, "total": total, "output_dir": str(out)}
logger.info("Pass 2/2: Loading DIT model (variant=%s)...", variant)
model = load_model_for_training(checkpoint_dir, variant, device)
processed = 0
p2_failed = 0
p2_total = len(intermediates)
try:
for i, tmp_path in enumerate(intermediates):
if cancel_check and cancel_check():
break
try:
data = torch.load(str(tmp_path), weights_only=True)
m_device = next(model.parameters()).device
m_dtype = next(model.parameters()).dtype
text_hs = data["text_hidden_states"].to(m_device, dtype=m_dtype)
text_mask = data["text_attention_mask"].to(m_device, dtype=m_dtype)
lyric_hs = data["lyric_hidden_states"].to(m_device, dtype=m_dtype)
lyric_mask = data["lyric_attention_mask"].to(m_device, dtype=m_dtype)
silence_lat = data["silence_latent"].to(m_device, dtype=m_dtype)
lat_len = data["latent_length"]
enc_hs, enc_mask = run_encoder(
model, text_hs, text_mask, lyric_hs, lyric_mask,
str(m_device), m_dtype,
)
del text_hs, text_mask, lyric_hs, lyric_mask
if silence_lat.dim() == 2:
silence_lat = silence_lat.unsqueeze(0)
ctx = build_context_latents(silence_lat, lat_len, str(m_device), m_dtype)
del silence_lat
has_bad = any(
torch.isnan(t).any() or torch.isinf(t).any()
for t in [enc_hs, ctx]
)
if has_bad:
p2_failed += 1
del enc_hs, enc_mask, ctx, data
continue
base_name = tmp_path.name.replace(".tmp.pt", ".pt")
final_path = out / base_name
staging_path = out / (base_name + ".__writing__")
torch.save({
"target_latents": data["target_latents"],
"attention_mask": data["attention_mask"],
"encoder_hidden_states": enc_hs.squeeze(0).cpu(),
"encoder_attention_mask": enc_mask.squeeze(0).cpu(),
"context_latents": ctx.squeeze(0).cpu(),
"metadata": data.get("metadata", {}),
}, staging_path)
os.replace(staging_path, final_path)
del enc_hs, enc_mask, ctx, data
tmp_path.unlink(missing_ok=True)
processed += 1
if progress_callback:
progress_callback(i + 1, p2_total, f"[Pass 2] {tmp_path.stem}")
except Exception as exc:
p2_failed += 1
logger.error("Pass 2 FAIL %s: %s", tmp_path.stem, exc)
finally:
logger.info("Unloading DIT model...")
unload_models(model)
failed = p1_failed + p2_failed
return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)}
# ============================================================================
# TRAINING LOOP (generator for Gradio compatibility)
# ============================================================================
def train_lora_generator(
dataset_dir: str,
output_dir: str,
checkpoint_dir: str,
epochs: int = 1000,
lr: float = 3e-4,
rank: int = 64,
alpha: int = 128,
dropout: float = 0.1,
batch_size: int = 1,
gradient_accumulation_steps: int = 4,
warmup_steps: int = 100,
weight_decay: float = 0.01,
max_grad_norm: float = 1.0,
save_every_n_epochs: int = 50,
seed: int = 42,
variant: str = "base",
device: str = "cpu",
cfg_ratio: float = 0.15,
timestep_mu: float = -0.4,
timestep_sigma: float = 1.0,
target_modules: Optional[List[str]] = None,
log_every: int = 10,
resume_from: Optional[str] = None,
) -> Generator[str, None, None]:
"""Run LoRA training, yielding progress strings each epoch.
This is a generator for Gradio live-update compatibility.
Call cancel_training() to stop after the current epoch.
"""
_training_cancel.clear()
train_start = time.time()
if target_modules is None:
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
ds_path = Path(dataset_dir)
if not ds_path.is_dir():
yield f"[FAIL] Dataset directory not found: {ds_path}"
return
out_path = Path(output_dir)
out_path.mkdir(parents=True, exist_ok=True)
yield "[INFO] Loading model..."
try:
model = load_model_for_training(checkpoint_dir, variant, device)
except Exception as exc:
yield f"[FAIL] Model load failed: {exc}"
return
# float32 on CPU (bfloat16 deadlocks)
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
model = model.to(dtype=dtype)
yield "[INFO] Injecting LoRA..."
lora_cfg = LoRAConfig(
r=rank, alpha=alpha, dropout=dropout,
target_modules=target_modules, bias="none",
)
try:
model, info = inject_lora(model, lora_cfg)
except Exception as exc:
yield f"[FAIL] LoRA injection failed: {exc}"
unload_models(model)
return
yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params"
# Gradient checkpointing + cache disable
force_disable_cache(model.decoder)
ckpt_ok = enable_gradient_checkpointing(model.decoder)
force_input_grads = ckpt_ok
if ckpt_ok:
yield "[INFO] Gradient checkpointing enabled"
# Dataset
dataset = TensorDataset(dataset_dir)
if len(dataset) == 0:
yield "[FAIL] No valid .pt files found in dataset directory"
unload_models(model)
return
yield f"[OK] Loaded {len(dataset)} preprocessed samples"
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True,
num_workers=0, collate_fn=_collate_batch, drop_last=False,
)
# Optimizer & scheduler
torch.manual_seed(seed)
random.seed(seed)
trainable_params = [p for p in model.parameters() if p.requires_grad]
if not trainable_params:
yield "[FAIL] No trainable parameters found"
unload_models(model)
return
optimizer = build_optimizer(trainable_params, lr=lr, weight_decay=weight_decay)
steps_per_epoch = max(1, math.ceil(len(loader) / gradient_accumulation_steps))
total_steps = steps_per_epoch * epochs
scheduler = build_scheduler(optimizer, total_steps, warmup_steps, lr)
yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs"
yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}"
# Null condition embedding for CFG dropout
null_cond = getattr(model, "null_condition_emb", None)
# Resume checkpoint
start_epoch = 0
global_step = 0
if resume_from and Path(resume_from).exists():
try:
yield f"[INFO] Resuming from {resume_from}"
ckpt_dir = Path(resume_from)
if ckpt_dir.is_file():
ckpt_dir = ckpt_dir.parent
# Load adapter weights
aw = ckpt_dir / "adapter_model.safetensors"
if aw.exists():
from safetensors.torch import load_file
state = load_file(str(aw))
decoder = model.decoder
while hasattr(decoder, "_forward_module"):
decoder = decoder._forward_module
decoder.load_state_dict(state, strict=False)
# Load training state
ts = ckpt_dir / "training_state.pt"
if ts.exists():
tstate = torch.load(str(ts), map_location=device, weights_only=True)
start_epoch = tstate.get("epoch", 0)
global_step = tstate.get("global_step", 0)
if "optimizer_state_dict" in tstate:
try:
optimizer.load_state_dict(tstate["optimizer_state_dict"])
except Exception:
pass
if "scheduler_state_dict" in tstate:
try:
scheduler.load_state_dict(tstate["scheduler_state_dict"])
except Exception:
pass
yield f"[OK] Resumed from epoch {start_epoch}, step {global_step}"
except Exception as exc:
yield f"[WARN] Checkpoint load failed: {exc}, starting fresh"
start_epoch = 0
global_step = 0
# Training loop
model.decoder.train()
acc_step = 0
acc_loss = 0.0
optimizer.zero_grad(set_to_none=True)
best_loss = float("inf")
best_epoch = 0
consecutive_nan = 0
MAX_NAN = 10
for epoch in range(start_epoch, epochs):
# Cancel check
if _training_cancel.is_set():
_training_cancel.clear()
early_path = str(out_path / "early_exit")
model.decoder.eval()
save_lora_adapter(model, early_path)
model.decoder.train()
yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
yield "[DONE]"
unload_models(model)
return
# Timeout check
elapsed = time.time() - train_start
if elapsed > MAX_TRAINING_TIME:
early_path = str(out_path / "timeout_exit")
model.decoder.eval()
save_lora_adapter(model, early_path)
yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}"
yield "[DONE]"
unload_models(model)
return
epoch_loss = 0.0
num_updates = 0
epoch_start = time.time()
for batch in loader:
# Forward
nb = device != "cpu"
tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb)
att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb)
enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb)
enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb)
ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb)
bsz = tgt.shape[0]
# CFG dropout
if null_cond is not None and cfg_ratio > 0:
enc_hs = apply_cfg_dropout(enc_hs, null_cond, cfg_ratio)
# Timestep sampling
t, _r = sample_timesteps(bsz, torch.device(device), dtype, timestep_mu, timestep_sigma)
# Flow matching noise
x1 = torch.randn_like(tgt)
x0 = tgt
t_ = t.unsqueeze(-1).unsqueeze(-1)
xt = t_ * x1 + (1.0 - t_) * x0
if force_input_grads:
xt = xt.requires_grad_(True)
# Decoder forward
dec_out = model.decoder(
hidden_states=xt,
timestep=t,
timestep_r=t,
attention_mask=att,
encoder_hidden_states=enc_hs,
encoder_attention_mask=enc_mask,
context_latents=ctx,
)
flow = x1 - x0
loss = F.mse_loss(dec_out[0], flow)
loss = loss.float() # fp32 for stable backward
# NaN guard
if torch.isnan(loss) or torch.isinf(loss):
consecutive_nan += 1
del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
if consecutive_nan >= MAX_NAN:
yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting"
unload_models(model)
return
if acc_step > 0:
optimizer.zero_grad(set_to_none=True)
acc_loss = 0.0
acc_step = 0
continue
consecutive_nan = 0
loss = loss / gradient_accumulation_steps
loss.backward()
acc_loss += loss.item()
del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
acc_step += 1
if acc_step >= gradient_accumulation_steps:
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
optimizer.step()
scheduler.step()
global_step += 1
avg_loss = acc_loss * gradient_accumulation_steps / acc_step
if global_step % log_every == 0:
current_lr = scheduler.get_last_lr()[0]
yield (
f"Epoch {epoch + 1}/{epochs}, "
f"Step {global_step}, "
f"Loss: {avg_loss:.4f}, "
f"LR: {current_lr:.2e}"
)
optimizer.zero_grad(set_to_none=True)
epoch_loss += avg_loss
num_updates += 1
acc_loss = 0.0
acc_step = 0
# Flush remainder
if acc_step > 0:
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
optimizer.step()
scheduler.step()
global_step += 1
avg_loss = acc_loss * gradient_accumulation_steps / acc_step
optimizer.zero_grad(set_to_none=True)
epoch_loss += avg_loss
num_updates += 1
acc_loss = 0.0
acc_step = 0
epoch_time = time.time() - epoch_start
avg_epoch_loss = epoch_loss / max(num_updates, 1)
is_best = avg_epoch_loss < best_loss - 0.001
if is_best:
best_loss = avg_epoch_loss
best_epoch = epoch + 1
best_str = f" (best: {best_loss:.4f} @ ep{best_epoch})" if best_epoch > 0 else ""
yield (
f"[OK] Epoch {epoch + 1}/{epochs} in {epoch_time:.1f}s, "
f"Loss: {avg_epoch_loss:.4f}{best_str}"
)
# Save best
if is_best and epoch + 1 >= 10:
best_path = str(out_path / "best")
model.decoder.eval()
save_lora_adapter(model, best_path)
model.decoder.train()
yield f"[OK] Best model saved (epoch {epoch + 1}, loss: {best_loss:.4f})"
# Periodic checkpoint
if (epoch + 1) % save_every_n_epochs == 0:
ckpt_path = str(out_path / "checkpoints" / f"epoch_{epoch + 1}")
model.decoder.eval()
save_lora_adapter(model, ckpt_path)
tstate = {
"epoch": epoch + 1,
"global_step": global_step,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
}
os.makedirs(ckpt_path, exist_ok=True)
torch.save(tstate, str(Path(ckpt_path) / "training_state.pt"))
model.decoder.train()
yield f"[OK] Checkpoint saved at epoch {epoch + 1}"
# Sanity check
if global_step == 0:
yield "[FAIL] Training completed 0 steps -- no batches processed"
unload_models(model)
return
# Final save (directly to output_dir, not a subdirectory)
model.decoder.eval()
save_lora_adapter(model, str(out_path))
final_loss = avg_epoch_loss if num_updates > 0 else 0.0
best_note = ""
if best_epoch > 0 and Path(out_path / "best").exists():
best_note = f"\n Best: {out_path / 'best'} (epoch {best_epoch}, loss: {best_loss:.4f})"
yield (
f"[OK] Training complete! LoRA saved to {out_path}{best_note}\n"
f" Adapter ready for inference."
)
yield "[DONE]"
unload_models(model)
# ============================================================================
# ADAPTER LISTING
# ============================================================================
def get_trained_loras(adapter_dir: str) -> List[str]:
"""List all saved LoRA adapter directories under adapter_dir."""
result = []
base = Path(adapter_dir)
if not base.is_dir():
return result
for root, dirs, files in os.walk(str(base)):
for f in files:
if f in ("adapter_config.json", "adapter_model.safetensors", "lora_weights.pt"):
result.append(root)
break
return sorted(set(result))