Nekochu commited on
Commit
917e4ed
·
1 Parent(s): 04c031f

add GPU/CUDA auto-detect, mixed precision, flash_attn, txt caption parser

Browse files
Files changed (1) hide show
  1. train_engine.py +299 -44
train_engine.py CHANGED
@@ -1,10 +1,16 @@
1
  """
2
- Standalone ACE-Step CPU LoRA Training Engine.
3
 
4
  Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
5
  module. No external Side-Step dependency required.
6
 
 
 
 
 
7
  Exports:
 
 
8
  preprocess_audio() - 2-pass sequential preprocessing
9
  train_lora_generator() - Generator-based LoRA training loop
10
  cancel_training() - Set the cancel flag
@@ -63,6 +69,93 @@ def cancel_training() -> None:
63
  _training_cancel.set()
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # ============================================================================
67
  # CONFIGS
68
  # ============================================================================
@@ -448,7 +541,12 @@ def _ensure_acestep_imports():
448
 
449
 
450
  def _attn_candidates(device: str) -> List[str]:
451
- """FA2 -> SDPA -> eager, filtered by availability."""
 
 
 
 
 
452
  candidates = []
453
  if device.startswith("cuda"):
454
  try:
@@ -457,8 +555,21 @@ def _attn_candidates(device: str) -> List[str]:
457
  props = torch.cuda.get_device_properties(dev_idx)
458
  if props.major >= 8:
459
  candidates.append("flash_attention_2")
460
- except (ImportError, Exception):
461
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  candidates.extend(["sdpa", "eager"])
463
  return candidates
464
 
@@ -469,8 +580,12 @@ def load_model_for_training(
469
  from transformers import AutoModel
470
 
471
  model_dir = _resolve_model_dir(checkpoint_dir, variant)
472
- # CPU always uses float32
473
- dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
 
 
 
 
474
 
475
  _ensure_acestep_imports()
476
 
@@ -489,7 +604,7 @@ def load_model_for_training(
489
  if device != "cpu":
490
  load_kwargs["device_map"] = {"": device}
491
  model = AutoModel.from_pretrained(str(model_dir), **load_kwargs)
492
- logger.info("Model loaded with attn_implementation=%s", attn)
493
  break
494
  except Exception as exc:
495
  err_text = str(exc)
@@ -499,11 +614,23 @@ def load_model_for_training(
499
  f" Original error: {err_text}"
500
  ) from exc
501
  last_err = exc
502
- logger.warning("attn backend '%s' failed: %s", attn, exc)
 
 
 
 
503
 
504
  if model is None:
505
  raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err
506
 
 
 
 
 
 
 
 
 
507
  for param in model.parameters():
508
  param.requires_grad = False
509
  model.eval()
@@ -517,10 +644,11 @@ def load_vae(checkpoint_dir: str, device: str = "cpu"):
517
  if not vae_path.is_dir():
518
  raise FileNotFoundError(f"VAE directory not found: {vae_path}")
519
 
520
- dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
521
  vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype)
522
  vae = vae.to(device=device)
523
  vae.eval()
 
524
  return vae
525
 
526
 
@@ -531,11 +659,12 @@ def load_text_encoder(checkpoint_dir: str, device: str = "cpu"):
531
  if not text_path.is_dir():
532
  raise FileNotFoundError(f"Text encoder not found: {text_path}")
533
 
534
- dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
535
  tokenizer = AutoTokenizer.from_pretrained(str(text_path))
536
  encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype)
537
  encoder = encoder.to(device=device)
538
  encoder.eval()
 
539
  return tokenizer, encoder
540
 
541
 
@@ -543,7 +672,7 @@ def load_silence_latent(
543
  checkpoint_dir: str, device: str = "cpu", variant: str = "base",
544
  ) -> torch.Tensor:
545
  ckpt = Path(checkpoint_dir)
546
- dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
547
 
548
  candidates = [ckpt / "silence_latent.pt"]
549
  subdir = _VARIANT_DIR.get(variant)
@@ -571,6 +700,14 @@ def unload_models(*models) -> None:
571
  pass
572
  del obj
573
  gc.collect()
 
 
 
 
 
 
 
 
574
 
575
 
576
  # ============================================================================
@@ -1904,16 +2041,44 @@ def _write_caption_sidecar(audio_path: Path, analysis: Dict[str, Any]) -> Path:
1904
  return sidecar_path
1905
 
1906
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1907
  def _read_caption_sidecar(audio_path: Path) -> Optional[Dict[str, Any]]:
1908
- """Read an existing .json caption sidecar if it exists."""
1909
- sidecar_path = audio_path.with_suffix(".json")
1910
- if not sidecar_path.is_file():
1911
- return None
1912
- try:
1913
- with open(sidecar_path, "r", encoding="utf-8") as f:
1914
- return json.load(f)
1915
- except Exception:
1916
- return None
 
 
 
 
 
 
 
1917
 
1918
 
1919
  # ============================================================================
@@ -1924,7 +2089,7 @@ def preprocess_audio(
1924
  audio_dir: str,
1925
  output_dir: str,
1926
  checkpoint_dir: str,
1927
- device: str = "cpu",
1928
  variant: str = "base",
1929
  max_duration: float = 0,
1930
  progress_callback: Optional[Callable] = None,
@@ -1934,7 +2099,13 @@ def preprocess_audio(
1934
 
1935
  Pass 1: Load VAE + text encoder, encode audio + text, save intermediates.
1936
  Pass 2: Load DIT model, run encoder, build context, save final .pt files.
 
 
 
1937
  """
 
 
 
