Spaces:
Running
Running
add GPU/CUDA auto-detect, mixed precision, flash_attn, txt caption parser
Browse files- train_engine.py +299 -44
train_engine.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
| 1 |
"""
|
| 2 |
-
Standalone ACE-Step
|
| 3 |
|
| 4 |
Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
|
| 5 |
module. No external Side-Step dependency required.
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
Exports:
|
|
|
|
|
|
|
| 8 |
preprocess_audio() - 2-pass sequential preprocessing
|
| 9 |
train_lora_generator() - Generator-based LoRA training loop
|
| 10 |
cancel_training() - Set the cancel flag
|
|
@@ -63,6 +69,93 @@ def cancel_training() -> None:
|
|
| 63 |
_training_cancel.set()
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
# ============================================================================
|
| 67 |
# CONFIGS
|
| 68 |
# ============================================================================
|
|
@@ -448,7 +541,12 @@ def _ensure_acestep_imports():
|
|
| 448 |
|
| 449 |
|
| 450 |
def _attn_candidates(device: str) -> List[str]:
|
| 451 |
-
"""FA2 -> SDPA -> eager, filtered by availability.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
candidates = []
|
| 453 |
if device.startswith("cuda"):
|
| 454 |
try:
|
|
@@ -457,8 +555,21 @@ def _attn_candidates(device: str) -> List[str]:
|
|
| 457 |
props = torch.cuda.get_device_properties(dev_idx)
|
| 458 |
if props.major >= 8:
|
| 459 |
candidates.append("flash_attention_2")
|
| 460 |
-
|
| 461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
candidates.extend(["sdpa", "eager"])
|
| 463 |
return candidates
|
| 464 |
|
|
@@ -469,8 +580,12 @@ def load_model_for_training(
|
|
| 469 |
from transformers import AutoModel
|
| 470 |
|
| 471 |
model_dir = _resolve_model_dir(checkpoint_dir, variant)
|
| 472 |
-
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
_ensure_acestep_imports()
|
| 476 |
|
|
@@ -489,7 +604,7 @@ def load_model_for_training(
|
|
| 489 |
if device != "cpu":
|
| 490 |
load_kwargs["device_map"] = {"": device}
|
| 491 |
model = AutoModel.from_pretrained(str(model_dir), **load_kwargs)
|
| 492 |
-
logger.info("Model loaded with attn_implementation=%s", attn)
|
| 493 |
break
|
| 494 |
except Exception as exc:
|
| 495 |
err_text = str(exc)
|
|
@@ -499,11 +614,23 @@ def load_model_for_training(
|
|
| 499 |
f" Original error: {err_text}"
|
| 500 |
) from exc
|
| 501 |
last_err = exc
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
if model is None:
|
| 505 |
raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err
|
| 506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
for param in model.parameters():
|
| 508 |
param.requires_grad = False
|
| 509 |
model.eval()
|
|
@@ -517,10 +644,11 @@ def load_vae(checkpoint_dir: str, device: str = "cpu"):
|
|
| 517 |
if not vae_path.is_dir():
|
| 518 |
raise FileNotFoundError(f"VAE directory not found: {vae_path}")
|
| 519 |
|
| 520 |
-
dtype =
|
| 521 |
vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype)
|
| 522 |
vae = vae.to(device=device)
|
| 523 |
vae.eval()
|
|
|
|
| 524 |
return vae
|
| 525 |
|
| 526 |
|
|
@@ -531,11 +659,12 @@ def load_text_encoder(checkpoint_dir: str, device: str = "cpu"):
|
|
| 531 |
if not text_path.is_dir():
|
| 532 |
raise FileNotFoundError(f"Text encoder not found: {text_path}")
|
| 533 |
|
| 534 |
-
dtype =
|
| 535 |
tokenizer = AutoTokenizer.from_pretrained(str(text_path))
|
| 536 |
encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype)
|
| 537 |
encoder = encoder.to(device=device)
|
| 538 |
encoder.eval()
|
|
|
|
| 539 |
return tokenizer, encoder
|
| 540 |
|
| 541 |
|
|
@@ -543,7 +672,7 @@ def load_silence_latent(
|
|
| 543 |
checkpoint_dir: str, device: str = "cpu", variant: str = "base",
|
| 544 |
) -> torch.Tensor:
|
| 545 |
ckpt = Path(checkpoint_dir)
|
| 546 |
-
dtype =
|
| 547 |
|
| 548 |
candidates = [ckpt / "silence_latent.pt"]
|
| 549 |
subdir = _VARIANT_DIR.get(variant)
|
|
@@ -571,6 +700,14 @@ def unload_models(*models) -> None:
|
|
| 571 |
pass
|
| 572 |
del obj
|
| 573 |
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
|
| 575 |
|
| 576 |
# ============================================================================
|
|
@@ -1904,16 +2041,44 @@ def _write_caption_sidecar(audio_path: Path, analysis: Dict[str, Any]) -> Path:
|
|
| 1904 |
return sidecar_path
|
| 1905 |
|
| 1906 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1907 |
def _read_caption_sidecar(audio_path: Path) -> Optional[Dict[str, Any]]:
|
| 1908 |
-
"""Read
|
| 1909 |
-
|
| 1910 |
-
if
|
| 1911 |
-
|
| 1912 |
-
|
| 1913 |
-
|
| 1914 |
-
|
| 1915 |
-
|
| 1916 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1917 |
|
| 1918 |
|
| 1919 |
# ============================================================================
|
|
@@ -1924,7 +2089,7 @@ def preprocess_audio(
|
|
| 1924 |
audio_dir: str,
|
| 1925 |
output_dir: str,
|
| 1926 |
checkpoint_dir: str,
|
| 1927 |
-
device: str = "
|
| 1928 |
variant: str = "base",
|
| 1929 |
max_duration: float = 0,
|
| 1930 |
progress_callback: Optional[Callable] = None,
|
|
@@ -1934,7 +2099,13 @@ def preprocess_audio(
|
|
| 1934 |
|
| 1935 |
Pass 1: Load VAE + text encoder, encode audio + text, save intermediates.
|
| 1936 |
Pass 2: Load DIT model, run encoder, build context, save final .pt files.
|
|
|
|
|
|
|
|
|
|
| 1937 |
"""
|
|
|
|
|
|
|
|
|
|
| 1938 |
out = Path(output_dir)
|
| 1939 |
out.mkdir(parents=True, exist_ok=True)
|
| 1940 |
|
|
@@ -1954,7 +2125,7 @@ def preprocess_audio(
|
|
| 1954 |
if max_duration <= 0:
|
| 1955 |
max_duration = _detect_max_duration(audio_files)
|
| 1956 |
|
| 1957 |
-
dtype =
|
| 1958 |
|
| 1959 |
# ---- Pass 1: VAE + Text Encoder ----
|
| 1960 |
logger.info("Pass 1/2: Loading VAE + Text Encoder...")
|
|
@@ -2086,6 +2257,7 @@ def preprocess_audio(
|
|
| 2086 |
finally:
|
| 2087 |
logger.info("Unloading VAE + Text Encoder...")
|
| 2088 |
unload_models(vae, text_enc, tokenizer, silence_lat)
|
|
|
|
| 2089 |
|
| 2090 |
# ---- Pass 2: DIT Encoder ----
|
| 2091 |
if not intermediates:
|
|
@@ -2162,6 +2334,7 @@ def preprocess_audio(
|
|
| 2162 |
finally:
|
| 2163 |
logger.info("Unloading DIT model...")
|
| 2164 |
unload_models(model)
|
|
|
|
| 2165 |
|
| 2166 |
failed = p1_failed + p2_failed
|
| 2167 |
return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)}
|
|
@@ -2188,7 +2361,7 @@ def train_lora_generator(
|
|
| 2188 |
save_every_n_epochs: int = 0,
|
| 2189 |
seed: int = 42,
|
| 2190 |
variant: str = "base",
|
| 2191 |
-
device: str = "
|
| 2192 |
cfg_ratio: float = 0.15,
|
| 2193 |
timestep_mu: float = -0.4,
|
| 2194 |
timestep_sigma: float = 1.0,
|
|
@@ -2200,10 +2373,20 @@ def train_lora_generator(
|
|
| 2200 |
|
| 2201 |
This is a generator for Gradio live-update compatibility.
|
| 2202 |
Call cancel_training() to stop after the current epoch.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2203 |
"""
|
| 2204 |
_training_cancel.clear()
|
| 2205 |
train_start = time.time()
|
| 2206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2207 |
if target_modules is None:
|
| 2208 |
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 2209 |
|
|
@@ -2215,6 +2398,13 @@ def train_lora_generator(
|
|
| 2215 |
out_path = Path(output_dir)
|
| 2216 |
out_path.mkdir(parents=True, exist_ok=True)
|
| 2217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2218 |
yield "[INFO] Loading model..."
|
| 2219 |
|
| 2220 |
try:
|
|
@@ -2223,10 +2413,14 @@ def train_lora_generator(
|
|
| 2223 |
yield f"[FAIL] Model load failed: {exc}"
|
| 2224 |
return
|
| 2225 |
|
| 2226 |
-
#
|
| 2227 |
-
|
| 2228 |
model = model.to(dtype=dtype)
|
| 2229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2230 |
yield "[INFO] Injecting LoRA..."
|
| 2231 |
|
| 2232 |
lora_cfg = LoRAConfig(
|
|
@@ -2262,11 +2456,14 @@ def train_lora_generator(
|
|
| 2262 |
loader = DataLoader(
|
| 2263 |
dataset, batch_size=batch_size, shuffle=True,
|
| 2264 |
num_workers=0, collate_fn=_collate_batch, drop_last=False,
|
|
|
|
| 2265 |
)
|
| 2266 |
|
| 2267 |
# Optimizer & scheduler
|
| 2268 |
torch.manual_seed(seed)
|
| 2269 |
random.seed(seed)
|
|
|
|
|
|
|
| 2270 |
|
| 2271 |
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 2272 |
if not trainable_params:
|
|
@@ -2282,6 +2479,13 @@ def train_lora_generator(
|
|
| 2282 |
yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs"
|
| 2283 |
yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}"
|
| 2284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2285 |
# Null condition embedding for CFG dropout
|
| 2286 |
null_cond = getattr(model, "null_condition_emb", None)
|
| 2287 |
|
|
@@ -2350,7 +2554,9 @@ def train_lora_generator(
|
|
| 2350 |
model.decoder.train()
|
| 2351 |
yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
|
| 2352 |
yield "[DONE]"
|
|
|
|
| 2353 |
unload_models(model)
|
|
|
|
| 2354 |
return
|
| 2355 |
|
| 2356 |
# Timeout check
|
|
@@ -2361,7 +2567,9 @@ def train_lora_generator(
|
|
| 2361 |
save_lora_adapter(model, early_path)
|
| 2362 |
yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}"
|
| 2363 |
yield "[DONE]"
|
|
|
|
| 2364 |
unload_models(model)
|
|
|
|
| 2365 |
return
|
| 2366 |
|
| 2367 |
epoch_loss = 0.0
|
|
@@ -2369,8 +2577,8 @@ def train_lora_generator(
|
|
| 2369 |
epoch_start = time.time()
|
| 2370 |
|
| 2371 |
for batch in loader:
|
| 2372 |
-
#
|
| 2373 |
-
nb =
|
| 2374 |
tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb)
|
| 2375 |
att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb)
|
| 2376 |
enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb)
|
|
@@ -2395,19 +2603,34 @@ def train_lora_generator(
|
|
| 2395 |
if force_input_grads:
|
| 2396 |
xt = xt.requires_grad_(True)
|
| 2397 |
|
| 2398 |
-
# Decoder forward
|
| 2399 |
-
|
| 2400 |
-
|
| 2401 |
-
|
| 2402 |
-
|
| 2403 |
-
|
| 2404 |
-
|
| 2405 |
-
|
| 2406 |
-
|
| 2407 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2408 |
|
| 2409 |
-
flow = x1 - x0
|
| 2410 |
-
loss = F.mse_loss(dec_out[0], flow)
|
| 2411 |
loss = loss.float() # fp32 for stable backward
|
| 2412 |
|
| 2413 |
# NaN guard
|
|
@@ -2416,7 +2639,9 @@ def train_lora_generator(
|
|
| 2416 |
del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
|
| 2417 |
if consecutive_nan >= MAX_NAN:
|
| 2418 |
yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting"
|
|
|
|
| 2419 |
unload_models(model)
|
|
|
|
| 2420 |
return
|
| 2421 |
if acc_step > 0:
|
| 2422 |
optimizer.zero_grad(set_to_none=True)
|
|
@@ -2426,14 +2651,27 @@ def train_lora_generator(
|
|
| 2426 |
consecutive_nan = 0
|
| 2427 |
|
| 2428 |
loss = loss / gradient_accumulation_steps
|
| 2429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2430 |
acc_loss += loss.item()
|
| 2431 |
del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
|
| 2432 |
acc_step += 1
|
| 2433 |
|
| 2434 |
if acc_step >= gradient_accumulation_steps:
|
| 2435 |
-
|
| 2436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2437 |
scheduler.step()
|
| 2438 |
global_step += 1
|
| 2439 |
|
|
@@ -2454,10 +2692,20 @@ def train_lora_generator(
|
|
| 2454 |
acc_loss = 0.0
|
| 2455 |
acc_step = 0
|
| 2456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2457 |
# Flush remainder
|
| 2458 |
if acc_step > 0:
|
| 2459 |
-
|
| 2460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2461 |
scheduler.step()
|
| 2462 |
global_step += 1
|
| 2463 |
avg_loss = acc_loss * gradient_accumulation_steps / acc_step
|
|
@@ -2506,10 +2754,15 @@ def train_lora_generator(
|
|
| 2506 |
model.decoder.train()
|
| 2507 |
yield f"[OK] Checkpoint saved at epoch {epoch + 1}"
|
| 2508 |
|
|
|
|
|
|
|
|
|
|
| 2509 |
# Sanity check
|
| 2510 |
if global_step == 0:
|
| 2511 |
yield "[FAIL] Training completed 0 steps -- no batches processed"
|
|
|
|
| 2512 |
unload_models(model)
|
|
|
|
| 2513 |
return
|
| 2514 |
|
| 2515 |
# Final save (directly to output_dir, not a subdirectory)
|
|
@@ -2525,7 +2778,9 @@ def train_lora_generator(
|
|
| 2525 |
f" Adapter ready for inference."
|
| 2526 |
)
|
| 2527 |
yield "[DONE]"
|
|
|
|
| 2528 |
unload_models(model)
|
|
|
|
| 2529 |
|
| 2530 |
|
| 2531 |
# ============================================================================
|
|
|
|
| 1 |
"""
|
| 2 |
+
Standalone ACE-Step LoRA Training Engine (CPU + GPU).
|
| 3 |
|
| 4 |
Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
|
| 5 |
module. No external Side-Step dependency required.
|
| 6 |
|
| 7 |
+
Auto-detects GPU (CUDA > MPS > CPU) and uses it when available,
|
| 8 |
+
falling back to CPU. bfloat16 is used on GPU; float32 is forced
|
| 9 |
+
on CPU (bfloat16 deadlocks on CPU -- known PyTorch bug).
|
| 10 |
+
|
| 11 |
Exports:
|
| 12 |
+
detect_device() - Auto-detect best available device
|
| 13 |
+
select_dtype() - Pick dtype for a device
|
| 14 |
preprocess_audio() - 2-pass sequential preprocessing
|
| 15 |
train_lora_generator() - Generator-based LoRA training loop
|
| 16 |
cancel_training() - Set the cancel flag
|
|
|
|
| 69 |
_training_cancel.set()
|
| 70 |
|
| 71 |
|
| 72 |
+
# ============================================================================
|
| 73 |
+
# DEVICE DETECTION & DTYPE SELECTION
|
| 74 |
+
# ============================================================================
|
| 75 |
+
|
| 76 |
+
def detect_device(requested: str = "auto") -> str:
|
| 77 |
+
"""Return the best available device string.
|
| 78 |
+
|
| 79 |
+
Priority: CUDA (best GPU by VRAM) > MPS (Apple Silicon) > CPU.
|
| 80 |
+
Pass an explicit device string (e.g. "cuda:0", "cpu") to skip
|
| 81 |
+
auto-detection.
|
| 82 |
+
"""
|
| 83 |
+
if requested != "auto":
|
| 84 |
+
return requested
|
| 85 |
+
|
| 86 |
+
if torch.cuda.is_available():
|
| 87 |
+
# Pick the GPU with the most VRAM when multiple are present
|
| 88 |
+
count = torch.cuda.device_count()
|
| 89 |
+
if count <= 1:
|
| 90 |
+
best_idx = 0
|
| 91 |
+
else:
|
| 92 |
+
best_idx, best_mem = 0, 0
|
| 93 |
+
for i in range(count):
|
| 94 |
+
mem = torch.cuda.get_device_properties(i).total_memory
|
| 95 |
+
if mem > best_mem:
|
| 96 |
+
best_idx, best_mem = i, mem
|
| 97 |
+
if best_idx != 0:
|
| 98 |
+
logger.info(
|
| 99 |
+
"Multiple CUDA devices (%d). Selected cuda:%d (%s, %.0f MiB).",
|
| 100 |
+
count, best_idx,
|
| 101 |
+
torch.cuda.get_device_name(best_idx),
|
| 102 |
+
best_mem / (1024 ** 2),
|
| 103 |
+
)
|
| 104 |
+
device = f"cuda:{best_idx}"
|
| 105 |
+
logger.info("Auto-detected device: %s (%s)", device, torch.cuda.get_device_name(best_idx))
|
| 106 |
+
return device
|
| 107 |
+
|
| 108 |
+
if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available():
|
| 109 |
+
logger.info("Auto-detected device: mps (Apple Silicon)")
|
| 110 |
+
return "mps"
|
| 111 |
+
|
| 112 |
+
logger.info("Auto-detected device: cpu")
|
| 113 |
+
return "cpu"
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def select_dtype(device: str) -> torch.dtype:
|
| 117 |
+
"""Select the appropriate training dtype for *device*.
|
| 118 |
+
|
| 119 |
+
GPU: bfloat16 if supported, else float16.
|
| 120 |
+
CPU: MUST stay float32 (bfloat16 deadlocks on CPU).
|
| 121 |
+
"""
|
| 122 |
+
dev_type = device.split(":")[0]
|
| 123 |
+
if dev_type == "cpu":
|
| 124 |
+
return CPU_DTYPE # always float32
|
| 125 |
+
|
| 126 |
+
if dev_type == "cuda":
|
| 127 |
+
# Prefer bfloat16 on Ampere+ (compute capability >= 8.0)
|
| 128 |
+
try:
|
| 129 |
+
idx = int(device.split(":")[1]) if ":" in device else 0
|
| 130 |
+
props = torch.cuda.get_device_properties(idx)
|
| 131 |
+
if props.major >= 8:
|
| 132 |
+
return torch.bfloat16
|
| 133 |
+
except Exception:
|
| 134 |
+
pass
|
| 135 |
+
return torch.float16
|
| 136 |
+
|
| 137 |
+
# MPS / other accelerators -- float32 is safest
|
| 138 |
+
if dev_type == "mps":
|
| 139 |
+
return torch.float32
|
| 140 |
+
|
| 141 |
+
return CPU_DTYPE
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _cuda_sync(device: str) -> None:
|
| 145 |
+
"""Synchronize CUDA if the device is a CUDA device (no-op otherwise)."""
|
| 146 |
+
if device.startswith("cuda") and torch.cuda.is_available():
|
| 147 |
+
torch.cuda.synchronize()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _clear_gpu_cache(device: str) -> None:
|
| 151 |
+
"""Free cached GPU memory for the given device type."""
|
| 152 |
+
dev_type = device.split(":")[0]
|
| 153 |
+
if dev_type == "cuda" and torch.cuda.is_available():
|
| 154 |
+
torch.cuda.empty_cache()
|
| 155 |
+
elif dev_type == "mps" and hasattr(torch, "mps") and torch.mps.is_available():
|
| 156 |
+
torch.mps.empty_cache()
|
| 157 |
+
|
| 158 |
+
|
| 159 |
# ============================================================================
|
| 160 |
# CONFIGS
|
| 161 |
# ============================================================================
|
|
|
|
| 541 |
|
| 542 |
|
| 543 |
def _attn_candidates(device: str) -> List[str]:
|
| 544 |
+
"""FA2 -> SDPA -> eager, filtered by availability.
|
| 545 |
+
|
| 546 |
+
On CUDA with flash_attn installed and compute capability >= 8.0,
|
| 547 |
+
flash_attention_2 is tried first. On CPU, flash_attention_2 is
|
| 548 |
+
always skipped (it requires CUDA).
|
| 549 |
+
"""
|
| 550 |
candidates = []
|
| 551 |
if device.startswith("cuda"):
|
| 552 |
try:
|
|
|
|
| 555 |
props = torch.cuda.get_device_properties(dev_idx)
|
| 556 |
if props.major >= 8:
|
| 557 |
candidates.append("flash_attention_2")
|
| 558 |
+
logger.info(
|
| 559 |
+
"flash_attention_2 available (compute %d.%d, flash_attn installed)",
|
| 560 |
+
props.major, props.minor,
|
| 561 |
+
)
|
| 562 |
+
else:
|
| 563 |
+
logger.info(
|
| 564 |
+
"flash_attention_2 skipped: compute %d.%d < 8.0 (need Ampere+)",
|
| 565 |
+
props.major, props.minor,
|
| 566 |
+
)
|
| 567 |
+
except ImportError:
|
| 568 |
+
logger.info("flash_attention_2 skipped: flash_attn package not installed")
|
| 569 |
+
except Exception as exc:
|
| 570 |
+
logger.info("flash_attention_2 skipped: %s", exc)
|
| 571 |
+
else:
|
| 572 |
+
logger.info("flash_attention_2 skipped: device is %s (not CUDA)", device)
|
| 573 |
candidates.extend(["sdpa", "eager"])
|
| 574 |
return candidates
|
| 575 |
|
|
|
|
| 580 |
from transformers import AutoModel
|
| 581 |
|
| 582 |
model_dir = _resolve_model_dir(checkpoint_dir, variant)
|
| 583 |
+
dtype = select_dtype(device)
|
| 584 |
+
|
| 585 |
+
logger.info(
|
| 586 |
+
"Loading model from %s (variant=%s, device=%s, dtype=%s)",
|
| 587 |
+
model_dir, variant, device, dtype,
|
| 588 |
+
)
|
| 589 |
|
| 590 |
_ensure_acestep_imports()
|
| 591 |
|
|
|
|
| 604 |
if device != "cpu":
|
| 605 |
load_kwargs["device_map"] = {"": device}
|
| 606 |
model = AutoModel.from_pretrained(str(model_dir), **load_kwargs)
|
| 607 |
+
logger.info("Model loaded with attn_implementation=%s on %s", attn, device)
|
| 608 |
break
|
| 609 |
except Exception as exc:
|
| 610 |
err_text = str(exc)
|
|
|
|
| 614 |
f" Original error: {err_text}"
|
| 615 |
) from exc
|
| 616 |
last_err = exc
|
| 617 |
+
next_attn = candidates[idx + 1] if idx + 1 < len(candidates) else None
|
| 618 |
+
if next_attn:
|
| 619 |
+
logger.warning("attn backend '%s' failed: %s; trying '%s'", attn, exc, next_attn)
|
| 620 |
+
else:
|
| 621 |
+
logger.warning("attn backend '%s' failed: %s", attn, exc)
|
| 622 |
|
| 623 |
if model is None:
|
| 624 |
raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err
|
| 625 |
|
| 626 |
+
# If device_map was not used (CPU), move model explicitly
|
| 627 |
+
if device != "cpu":
|
| 628 |
+
# device_map already placed weights; just verify dtype
|
| 629 |
+
if any(p.dtype != dtype for p in model.parameters()):
|
| 630 |
+
model = model.to(dtype=dtype)
|
| 631 |
+
else:
|
| 632 |
+
model = model.to(device=device, dtype=dtype)
|
| 633 |
+
|
| 634 |
for param in model.parameters():
|
| 635 |
param.requires_grad = False
|
| 636 |
model.eval()
|
|
|
|
| 644 |
if not vae_path.is_dir():
|
| 645 |
raise FileNotFoundError(f"VAE directory not found: {vae_path}")
|
| 646 |
|
| 647 |
+
dtype = select_dtype(device)
|
| 648 |
vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype)
|
| 649 |
vae = vae.to(device=device)
|
| 650 |
vae.eval()
|
| 651 |
+
logger.info("VAE loaded on %s (dtype=%s)", device, dtype)
|
| 652 |
return vae
|
| 653 |
|
| 654 |
|
|
|
|
| 659 |
if not text_path.is_dir():
|
| 660 |
raise FileNotFoundError(f"Text encoder not found: {text_path}")
|
| 661 |
|
| 662 |
+
dtype = select_dtype(device)
|
| 663 |
tokenizer = AutoTokenizer.from_pretrained(str(text_path))
|
| 664 |
encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype)
|
| 665 |
encoder = encoder.to(device=device)
|
| 666 |
encoder.eval()
|
| 667 |
+
logger.info("Text encoder loaded on %s (dtype=%s)", device, dtype)
|
| 668 |
return tokenizer, encoder
|
| 669 |
|
| 670 |
|
|
|
|
| 672 |
checkpoint_dir: str, device: str = "cpu", variant: str = "base",
|
| 673 |
) -> torch.Tensor:
|
| 674 |
ckpt = Path(checkpoint_dir)
|
| 675 |
+
dtype = select_dtype(device)
|
| 676 |
|
| 677 |
candidates = [ckpt / "silence_latent.pt"]
|
| 678 |
subdir = _VARIANT_DIR.get(variant)
|
|
|
|
| 700 |
pass
|
| 701 |
del obj
|
| 702 |
gc.collect()
|
| 703 |
+
# Free GPU memory after unloading
|
| 704 |
+
if torch.cuda.is_available():
|
| 705 |
+
torch.cuda.empty_cache()
|
| 706 |
+
if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available():
|
| 707 |
+
try:
|
| 708 |
+
torch.mps.empty_cache()
|
| 709 |
+
except Exception:
|
| 710 |
+
pass
|
| 711 |
|
| 712 |
|
| 713 |
# ============================================================================
|
|
|
|
| 2041 |
return sidecar_path
|
| 2042 |
|
| 2043 |
|
| 2044 |
+
def _parse_txt_caption(text: str) -> Dict[str, Any]:
|
| 2045 |
+
"""Parse user's .txt caption format into structured fields."""
|
| 2046 |
+
result: Dict[str, Any] = {}
|
| 2047 |
+
lyrics_match = re.search(r'lyrics say "(.*?)" at tempo', text, re.DOTALL)
|
| 2048 |
+
if lyrics_match:
|
| 2049 |
+
result["lyrics"] = lyrics_match.group(1).strip()
|
| 2050 |
+
caption_part = text[:lyrics_match.start()].strip().rstrip(",").strip()
|
| 2051 |
+
else:
|
| 2052 |
+
result["lyrics"] = "[Instrumental]"
|
| 2053 |
+
caption_part = text.strip()
|
| 2054 |
+
bpm_match = re.search(r'at tempo (\d+) BPM', text)
|
| 2055 |
+
if bpm_match:
|
| 2056 |
+
result["bpm"] = bpm_match.group(1)
|
| 2057 |
+
caption_part = re.sub(r'\s*at tempo \d+ BPM.*', '', caption_part).strip()
|
| 2058 |
+
key_match = re.search(r'in the key of ([A-G][#b]?[-\d]*)', text)
|
| 2059 |
+
if key_match:
|
| 2060 |
+
result["key"] = key_match.group(1)
|
| 2061 |
+
result["caption"] = caption_part if caption_part else text[:200]
|
| 2062 |
+
return result
|
| 2063 |
+
|
| 2064 |
+
|
| 2065 |
def _read_caption_sidecar(audio_path: Path) -> Optional[Dict[str, Any]]:
|
| 2066 |
+
"""Read .json or .txt caption sidecar."""
|
| 2067 |
+
json_path = audio_path.with_suffix(".json")
|
| 2068 |
+
if json_path.is_file():
|
| 2069 |
+
try:
|
| 2070 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 2071 |
+
return json.load(f)
|
| 2072 |
+
except Exception:
|
| 2073 |
+
pass
|
| 2074 |
+
txt_path = audio_path.with_suffix(".txt")
|
| 2075 |
+
if txt_path.is_file():
|
| 2076 |
+
try:
|
| 2077 |
+
with open(txt_path, "r", encoding="utf-8") as f:
|
| 2078 |
+
return _parse_txt_caption(f.read())
|
| 2079 |
+
except Exception:
|
| 2080 |
+
pass
|
| 2081 |
+
return None
|
| 2082 |
|
| 2083 |
|
| 2084 |
# ============================================================================
|
|
|
|
| 2089 |
audio_dir: str,
|
| 2090 |
output_dir: str,
|
| 2091 |
checkpoint_dir: str,
|
| 2092 |
+
device: str = "auto",
|
| 2093 |
variant: str = "base",
|
| 2094 |
max_duration: float = 0,
|
| 2095 |
progress_callback: Optional[Callable] = None,
|
|
|
|
| 2099 |
|
| 2100 |
Pass 1: Load VAE + text encoder, encode audio + text, save intermediates.
|
| 2101 |
Pass 2: Load DIT model, run encoder, build context, save final .pt files.
|
| 2102 |
+
|
| 2103 |
+
Args:
|
| 2104 |
+
device: "auto" to auto-detect GPU/CPU, or explicit device string.
|
| 2105 |
"""
|
| 2106 |
+
device = detect_device(device)
|
| 2107 |
+
logger.info("Preprocessing on device: %s", device)
|
| 2108 |
+
|
| 2109 |
out = Path(output_dir)
|
| 2110 |
out.mkdir(parents=True, exist_ok=True)
|
| 2111 |
|
|
|
|
| 2125 |
if max_duration <= 0:
|
| 2126 |
max_duration = _detect_max_duration(audio_files)
|
| 2127 |
|
| 2128 |
+
dtype = select_dtype(device)
|
| 2129 |
|
| 2130 |
# ---- Pass 1: VAE + Text Encoder ----
|
| 2131 |
logger.info("Pass 1/2: Loading VAE + Text Encoder...")
|
|
|
|
| 2257 |
finally:
|
| 2258 |
logger.info("Unloading VAE + Text Encoder...")
|
| 2259 |
unload_models(vae, text_enc, tokenizer, silence_lat)
|
| 2260 |
+
_clear_gpu_cache(device)
|
| 2261 |
|
| 2262 |
# ---- Pass 2: DIT Encoder ----
|
| 2263 |
if not intermediates:
|
|
|
|
| 2334 |
finally:
|
| 2335 |
logger.info("Unloading DIT model...")
|
| 2336 |
unload_models(model)
|
| 2337 |
+
_clear_gpu_cache(device)
|
| 2338 |
|
| 2339 |
failed = p1_failed + p2_failed
|
| 2340 |
return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)}
|
|
|
|
| 2361 |
save_every_n_epochs: int = 0,
|
| 2362 |
seed: int = 42,
|
| 2363 |
variant: str = "base",
|
| 2364 |
+
device: str = "auto",
|
| 2365 |
cfg_ratio: float = 0.15,
|
| 2366 |
timestep_mu: float = -0.4,
|
| 2367 |
timestep_sigma: float = 1.0,
|
|
|
|
| 2373 |
|
| 2374 |
This is a generator for Gradio live-update compatibility.
|
| 2375 |
Call cancel_training() to stop after the current epoch.
|
| 2376 |
+
|
| 2377 |
+
Args:
|
| 2378 |
+
device: "auto" to auto-detect GPU/CPU, or explicit device string.
|
| 2379 |
+
GPU uses mixed-precision (bfloat16/float16); CPU stays float32.
|
| 2380 |
"""
|
| 2381 |
_training_cancel.clear()
|
| 2382 |
train_start = time.time()
|
| 2383 |
|
| 2384 |
+
# Auto-detect device
|
| 2385 |
+
device = detect_device(device)
|
| 2386 |
+
dtype = select_dtype(device)
|
| 2387 |
+
dev_type = device.split(":")[0]
|
| 2388 |
+
use_amp = dev_type == "cuda"
|
| 2389 |
+
|
| 2390 |
if target_modules is None:
|
| 2391 |
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 2392 |
|
|
|
|
| 2398 |
out_path = Path(output_dir)
|
| 2399 |
out_path.mkdir(parents=True, exist_ok=True)
|
| 2400 |
|
| 2401 |
+
yield f"[INFO] Device: {device}, dtype: {dtype}, AMP: {use_amp}"
|
| 2402 |
+
|
| 2403 |
+
if dev_type == "cuda":
|
| 2404 |
+
gpu_name = torch.cuda.get_device_name(device)
|
| 2405 |
+
gpu_mem = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)
|
| 2406 |
+
yield f"[INFO] GPU: {gpu_name} ({gpu_mem:.1f} GiB VRAM)"
|
| 2407 |
+
|
| 2408 |
yield "[INFO] Loading model..."
|
| 2409 |
|
| 2410 |
try:
|
|
|
|
| 2413 |
yield f"[FAIL] Model load failed: {exc}"
|
| 2414 |
return
|
| 2415 |
|
| 2416 |
+
# Ensure model is in the correct dtype (load_model_for_training handles this,
|
| 2417 |
+
# but be explicit for safety)
|
| 2418 |
model = model.to(dtype=dtype)
|
| 2419 |
|
| 2420 |
+
# Move model to device if not already there (CPU path)
|
| 2421 |
+
if dev_type == "cpu":
|
| 2422 |
+
model = model.to(device=device)
|
| 2423 |
+
|
| 2424 |
yield "[INFO] Injecting LoRA..."
|
| 2425 |
|
| 2426 |
lora_cfg = LoRAConfig(
|
|
|
|
| 2456 |
loader = DataLoader(
|
| 2457 |
dataset, batch_size=batch_size, shuffle=True,
|
| 2458 |
num_workers=0, collate_fn=_collate_batch, drop_last=False,
|
| 2459 |
+
pin_memory=(dev_type == "cuda"),
|
| 2460 |
)
|
| 2461 |
|
| 2462 |
# Optimizer & scheduler
|
| 2463 |
torch.manual_seed(seed)
|
| 2464 |
random.seed(seed)
|
| 2465 |
+
if dev_type == "cuda":
|
| 2466 |
+
torch.cuda.manual_seed_all(seed)
|
| 2467 |
|
| 2468 |
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 2469 |
if not trainable_params:
|
|
|
|
| 2479 |
yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs"
|
| 2480 |
yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}"
|
| 2481 |
|
| 2482 |
+
# GradScaler for mixed precision on GPU (only for float16, not bfloat16)
|
| 2483 |
+
use_grad_scaler = use_amp and dtype == torch.float16
|
| 2484 |
+
grad_scaler = None
|
| 2485 |
+
if use_grad_scaler:
|
| 2486 |
+
grad_scaler = torch.cuda.amp.GradScaler()
|
| 2487 |
+
yield "[INFO] GradScaler enabled (float16 mixed precision)"
|
| 2488 |
+
|
| 2489 |
# Null condition embedding for CFG dropout
|
| 2490 |
null_cond = getattr(model, "null_condition_emb", None)
|
| 2491 |
|
|
|
|
| 2554 |
model.decoder.train()
|
| 2555 |
yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
|
| 2556 |
yield "[DONE]"
|
| 2557 |
+
_cuda_sync(device)
|
| 2558 |
unload_models(model)
|
| 2559 |
+
_clear_gpu_cache(device)
|
| 2560 |
return
|
| 2561 |
|
| 2562 |
# Timeout check
|
|
|
|
| 2567 |
save_lora_adapter(model, early_path)
|
| 2568 |
yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}"
|
| 2569 |
yield "[DONE]"
|
| 2570 |
+
_cuda_sync(device)
|
| 2571 |
unload_models(model)
|
| 2572 |
+
_clear_gpu_cache(device)
|
| 2573 |
return
|
| 2574 |
|
| 2575 |
epoch_loss = 0.0
|
|
|
|
| 2577 |
epoch_start = time.time()
|
| 2578 |
|
| 2579 |
for batch in loader:
|
| 2580 |
+
# Move batch tensors to device
|
| 2581 |
+
nb = dev_type != "cpu"
|
| 2582 |
tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb)
|
| 2583 |
att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb)
|
| 2584 |
enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb)
|
|
|
|
| 2603 |
if force_input_grads:
|
| 2604 |
xt = xt.requires_grad_(True)
|
| 2605 |
|
| 2606 |
+
# Decoder forward -- use AMP autocast on GPU for mixed precision
|
| 2607 |
+
if use_amp:
|
| 2608 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
| 2609 |
+
dec_out = model.decoder(
|
| 2610 |
+
hidden_states=xt,
|
| 2611 |
+
timestep=t,
|
| 2612 |
+
timestep_r=t,
|
| 2613 |
+
attention_mask=att,
|
| 2614 |
+
encoder_hidden_states=enc_hs,
|
| 2615 |
+
encoder_attention_mask=enc_mask,
|
| 2616 |
+
context_latents=ctx,
|
| 2617 |
+
)
|
| 2618 |
+
flow = x1 - x0
|
| 2619 |
+
loss = F.mse_loss(dec_out[0], flow)
|
| 2620 |
+
else:
|
| 2621 |
+
# CPU path -- no autocast
|
| 2622 |
+
dec_out = model.decoder(
|
| 2623 |
+
hidden_states=xt,
|
| 2624 |
+
timestep=t,
|
| 2625 |
+
timestep_r=t,
|
| 2626 |
+
attention_mask=att,
|
| 2627 |
+
encoder_hidden_states=enc_hs,
|
| 2628 |
+
encoder_attention_mask=enc_mask,
|
| 2629 |
+
context_latents=ctx,
|
| 2630 |
+
)
|
| 2631 |
+
flow = x1 - x0
|
| 2632 |
+
loss = F.mse_loss(dec_out[0], flow)
|
| 2633 |
|
|
|
|
|
|
|
| 2634 |
loss = loss.float() # fp32 for stable backward
|
| 2635 |
|
| 2636 |
# NaN guard
|
|
|
|
| 2639 |
del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
|
| 2640 |
if consecutive_nan >= MAX_NAN:
|
| 2641 |
yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting"
|
| 2642 |
+
_cuda_sync(device)
|
| 2643 |
unload_models(model)
|
| 2644 |
+
_clear_gpu_cache(device)
|
| 2645 |
return
|
| 2646 |
if acc_step > 0:
|
| 2647 |
optimizer.zero_grad(set_to_none=True)
|
|
|
|
| 2651 |
consecutive_nan = 0
|
| 2652 |
|
| 2653 |
loss = loss / gradient_accumulation_steps
|
| 2654 |
+
|
| 2655 |
+
# Backward -- use GradScaler on float16 GPU
|
| 2656 |
+
if grad_scaler is not None:
|
| 2657 |
+
grad_scaler.scale(loss).backward()
|
| 2658 |
+
else:
|
| 2659 |
+
loss.backward()
|
| 2660 |
+
|
| 2661 |
acc_loss += loss.item()
|
| 2662 |
del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
|
| 2663 |
acc_step += 1
|
| 2664 |
|
| 2665 |
if acc_step >= gradient_accumulation_steps:
|
| 2666 |
+
if grad_scaler is not None:
|
| 2667 |
+
grad_scaler.unscale_(optimizer)
|
| 2668 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
|
| 2669 |
+
grad_scaler.step(optimizer)
|
| 2670 |
+
grad_scaler.update()
|
| 2671 |
+
else:
|
| 2672 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
|
| 2673 |
+
optimizer.step()
|
| 2674 |
+
|
| 2675 |
scheduler.step()
|
| 2676 |
global_step += 1
|
| 2677 |
|
|
|
|
| 2692 |
acc_loss = 0.0
|
| 2693 |
acc_step = 0
|
| 2694 |
|
| 2695 |
+
# Periodic GPU cache cleanup
|
| 2696 |
+
if dev_type == "cuda" and global_step % log_every == 0:
|
| 2697 |
+
torch.cuda.empty_cache()
|
| 2698 |
+
|
| 2699 |
# Flush remainder
|
| 2700 |
if acc_step > 0:
|
| 2701 |
+
if grad_scaler is not None:
|
| 2702 |
+
grad_scaler.unscale_(optimizer)
|
| 2703 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
|
| 2704 |
+
grad_scaler.step(optimizer)
|
| 2705 |
+
grad_scaler.update()
|
| 2706 |
+
else:
|
| 2707 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
|
| 2708 |
+
optimizer.step()
|
| 2709 |
scheduler.step()
|
| 2710 |
global_step += 1
|
| 2711 |
avg_loss = acc_loss * gradient_accumulation_steps / acc_step
|
|
|
|
| 2754 |
model.decoder.train()
|
| 2755 |
yield f"[OK] Checkpoint saved at epoch {epoch + 1}"
|
| 2756 |
|
| 2757 |
+
# Clear GPU cache after epoch + checkpoint save
|
| 2758 |
+
_clear_gpu_cache(device)
|
| 2759 |
+
|
| 2760 |
# Sanity check
|
| 2761 |
if global_step == 0:
|
| 2762 |
yield "[FAIL] Training completed 0 steps -- no batches processed"
|
| 2763 |
+
_cuda_sync(device)
|
| 2764 |
unload_models(model)
|
| 2765 |
+
_clear_gpu_cache(device)
|
| 2766 |
return
|
| 2767 |
|
| 2768 |
# Final save (directly to output_dir, not a subdirectory)
|
|
|
|
| 2778 |
f" Adapter ready for inference."
|
| 2779 |
)
|
| 2780 |
yield "[DONE]"
|
| 2781 |
+
_cuda_sync(device)
|
| 2782 |
unload_models(model)
|
| 2783 |
+
_clear_gpu_cache(device)
|
| 2784 |
|
| 2785 |
|
| 2786 |
# ============================================================================
|