1938
  out = Path(output_dir)
1939
  out.mkdir(parents=True, exist_ok=True)
1940
 
@@ -1954,7 +2125,7 @@ def preprocess_audio(
1954
  if max_duration <= 0:
1955
  max_duration = _detect_max_duration(audio_files)
1956
 
1957
- dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
1958
 
1959
  # ---- Pass 1: VAE + Text Encoder ----
1960
  logger.info("Pass 1/2: Loading VAE + Text Encoder...")
@@ -2086,6 +2257,7 @@ def preprocess_audio(
2086
  finally:
2087
  logger.info("Unloading VAE + Text Encoder...")
2088
  unload_models(vae, text_enc, tokenizer, silence_lat)
 
2089
 
2090
  # ---- Pass 2: DIT Encoder ----
2091
  if not intermediates:
@@ -2162,6 +2334,7 @@ def preprocess_audio(
2162
  finally:
2163
  logger.info("Unloading DIT model...")
2164
  unload_models(model)
 
2165
 
2166
  failed = p1_failed + p2_failed
2167
  return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)}
@@ -2188,7 +2361,7 @@ def train_lora_generator(
2188
  save_every_n_epochs: int = 0,
2189
  seed: int = 42,
2190
  variant: str = "base",
2191
- device: str = "cpu",
2192
  cfg_ratio: float = 0.15,
2193
  timestep_mu: float = -0.4,
2194
  timestep_sigma: float = 1.0,
@@ -2200,10 +2373,20 @@ def train_lora_generator(
2200
 
2201
  This is a generator for Gradio live-update compatibility.
2202
  Call cancel_training() to stop after the current epoch.
 
 
 
 
2203
  """
2204
  _training_cancel.clear()
2205
  train_start = time.time()
2206
 
 
 
 
 
 
 
2207
  if target_modules is None:
2208
  target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
2209
 
@@ -2215,6 +2398,13 @@ def train_lora_generator(
2215
  out_path = Path(output_dir)
2216
  out_path.mkdir(parents=True, exist_ok=True)
2217
 
 
 
 
 
 
 
 
2218
  yield "[INFO] Loading model..."
2219
 
2220
  try:
@@ -2223,10 +2413,14 @@ def train_lora_generator(
2223
  yield f"[FAIL] Model load failed: {exc}"
2224
  return
2225
 
2226
- # float32 on CPU (bfloat16 deadlocks)
2227
- dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
2228
  model = model.to(dtype=dtype)
2229
 
 
 
 
 
2230
  yield "[INFO] Injecting LoRA..."
2231
 
2232
  lora_cfg = LoRAConfig(
@@ -2262,11 +2456,14 @@ def train_lora_generator(
2262
  loader = DataLoader(
2263
  dataset, batch_size=batch_size, shuffle=True,
2264
  num_workers=0, collate_fn=_collate_batch, drop_last=False,
 
2265
  )
2266
 
2267
  # Optimizer & scheduler
2268
  torch.manual_seed(seed)
2269
  random.seed(seed)
 
 
2270
 
2271
  trainable_params = [p for p in model.parameters() if p.requires_grad]
2272
  if not trainable_params:
@@ -2282,6 +2479,13 @@ def train_lora_generator(
2282
  yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs"
2283
  yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}"
2284
 
 
 
 
 
 
 
 
2285
  # Null condition embedding for CFG dropout
2286
  null_cond = getattr(model, "null_condition_emb", None)
2287
 
@@ -2350,7 +2554,9 @@ def train_lora_generator(
2350
  model.decoder.train()
2351
  yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
2352
  yield "[DONE]"
 
2353
  unload_models(model)
 
2354
  return
2355
 
2356
  # Timeout check
@@ -2361,7 +2567,9 @@ def train_lora_generator(
2361
  save_lora_adapter(model, early_path)
2362
  yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}"
2363
  yield "[DONE]"
 
2364
  unload_models(model)
 
2365
  return
2366
 
2367
  epoch_loss = 0.0
@@ -2369,8 +2577,8 @@ def train_lora_generator(
2369
  epoch_start = time.time()
2370
 
2371
  for batch in loader:
2372
- # Forward
2373
- nb = device != "cpu"
2374
  tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb)
2375
  att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb)
2376
  enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb)
@@ -2395,19 +2603,34 @@ def train_lora_generator(
2395
  if force_input_grads:
2396
  xt = xt.requires_grad_(True)
2397
 
2398
- # Decoder forward
2399
- dec_out = model.decoder(
2400
- hidden_states=xt,
2401
- timestep=t,
2402
- timestep_r=t,
2403
- attention_mask=att,
2404
- encoder_hidden_states=enc_hs,
2405
- encoder_attention_mask=enc_mask,
2406
- context_latents=ctx,
2407
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2408
 
2409
- flow = x1 - x0
2410
- loss = F.mse_loss(dec_out[0], flow)
2411
  loss = loss.float() # fp32 for stable backward
2412
 
2413
  # NaN guard
@@ -2416,7 +2639,9 @@ def train_lora_generator(
2416
  del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
2417
  if consecutive_nan >= MAX_NAN:
2418
  yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting"
 
2419
  unload_models(model)
 
2420
  return
2421
  if acc_step > 0:
2422
  optimizer.zero_grad(set_to_none=True)
@@ -2426,14 +2651,27 @@ def train_lora_generator(
2426
  consecutive_nan = 0
2427
 
2428
  loss = loss / gradient_accumulation_steps
2429
- loss.backward()
 
 
 
 
 
 
2430
  acc_loss += loss.item()
2431
  del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
2432
  acc_step += 1
2433
 
2434
  if acc_step >= gradient_accumulation_steps:
2435
- torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
2436
- optimizer.step()
 
 
 
 
 
 
 
2437
  scheduler.step()
2438
  global_step += 1
2439
 
@@ -2454,10 +2692,20 @@ def train_lora_generator(
2454
  acc_loss = 0.0
2455
  acc_step = 0
2456
 
 
 
 
 
2457
  # Flush remainder
2458
  if acc_step > 0:
2459
- torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
2460
- optimizer.step()
 
 
 
 
 
 
2461
  scheduler.step()
2462
  global_step += 1
2463
  avg_loss = acc_loss * gradient_accumulation_steps / acc_step
@@ -2506,10 +2754,15 @@ def train_lora_generator(
2506
  model.decoder.train()
2507
  yield f"[OK] Checkpoint saved at epoch {epoch + 1}"
2508
 
 
 
 
2509
  # Sanity check
2510
  if global_step == 0:
2511
  yield "[FAIL] Training completed 0 steps -- no batches processed"
 
2512
  unload_models(model)
 
2513
  return
2514
 
2515
  # Final save (directly to output_dir, not a subdirectory)
@@ -2525,7 +2778,9 @@ def train_lora_generator(
2525
  f" Adapter ready for inference."
2526
  )
2527
  yield "[DONE]"
 
2528
  unload_models(model)
 
2529
 
2530
 
2531
  # ============================================================================
 
1
  """
2
+ Standalone ACE-Step LoRA Training Engine (CPU + GPU).
3
 
4
  Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
5
  module. No external Side-Step dependency required.
6
 
7
+ Auto-detects GPU (CUDA > MPS > CPU) and uses it when available,
8
+ falling back to CPU. bfloat16 is used on GPU; float32 is forced
9
+ on CPU (bfloat16 deadlocks on CPU -- known PyTorch bug).
10
+
11
  Exports:
12
+ detect_device() - Auto-detect best available device
13
+ select_dtype() - Pick dtype for a device
14
  preprocess_audio() - 2-pass sequential preprocessing
15
  train_lora_generator() - Generator-based LoRA training loop
16
  cancel_training() - Set the cancel flag
 
69
  _training_cancel.set()
70
 
71
 
72
+ # ============================================================================
73
+ # DEVICE DETECTION & DTYPE SELECTION
74
+ # ============================================================================
75
+
76
+ def detect_device(requested: str = "auto") -> str:
77
+ """Return the best available device string.
78
+
79
+ Priority: CUDA (best GPU by VRAM) > MPS (Apple Silicon) > CPU.
80
+ Pass an explicit device string (e.g. "cuda:0", "cpu") to skip
81
+ auto-detection.
82
+ """
83
+ if requested != "auto":
84
+ return requested
85
+
86
+ if torch.cuda.is_available():
87
+ # Pick the GPU with the most VRAM when multiple are present
88
+ count = torch.cuda.device_count()
89
+ if count <= 1:
90
+ best_idx = 0
91
+ else:
92
+ best_idx, best_mem = 0, 0
93
+ for i in range(count):
94
+ mem = torch.cuda.get_device_properties(i).total_memory
95
+ if mem > best_mem:
96
+ best_idx, best_mem = i, mem
97
+ if best_idx != 0:
98
+ logger.info(
99
+ "Multiple CUDA devices (%d). Selected cuda:%d (%s, %.0f MiB).",
100
+ count, best_idx,
101
+ torch.cuda.get_device_name(best_idx),
102
+ best_mem / (1024 ** 2),
103
+ )
104
+ device = f"cuda:{best_idx}"
105
+ logger.info("Auto-detected device: %s (%s)", device, torch.cuda.get_device_name(best_idx))
106
+ return device
107
+
108
+ if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available():
109
+ logger.info("Auto-detected device: mps (Apple Silicon)")
110
+ return "mps"
111
+
112
+ logger.info("Auto-detected device: cpu")
113
+ return "cpu"
114
+
115
+
116
+ def select_dtype(device: str) -> torch.dtype:
117
+ """Select the appropriate training dtype for *device*.
118
+
119
+ GPU: bfloat16 if supported, else float16.
120
+ CPU: MUST stay float32 (bfloat16 deadlocks on CPU).
121
+ """
122
+ dev_type = device.split(":")[0]
123
+ if dev_type == "cpu":
124
+ return CPU_DTYPE # always float32
125
+
126
+ if dev_type == "cuda":
127
+ # Prefer bfloat16 on Ampere+ (compute capability >= 8.0)
128
+ try:
129
+ idx = int(device.split(":")[1]) if ":" in device else 0
130
+ props = torch.cuda.get_device_properties(idx)
131
+ if props.major >= 8:
132
+ return torch.bfloat16
133
+ except Exception:
134
+ pass
135
+ return torch.float16
136
+
137
+ # MPS / other accelerators -- float32 is safest
138
+ if dev_type == "mps":
139
+ return torch.float32
140
+
141
+ return CPU_DTYPE
142
+
143
+
144
+ def _cuda_sync(device: str) -> None:
145
+ """Synchronize CUDA if the device is a CUDA device (no-op otherwise)."""
146
+ if device.startswith("cuda") and torch.cuda.is_available():
147
+ torch.cuda.synchronize()
148
+
149
+
150
+ def _clear_gpu_cache(device: str) -> None:
151
+ """Free cached GPU memory for the given device type."""
152
+ dev_type = device.split(":")[0]
153
+ if dev_type == "cuda" and torch.cuda.is_available():
154
+ torch.cuda.empty_cache()
155
+ elif dev_type == "mps" and hasattr(torch, "mps") and torch.mps.is_available():
156
+ torch.mps.empty_cache()
157
+
158
+
159
  # ============================================================================
160
  # CONFIGS
161
  # ============================================================================
 
541
 
542
 
543
  def _attn_candidates(device: str) -> List[str]:
544
+ """FA2 -> SDPA -> eager, filtered by availability.
545
+
546
+ On CUDA with flash_attn installed and compute capability >= 8.0,
547
+ flash_attention_2 is tried first. On CPU, flash_attention_2 is
548
+ always skipped (it requires CUDA).
549
+ """
550
  candidates = []
551
  if device.startswith("cuda"):
552
  try:
 
555
  props = torch.cuda.get_device_properties(dev_idx)
556
  if props.major >= 8:
557
  candidates.append("flash_attention_2")
558
+ logger.info(
559
+ "flash_attention_2 available (compute %d.%d, flash_attn installed)",
560
+ props.major, props.minor,
561
+ )
562
+ else:
563
+ logger.info(
564
+ "flash_attention_2 skipped: compute %d.%d < 8.0 (need Ampere+)",
565
+ props.major, props.minor,
566
+ )
567
+ except ImportError:
568
+ logger.info("flash_attention_2 skipped: flash_attn package not installed")
569
+ except Exception as exc:
570
+ logger.info("flash_attention_2 skipped: %s", exc)
571
+ else:
572
+ logger.info("flash_attention_2 skipped: device is %s (not CUDA)", device)
573
  candidates.extend(["sdpa", "eager"])
574
  return candidates
575
 
 
580
  from transformers import AutoModel
581
 
582
  model_dir = _resolve_model_dir(checkpoint_dir, variant)
583
+ dtype = select_dtype(device)
584
+
585
+ logger.info(
586
+ "Loading model from %s (variant=%s, device=%s, dtype=%s)",
587
+ model_dir, variant, device, dtype,
588
+ )
589
 
590
  _ensure_acestep_imports()
591
 
 
604
  if device != "cpu":
605
  load_kwargs["device_map"] = {"": device}
606
  model = AutoModel.from_pretrained(str(model_dir), **load_kwargs)
607
+ logger.info("Model loaded with attn_implementation=%s on %s", attn, device)
608
  break
609
  except Exception as exc:
610
  err_text = str(exc)
 
614
  f" Original error: {err_text}"
615
  ) from exc
616
  last_err = exc
617
+ next_attn = candidates[idx + 1] if idx + 1 < len(candidates) else None
618
+ if next_attn:
619
+ logger.warning("attn backend '%s' failed: %s; trying '%s'", attn, exc, next_attn)
620
+ else:
621
+ logger.warning("attn backend '%s' failed: %s", attn, exc)
622
 
623
  if model is None:
624
  raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err
625
 
626
+ # If device_map was not used (CPU), move model explicitly
627
+ if device != "cpu":
628
+ # device_map already placed weights; just verify dtype
629
+ if any(p.dtype != dtype for p in model.parameters()):
630
+ model = model.to(dtype=dtype)
631
+ else:
632
+ model = model.to(device=device, dtype=dtype)
633
+
634
  for param in model.parameters():
635
  param.requires_grad = False
636
  model.eval()
 
644
  if not vae_path.is_dir():
645
  raise FileNotFoundError(f"VAE directory not found: {vae_path}")
646
 
647
+ dtype = select_dtype(device)
648
  vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype)
649
  vae = vae.to(device=device)
650
  vae.eval()
651
+ logger.info("VAE loaded on %s (dtype=%s)", device, dtype)
652
  return vae
653
 
654
 
 
659
  if not text_path.is_dir():
660
  raise FileNotFoundError(f"Text encoder not found: {text_path}")
661
 
662
+ dtype = select_dtype(device)
663
  tokenizer = AutoTokenizer.from_pretrained(str(text_path))
664
  encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype)
665
  encoder = encoder.to(device=device)
666
  encoder.eval()
667
+ logger.info("Text encoder loaded on %s (dtype=%s)", device, dtype)
668
  return tokenizer, encoder
669
 
670
 
 
672
  checkpoint_dir: str, device: str = "cpu", variant: str = "base",
673
  ) -> torch.Tensor:
674
  ckpt = Path(checkpoint_dir)
675
+ dtype = select_dtype(device)
676
 
677
  candidates = [ckpt / "silence_latent.pt"]
678
  subdir = _VARIANT_DIR.get(variant)
 
700
  pass
701
  del obj
702
  gc.collect()
703
+ # Free GPU memory after unloading
704
+ if torch.cuda.is_available():
705
+ torch.cuda.empty_cache()
706
+ if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available():
707
+ try:
708
+ torch.mps.empty_cache()
709
+ except Exception:
710
+ pass
711
 
712
 
713
  # ============================================================================
 
2041
  return sidecar_path
2042
 
2043
 
2044
+ def _parse_txt_caption(text: str) -> Dict[str, Any]:
2045
+ """Parse user's .txt caption format into structured fields."""
2046
+ result: Dict[str, Any] = {}
2047
+ lyrics_match = re.search(r'lyrics say "(.*?)" at tempo', text, re.DOTALL)
2048
+ if lyrics_match:
2049
+ result["lyrics"] = lyrics_match.group(1).strip()
2050
+ caption_part = text[:lyrics_match.start()].strip().rstrip(",").strip()
2051
+ else:
2052
+ result["lyrics"] = "[Instrumental]"
2053
+ caption_part = text.strip()
2054
+ bpm_match = re.search(r'at tempo (\d+) BPM', text)
2055
+ if bpm_match:
2056
+ result["bpm"] = bpm_match.group(1)
2057
+ caption_part = re.sub(r'\s*at tempo \d+ BPM.*', '', caption_part).strip()
2058
+ key_match = re.search(r'in the key of ([A-G][#b]?[-\d]*)', text)
2059
+ if key_match:
2060
+ result["key"] = key_match.group(1)
2061
+ result["caption"] = caption_part if caption_part else text[:200]
2062
+ return result
2063
+
2064
+
2065
  def _read_caption_sidecar(audio_path: Path) -> Optional[Dict[str, Any]]:
2066
+ """Read .json or .txt caption sidecar."""
2067
+ json_path = audio_path.with_suffix(".json")
2068
+ if json_path.is_file():
2069
+ try:
2070
+ with open(json_path, "r", encoding="utf-8") as f:
2071
+ return json.load(f)
2072
+ except Exception:
2073
+ pass
2074
+ txt_path = audio_path.with_suffix(".txt")
2075
+ if txt_path.is_file():
2076
+ try:
2077
+ with open(txt_path, "r", encoding="utf-8") as f:
2078
+ return _parse_txt_caption(f.read())
2079
+ except Exception:
2080
+ pass
2081
+ return None
2082
 
2083
 
2084
  # ============================================================================
 
2089
  audio_dir: str,
2090
  output_dir: str,
2091
  checkpoint_dir: str,
2092
+ device: str = "auto",
2093
  variant: str = "base",
2094
  max_duration: float = 0,
2095
  progress_callback: Optional[Callable] = None,
 
2099
 
2100
  Pass 1: Load VAE + text encoder, encode audio + text, save intermediates.
2101
  Pass 2: Load DIT model, run encoder, build context, save final .pt files.
2102
+
2103
+ Args:
2104
+ device: "auto" to auto-detect GPU/CPU, or explicit device string.
2105
  """
2106
+ device = detect_device(device)
2107
+ logger.info("Preprocessing on device: %s", device)
2108
+
2109
  out = Path(output_dir)
2110
  out.mkdir(parents=True, exist_ok=True)
2111
 
 
2125
  if max_duration <= 0:
2126
  max_duration = _detect_max_duration(audio_files)
2127
 
2128
+ dtype = select_dtype(device)
2129
 
2130
  # ---- Pass 1: VAE + Text Encoder ----
2131
  logger.info("Pass 1/2: Loading VAE + Text Encoder...")
 
2257
  finally:
2258
  logger.info("Unloading VAE + Text Encoder...")
2259
  unload_models(vae, text_enc, tokenizer, silence_lat)
2260
+ _clear_gpu_cache(device)
2261
 
2262
  # ---- Pass 2: DIT Encoder ----
2263
  if not intermediates:
 
2334
  finally:
2335
  logger.info("Unloading DIT model...")
2336
  unload_models(model)
2337
+ _clear_gpu_cache(device)
2338
 
2339
  failed = p1_failed + p2_failed
2340
  return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)}
 
2361
  save_every_n_epochs: int = 0,
2362
  seed: int = 42,
2363
  variant: str = "base",
2364
+ device: str = "auto",
2365
  cfg_ratio: float = 0.15,
2366
  timestep_mu: float = -0.4,
2367
  timestep_sigma: float = 1.0,
 
2373
 
2374
  This is a generator for Gradio live-update compatibility.
2375
  Call cancel_training() to stop after the current epoch.
2376
+
2377
+ Args:
2378
+ device: "auto" to auto-detect GPU/CPU, or explicit device string.
2379
+ GPU uses mixed-precision (bfloat16/float16); CPU stays float32.
2380
  """
2381
  _training_cancel.clear()
2382
  train_start = time.time()
2383
 
2384
+ # Auto-detect device
2385
+ device = detect_device(device)
2386
+ dtype = select_dtype(device)
2387
+ dev_type = device.split(":")[0]
2388
+ use_amp = dev_type == "cuda"
2389
+
2390
  if target_modules is None:
2391
  target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
2392
 
 
2398
  out_path = Path(output_dir)
2399
  out_path.mkdir(parents=True, exist_ok=True)
2400
 
2401
+ yield f"[INFO] Device: {device}, dtype: {dtype}, AMP: {use_amp}"
2402
+
2403
+ if dev_type == "cuda":
2404
+ gpu_name = torch.cuda.get_device_name(device)
2405
+ gpu_mem = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)
2406
+ yield f"[INFO] GPU: {gpu_name} ({gpu_mem:.1f} GiB VRAM)"
2407
+
2408
  yield "[INFO] Loading model..."
2409
 
2410
  try:
 
2413
  yield f"[FAIL] Model load failed: {exc}"
2414
  return
2415
 
2416
+ # Ensure model is in the correct dtype (load_model_for_training handles this,
2417
+ # but be explicit for safety)
2418
  model = model.to(dtype=dtype)
2419
 
2420
+ # Move model to device if not already there (CPU path)
2421
+ if dev_type == "cpu":
2422
+ model = model.to(device=device)
2423
+
2424
  yield "[INFO] Injecting LoRA..."
2425
 
2426
  lora_cfg = LoRAConfig(
 
2456
  loader = DataLoader(
2457
  dataset, batch_size=batch_size, shuffle=True,
2458
  num_workers=0, collate_fn=_collate_batch, drop_last=False,
2459
+ pin_memory=(dev_type == "cuda"),
2460
  )
2461
 
2462
  # Optimizer & scheduler
2463
  torch.manual_seed(seed)
2464
  random.seed(seed)
2465
+ if dev_type == "cuda":
2466
+ torch.cuda.manual_seed_all(seed)
2467
 
2468
  trainable_params = [p for p in model.parameters() if p.requires_grad]
2469
  if not trainable_params:
 
2479
  yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs"
2480
  yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}"
2481
 
2482
+ # GradScaler for mixed precision on GPU (only for float16, not bfloat16)
2483
+ use_grad_scaler = use_amp and dtype == torch.float16
2484
+ grad_scaler = None
2485
+ if use_grad_scaler:
2486
+ grad_scaler = torch.cuda.amp.GradScaler()
2487
+ yield "[INFO] GradScaler enabled (float16 mixed precision)"
2488
+
2489
  # Null condition embedding for CFG dropout
2490
  null_cond = getattr(model, "null_condition_emb", None)
2491
 
 
2554
  model.decoder.train()
2555
  yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
2556
  yield "[DONE]"
2557
+ _cuda_sync(device)
2558
  unload_models(model)
2559
+ _clear_gpu_cache(device)
2560
  return
2561
 
2562
  # Timeout check
 
2567
  save_lora_adapter(model, early_path)
2568
  yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}"
2569
  yield "[DONE]"
2570
+ _cuda_sync(device)
2571
  unload_models(model)
2572
+ _clear_gpu_cache(device)
2573
  return
2574
 
2575
  epoch_loss = 0.0
 
2577
  epoch_start = time.time()
2578
 
2579
  for batch in loader:
2580
+ # Move batch tensors to device
2581
+ nb = dev_type != "cpu"
2582
  tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb)
2583
  att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb)
2584
  enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb)
 
2603
  if force_input_grads:
2604
  xt = xt.requires_grad_(True)
2605
 
2606
+ # Decoder forward -- use AMP autocast on GPU for mixed precision
2607
+ if use_amp:
2608
+ with torch.cuda.amp.autocast(dtype=dtype):
2609
+ dec_out = model.decoder(
2610
+ hidden_states=xt,
2611
+ timestep=t,
2612
+ timestep_r=t,
2613
+ attention_mask=att,
2614
+ encoder_hidden_states=enc_hs,
2615
+ encoder_attention_mask=enc_mask,
2616
+ context_latents=ctx,
2617
+ )
2618
+ flow = x1 - x0
2619
+ loss = F.mse_loss(dec_out[0], flow)
2620
+ else:
2621
+ # CPU path -- no autocast
2622
+ dec_out = model.decoder(
2623
+ hidden_states=xt,
2624
+ timestep=t,
2625
+ timestep_r=t,
2626
+ attention_mask=att,
2627
+ encoder_hidden_states=enc_hs,
2628
+ encoder_attention_mask=enc_mask,
2629
+ context_latents=ctx,
2630
+ )
2631
+ flow = x1 - x0
2632
+ loss = F.mse_loss(dec_out[0], flow)
2633
 
 
 
2634
  loss = loss.float() # fp32 for stable backward
2635
 
2636
  # NaN guard
 
2639
  del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
2640
  if consecutive_nan >= MAX_NAN:
2641
  yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting"
2642
+ _cuda_sync(device)
2643
  unload_models(model)
2644
+ _clear_gpu_cache(device)
2645
  return
2646
  if acc_step > 0:
2647
  optimizer.zero_grad(set_to_none=True)
 
2651
  consecutive_nan = 0
2652
 
2653
  loss = loss / gradient_accumulation_steps
2654
+
2655
+ # Backward -- use GradScaler on float16 GPU
2656
+ if grad_scaler is not None:
2657
+ grad_scaler.scale(loss).backward()
2658
+ else:
2659
+ loss.backward()
2660
+
2661
  acc_loss += loss.item()
2662
  del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
2663
  acc_step += 1
2664
 
2665
  if acc_step >= gradient_accumulation_steps:
2666
+ if grad_scaler is not None:
2667
+ grad_scaler.unscale_(optimizer)
2668
+ torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
2669
+ grad_scaler.step(optimizer)
2670
+ grad_scaler.update()
2671
+ else:
2672
+ torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
2673
+ optimizer.step()
2674
+
2675
  scheduler.step()
2676
  global_step += 1
2677
 
 
2692
  acc_loss = 0.0
2693
  acc_step = 0
2694
 
2695
+ # Periodic GPU cache cleanup
2696
+ if dev_type == "cuda" and global_step % log_every == 0:
2697
+ torch.cuda.empty_cache()
2698
+
2699
  # Flush remainder
2700
  if acc_step > 0:
2701
+ if grad_scaler is not None:
2702
+ grad_scaler.unscale_(optimizer)
2703
+ torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
2704
+ grad_scaler.step(optimizer)
2705
+ grad_scaler.update()
2706
+ else:
2707
+ torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
2708
+ optimizer.step()
2709
  scheduler.step()
2710
  global_step += 1
2711
  avg_loss = acc_loss * gradient_accumulation_steps / acc_step
 
2754
  model.decoder.train()
2755
  yield f"[OK] Checkpoint saved at epoch {epoch + 1}"
2756
 
2757
+ # Clear GPU cache after epoch + checkpoint save
2758
+ _clear_gpu_cache(device)
2759
+
2760
  # Sanity check
2761
  if global_step == 0:
2762
  yield "[FAIL] Training completed 0 steps -- no batches processed"
2763
+ _cuda_sync(device)
2764
  unload_models(model)
2765
+ _clear_gpu_cache(device)
2766
  return
2767
 
2768
  # Final save (directly to output_dir, not a subdirectory)
 
2778
  f" Adapter ready for inference."
2779
  )
2780
  yield "[DONE]"
2781
+ _cuda_sync(device)
2782
  unload_models(model)
2783
+ _clear_gpu_cache(device)
2784
 
2785
 
2786
  # ============================================================================