Spaces:
Running
Running
| """ | |
| Standalone ACE-Step CPU LoRA Training Engine. | |
| Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained | |
| module. No external Side-Step dependency required. | |
| Exports: | |
| preprocess_audio() - 2-pass sequential preprocessing | |
| train_lora_generator() - Generator-based LoRA training loop | |
| cancel_training() - Set the cancel flag | |
| get_trained_loras() - List saved adapters | |
| """ | |
| from __future__ import annotations | |
| import gc | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import sys | |
| import time | |
| import types | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Generator, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import ( | |
| CosineAnnealingLR, | |
| LinearLR, | |
| SequentialLR, | |
| ) | |
| from torch.utils.data import DataLoader, Dataset | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Configurable caps (edit these at the top of the file) | |
| # --------------------------------------------------------------------------- | |
| MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file | |
| MAX_TRAINING_TIME = 28800 # 8 hours hard timeout | |
| TARGET_SR = 48000 | |
| AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"}) | |
| # bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32 | |
| CPU_DTYPE = torch.float32 | |
| import threading | |
| _training_cancel = threading.Event() | |
| def cancel_training() -> None: | |
| _training_cancel.set() | |
| # ============================================================================ | |
| # CONFIGS | |
| # ============================================================================ | |
| class LoRAConfig: | |
| r: int = 64 | |
| alpha: int = 128 | |
| dropout: float = 0.1 | |
| target_modules: List[str] = field(default_factory=lambda: [ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| ]) | |
| bias: str = "none" | |
| attention_type: str = "both" | |
| target_mlp: bool = True | |
| # ============================================================================ | |
| # TIMESTEP SAMPLING & CFG DROPOUT | |
| # ============================================================================ | |
| def sample_timesteps( | |
| batch_size: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| timestep_mu: float = -0.4, | |
| timestep_sigma: float = 1.0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| t = torch.sigmoid( | |
| torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu | |
| ) | |
| r = torch.sigmoid( | |
| torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu | |
| ) | |
| t, r = torch.maximum(t, r), torch.minimum(t, r) | |
| # use_meanflow=False forces r=t (ACE-Step convention) | |
| return t, t | |
| def apply_cfg_dropout( | |
| encoder_hidden_states: torch.Tensor, | |
| null_condition_emb: torch.Tensor, | |
| cfg_ratio: float = 0.15, | |
| ) -> torch.Tensor: | |
| bsz = encoder_hidden_states.shape[0] | |
| device = encoder_hidden_states.device | |
| dtype = encoder_hidden_states.dtype | |
| mask = torch.where( | |
| torch.rand(size=(bsz,), device=device, dtype=dtype) < cfg_ratio, | |
| torch.zeros(size=(bsz,), device=device, dtype=dtype), | |
| torch.ones(size=(bsz,), device=device, dtype=dtype), | |
| ).view(-1, 1, 1) | |
| return torch.where( | |
| mask > 0, | |
| encoder_hidden_states, | |
| null_condition_emb.expand_as(encoder_hidden_states), | |
| ) | |
| # ============================================================================ | |
| # OPTIMIZER (Adafactor preferred for CPU -- 1.5 bytes/param) | |
| # ============================================================================ | |
| def build_optimizer( | |
| params, lr: float = 1e-4, weight_decay: float = 0.01, | |
| ) -> torch.optim.Optimizer: | |
| try: | |
| from transformers.optimization import Adafactor | |
| logger.info("Using Adafactor optimizer (minimal state memory)") | |
| return Adafactor( | |
| params, lr=lr, weight_decay=weight_decay, | |
| scale_parameter=False, relative_step=False, | |
| ) | |
| except ImportError: | |
| logger.warning("transformers not installed, falling back to AdamW") | |
| return AdamW(params, lr=lr, weight_decay=weight_decay) | |
| def build_scheduler( | |
| optimizer, total_steps: int, warmup_steps: int, lr: float, | |
| ): | |
| _max_warmup = max(1, total_steps // 10) | |
| if warmup_steps > _max_warmup: | |
| warmup_steps = _max_warmup | |
| warmup = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps) | |
| remaining = max(1, total_steps - warmup_steps) | |
| main = CosineAnnealingLR(optimizer, T_max=remaining, eta_min=lr * 0.01) | |
| return SequentialLR(optimizer, [warmup, main], milestones=[warmup_steps]) | |
| # ============================================================================ | |
| # DATASET | |
| # ============================================================================ | |
| def _collate_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: | |
| max_t = max(s["target_latents"].shape[0] for s in batch) | |
| max_e = max(s["encoder_hidden_states"].shape[0] for s in batch) | |
| def pad(t, max_len, dim=0): | |
| diff = max_len - t.shape[dim] | |
| if diff <= 0: | |
| return t | |
| shape = list(t.shape) | |
| shape[dim] = diff | |
| return torch.cat([t, t.new_zeros(*shape)], dim=dim) | |
| return { | |
| "target_latents": torch.stack([pad(s["target_latents"], max_t) for s in batch]), | |
| "attention_mask": torch.stack([pad(s["attention_mask"], max_t) for s in batch]), | |
| "encoder_hidden_states": torch.stack([pad(s["encoder_hidden_states"], max_e) for s in batch]), | |
| "encoder_attention_mask": torch.stack([pad(s["encoder_attention_mask"], max_e) for s in batch]), | |
| "context_latents": torch.stack([pad(s["context_latents"], max_t) for s in batch]), | |
| } | |
| class TensorDataset(Dataset): | |
| _REQUIRED = frozenset([ | |
| "target_latents", "attention_mask", "encoder_hidden_states", | |
| "encoder_attention_mask", "context_latents", | |
| ]) | |
| def __init__(self, tensor_dir: str): | |
| self.paths: List[str] = [] | |
| for f in sorted(os.listdir(tensor_dir)): | |
| if f.endswith(".pt") and not f.endswith(".tmp.pt") and f != "manifest.json": | |
| self.paths.append(str(Path(tensor_dir) / f)) | |
| def __len__(self) -> int: | |
| return len(self.paths) | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| data = torch.load(self.paths[idx], map_location="cpu", weights_only=True) | |
| missing = self._REQUIRED - data.keys() | |
| if missing: | |
| raise KeyError(f"Missing keys {sorted(missing)} in {self.paths[idx]}") | |
| for k in ("target_latents", "encoder_hidden_states", "context_latents"): | |
| t = data[k] | |
| if torch.isnan(t).any() or torch.isinf(t).any(): | |
| t.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) | |
| return {k: data[k] for k in self._REQUIRED} | |
| # ============================================================================ | |
| # GRADIENT CHECKPOINTING | |
| # ============================================================================ | |
| def _find_decoder_layers(decoder: nn.Module) -> Optional[nn.ModuleList]: | |
| for attr in ("layers", "blocks", "transformer_blocks"): | |
| c = getattr(decoder, attr, None) | |
| if isinstance(c, nn.ModuleList) and len(c) > 0: | |
| return c | |
| for child in decoder.children(): | |
| for attr in ("layers", "blocks", "transformer_blocks"): | |
| c = getattr(child, attr, None) | |
| if isinstance(c, nn.ModuleList) and len(c) > 0: | |
| return c | |
| return None | |
| def enable_gradient_checkpointing(decoder: nn.Module) -> bool: | |
| """Enable gradient checkpointing on the decoder to save memory.""" | |
| enabled = False | |
| # Walk wrapper chain | |
| stack = [decoder] | |
| visited = set() | |
| while stack: | |
| mod = stack.pop() | |
| if not isinstance(mod, nn.Module): | |
| continue | |
| mid = id(mod) | |
| if mid in visited: | |
| continue | |
| visited.add(mid) | |
| if hasattr(mod, "gradient_checkpointing_enable"): | |
| try: | |
| mod.gradient_checkpointing_enable() | |
| enabled = True | |
| except Exception: | |
| pass | |
| elif hasattr(mod, "gradient_checkpointing"): | |
| try: | |
| mod.gradient_checkpointing = True | |
| enabled = True | |
| except Exception: | |
| pass | |
| if hasattr(mod, "enable_input_require_grads"): | |
| try: | |
| mod.enable_input_require_grads() | |
| except Exception: | |
| pass | |
| cfg = getattr(mod, "config", None) | |
| if cfg is not None and hasattr(cfg, "use_cache"): | |
| try: | |
| cfg.use_cache = False | |
| except Exception: | |
| pass | |
| for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"): | |
| child = getattr(mod, a, None) | |
| if isinstance(child, nn.Module): | |
| stack.append(child) | |
| return enabled | |
| def force_disable_cache(decoder: nn.Module) -> None: | |
| stack = [decoder] | |
| visited = set() | |
| while stack: | |
| mod = stack.pop() | |
| if not isinstance(mod, nn.Module): | |
| continue | |
| mid = id(mod) | |
| if mid in visited: | |
| continue | |
| visited.add(mid) | |
| cfg = getattr(mod, "config", None) | |
| if cfg is not None and hasattr(cfg, "use_cache"): | |
| try: | |
| cfg.use_cache = False | |
| except Exception: | |
| pass | |
| for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"): | |
| child = getattr(mod, a, None) | |
| if isinstance(child, nn.Module): | |
| stack.append(child) | |
| # ============================================================================ | |
| # LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT) | |
| # ============================================================================ | |
| def _unwrap_decoder(model): | |
| decoder = model.decoder if hasattr(model, "decoder") else model | |
| while hasattr(decoder, "_forward_module"): | |
| decoder = decoder._forward_module | |
| if hasattr(decoder, "base_model"): | |
| bm = decoder.base_model | |
| decoder = bm.model if hasattr(bm, "model") else bm | |
| if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module): | |
| decoder = decoder.model | |
| return decoder | |
| def inject_lora(model, lora_cfg: LoRAConfig) -> Tuple[Any, Dict[str, Any]]: | |
| from peft import get_peft_model, LoraConfig as PeftLoraConfig, TaskType | |
| decoder = _unwrap_decoder(model) | |
| model.decoder = decoder | |
| # Guard enable_input_require_grads for DiT (no get_input_embeddings) | |
| if hasattr(decoder, "enable_input_require_grads"): | |
| orig = decoder.enable_input_require_grads | |
| def _safe(self): | |
| try: | |
| return orig() | |
| except NotImplementedError: | |
| return None | |
| decoder.enable_input_require_grads = types.MethodType(_safe, decoder) | |
| if hasattr(decoder, "is_gradient_checkpointing"): | |
| try: | |
| decoder.is_gradient_checkpointing = False | |
| except Exception: | |
| pass | |
| peft_cfg = PeftLoraConfig( | |
| r=lora_cfg.r, | |
| lora_alpha=lora_cfg.alpha, | |
| lora_dropout=lora_cfg.dropout, | |
| target_modules=lora_cfg.target_modules, | |
| bias=lora_cfg.bias, | |
| task_type=TaskType.FEATURE_EXTRACTION, | |
| ) | |
| model.decoder = get_peft_model(decoder, peft_cfg) | |
| for name, param in model.named_parameters(): | |
| if "lora_" not in name: | |
| param.requires_grad = False | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| return model, { | |
| "total_params": total, | |
| "trainable_params": trainable, | |
| "trainable_ratio": trainable / total if total > 0 else 0, | |
| } | |
| def save_lora_adapter(model, output_dir: str) -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| decoder = model.decoder if hasattr(model, "decoder") else model | |
| while hasattr(decoder, "_forward_module"): | |
| decoder = decoder._forward_module | |
| if hasattr(decoder, "save_pretrained"): | |
| decoder.save_pretrained(output_dir) | |
| # Scrub base_model path for portability | |
| cfg_path = os.path.join(output_dir, "adapter_config.json") | |
| if os.path.isfile(cfg_path): | |
| try: | |
| with open(cfg_path, "r") as f: | |
| cfg = json.load(f) | |
| if cfg.get("base_model_name_or_path"): | |
| cfg["base_model_name_or_path"] = "" | |
| with open(cfg_path, "w") as f: | |
| json.dump(cfg, f, indent=2) | |
| except Exception: | |
| pass | |
| logger.info("LoRA adapter saved to %s", output_dir) | |
| else: | |
| # Fallback: manual extraction | |
| state = {} | |
| for name, param in decoder.named_parameters(): | |
| if "lora_" in name: | |
| state[name] = param.data.clone() | |
| if state: | |
| try: | |
| from safetensors.torch import save_file | |
| save_file(state, str(Path(output_dir) / "adapter_model.safetensors")) | |
| except ImportError: | |
| torch.save(state, str(Path(output_dir) / "lora_weights.pt")) | |
| logger.info("LoRA adapter saved (fallback) to %s", output_dir) | |
| # ============================================================================ | |
| # MODEL LOADING (FA2 -> SDPA -> eager fallback) | |
| # ============================================================================ | |
| _VARIANT_DIR = { | |
| "turbo": "acestep-v15-turbo", | |
| "base": "acestep-v15-base", | |
| "sft": "acestep-v15-sft", | |
| } | |
| def _resolve_model_dir(checkpoint_dir: str, variant: str) -> Path: | |
| base = Path(checkpoint_dir).resolve() | |
| subdir = _VARIANT_DIR.get(variant) | |
| if subdir: | |
| p = (Path(checkpoint_dir) / subdir).resolve() | |
| if p.is_dir(): | |
| return p | |
| p = (Path(checkpoint_dir) / variant).resolve() | |
| if p.is_dir(): | |
| return p | |
| raise FileNotFoundError( | |
| f"Model directory not found: tried {_VARIANT_DIR.get(variant, variant)!r} " | |
| f"and {variant!r} under {checkpoint_dir}" | |
| ) | |
| def _ensure_acestep_imports(): | |
| """Register stub modules so AutoModel can load ACE-Step checkpoints.""" | |
| for name in ( | |
| "acestep", "acestep.models", "acestep.models.common", | |
| "acestep.models.xl_base", "acestep.models.xl_turbo", "acestep.models.xl_sft", | |
| ): | |
| if name not in sys.modules: | |
| stub = types.ModuleType(name) | |
| stub.__path__ = [] | |
| sys.modules[name] = stub | |
| # Try to load real modules from adjacent ACE-Step checkout | |
| for name in ( | |
| "acestep.models.common.configuration_acestep_v15", | |
| "acestep.models.common.apg_guidance", | |
| ): | |
| if name not in sys.modules: | |
| sys.modules[name] = types.ModuleType(name) | |
| def _attn_candidates(device: str) -> List[str]: | |
| """FA2 -> SDPA -> eager, filtered by availability.""" | |
| candidates = [] | |
| if device.startswith("cuda"): | |
| try: | |
| import flash_attn # noqa: F401 | |
| dev_idx = int(device.split(":")[1]) if ":" in device else 0 | |
| props = torch.cuda.get_device_properties(dev_idx) | |
| if props.major >= 8: | |
| candidates.append("flash_attention_2") | |
| except (ImportError, Exception): | |
| pass | |
| candidates.extend(["sdpa", "eager"]) | |
| return candidates | |
| def load_model_for_training( | |
| checkpoint_dir: str, variant: str = "base", device: str = "cpu", | |
| ) -> Any: | |
| from transformers import AutoModel | |
| model_dir = _resolve_model_dir(checkpoint_dir, variant) | |
| # CPU always uses float32 | |
| dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 | |
| _ensure_acestep_imports() | |
| candidates = _attn_candidates(device) | |
| model = None | |
| last_err = None | |
| for idx, attn in enumerate(candidates): | |
| try: | |
| load_kwargs = dict( | |
| trust_remote_code=True, | |
| attn_implementation=attn, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=False, | |
| ) | |
| if device != "cpu": | |
| load_kwargs["device_map"] = {"": device} | |
| model = AutoModel.from_pretrained(str(model_dir), **load_kwargs) | |
| logger.info("Model loaded with attn_implementation=%s", attn) | |
| break | |
| except Exception as exc: | |
| err_text = str(exc) | |
| if "packages that were not found" in err_text or "No module named" in err_text: | |
| raise RuntimeError( | |
| f"Model files in {model_dir} require a missing Python package.\n" | |
| f" Original error: {err_text}" | |
| ) from exc | |
| last_err = exc | |
| logger.warning("attn backend '%s' failed: %s", attn, exc) | |
| if model is None: | |
| raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| return model | |
| def load_vae(checkpoint_dir: str, device: str = "cpu"): | |
| from diffusers.models import AutoencoderOobleck | |
| vae_path = Path(checkpoint_dir) / "vae" | |
| if not vae_path.is_dir(): | |
| raise FileNotFoundError(f"VAE directory not found: {vae_path}") | |
| dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 | |
| vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype) | |
| vae = vae.to(device=device) | |
| vae.eval() | |
| return vae | |
| def load_text_encoder(checkpoint_dir: str, device: str = "cpu"): | |
| from transformers import AutoModel, AutoTokenizer | |
| text_path = Path(checkpoint_dir) / "Qwen3-Embedding-0.6B" | |
| if not text_path.is_dir(): | |
| raise FileNotFoundError(f"Text encoder not found: {text_path}") | |
| dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 | |
| tokenizer = AutoTokenizer.from_pretrained(str(text_path)) | |
| encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype) | |
| encoder = encoder.to(device=device) | |
| encoder.eval() | |
| return tokenizer, encoder | |
| def load_silence_latent( | |
| checkpoint_dir: str, device: str = "cpu", variant: str = "base", | |
| ) -> torch.Tensor: | |
| ckpt = Path(checkpoint_dir) | |
| dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 | |
| candidates = [ckpt / "silence_latent.pt"] | |
| subdir = _VARIANT_DIR.get(variant) | |
| if subdir: | |
| candidates.append(ckpt / subdir / "silence_latent.pt") | |
| for sd in _VARIANT_DIR.values(): | |
| candidates.append(ckpt / sd / "silence_latent.pt") | |
| for c in candidates: | |
| if c.is_file(): | |
| sl = torch.load(str(c), weights_only=True).transpose(1, 2) | |
| return sl.to(device=device, dtype=dtype) | |
| raise FileNotFoundError(f"silence_latent.pt not found under {ckpt}") | |
| def unload_models(*models) -> None: | |
| for obj in models: | |
| if obj is None: | |
| continue | |
| if hasattr(obj, "to"): | |
| try: | |
| obj.to("cpu") | |
| except Exception: | |
| pass | |
| del obj | |
| gc.collect() | |
| # ============================================================================ | |
| # AUDIO LOADING | |
| # ============================================================================ | |
| def load_audio_stereo( | |
| audio_path: str, target_sr: int, max_duration: float, | |
| ) -> Tuple[torch.Tensor, int]: | |
| import numpy as np | |
| try: | |
| import soundfile as sf | |
| data, sr = sf.read(audio_path, dtype="float32", always_2d=True) | |
| audio_np = np.ascontiguousarray(data.T) | |
| sr = int(sr) | |
| if sr != target_sr: | |
| import librosa | |
| audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr, axis=1) | |
| sr = target_sr | |
| audio = torch.from_numpy(np.ascontiguousarray(audio_np)) | |
| except Exception: | |
| import torchaudio | |
| audio, sr = torchaudio.load(audio_path) | |
| sr = int(sr) | |
| if sr != target_sr: | |
| audio = torchaudio.transforms.Resample(sr, target_sr)(audio) | |
| sr = target_sr | |
| if audio.shape[0] == 1: | |
| audio = audio.repeat(2, 1) | |
| elif audio.shape[0] > 2: | |
| audio = audio[:2, :] | |
| max_samples = int(max_duration * target_sr) | |
| if audio.shape[1] > max_samples: | |
| audio = audio[:, :max_samples] | |
| return audio, sr | |
| # ============================================================================ | |
| # TEXT / LYRICS ENCODING | |
| # ============================================================================ | |
| def encode_text(text_encoder, tokenizer, text_prompt: str, device, dtype): | |
| inputs = tokenizer( | |
| text_prompt, padding="max_length", max_length=256, | |
| truncation=True, return_tensors="pt", | |
| ) | |
| ids = inputs.input_ids.to(device) | |
| mask = inputs.attention_mask.to(device).to(dtype) | |
| enc_dev = next(text_encoder.parameters()).device | |
| if ids.device != enc_dev: | |
| ids = ids.to(enc_dev) | |
| mask = mask.to(enc_dev) | |
| with torch.no_grad(): | |
| hs = text_encoder(ids).last_hidden_state.to(dtype) | |
| return hs, mask | |
| def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype): | |
| inputs = tokenizer( | |
| lyrics, padding="max_length", max_length=512, | |
| truncation=True, return_tensors="pt", | |
| ) | |
| ids = inputs.input_ids.to(device) | |
| mask = inputs.attention_mask.to(device).to(dtype) | |
| enc_dev = next(text_encoder.parameters()).device | |
| if ids.device != enc_dev: | |
| ids = ids.to(enc_dev) | |
| mask = mask.to(enc_dev) | |
| with torch.no_grad(): | |
| hs = text_encoder.embed_tokens(ids).to(dtype) | |
| return hs, mask | |
| # ============================================================================ | |
| # VAE TILED ENCODING | |
| # ============================================================================ | |
| def tiled_vae_encode( | |
| vae, audio: torch.Tensor, dtype: torch.dtype, | |
| chunk_size: Optional[int] = None, overlap: int = 96000, | |
| ) -> torch.Tensor: | |
| vae_device = next(vae.parameters()).device | |
| vae_dtype = vae.dtype | |
| if chunk_size is None: | |
| chunk_size = TARGET_SR * 30 | |
| B, C, S = audio.shape | |
| if S <= chunk_size: | |
| vae_input = audio.to(vae_device, dtype=vae_dtype) | |
| with torch.inference_mode(): | |
| latents = vae.encode(vae_input).latent_dist.sample() | |
| return latents.transpose(1, 2).to(dtype) | |
| stride = chunk_size - 2 * overlap | |
| if stride <= 0: | |
| raise ValueError(f"chunk_size ({chunk_size}) must be > 2 * overlap ({overlap})") | |
| num_steps = math.ceil(S / stride) | |
| ds_factor = None | |
| write_pos = 0 | |
| final = None | |
| for i in range(num_steps): | |
| core_start = i * stride | |
| core_end = min(core_start + stride, S) | |
| win_start = max(0, core_start - overlap) | |
| win_end = min(S, core_end + overlap) | |
| chunk = audio[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype) | |
| with torch.inference_mode(): | |
| lat = vae.encode(chunk).latent_dist.sample() | |
| if ds_factor is None: | |
| ds_factor = chunk.shape[-1] / lat.shape[-1] | |
| total_len = int(round(S / ds_factor)) | |
| final = torch.zeros(B, lat.shape[1], total_len, dtype=lat.dtype, device="cpu") | |
| trim_start = int(round((core_start - win_start) / ds_factor)) | |
| trim_end = int(round((win_end - core_end) / ds_factor)) | |
| end_idx = lat.shape[-1] - trim_end if trim_end > 0 else lat.shape[-1] | |
| core = lat[:, :, trim_start:end_idx] | |
| core_len = core.shape[-1] | |
| final[:, :, write_pos:write_pos + core_len] = core.cpu() | |
| write_pos += core_len | |
| del chunk, lat, core | |
| final = final[:, :, :write_pos] | |
| return final.transpose(1, 2).to(dtype) | |
| # ============================================================================ | |
| # ENCODER / CONTEXT HELPERS | |
| # ============================================================================ | |
| def run_encoder( | |
| model, text_hs, text_mask, lyric_hs, lyric_mask, device, dtype, | |
| ): | |
| refer = torch.zeros(1, 1, 64, device=device, dtype=dtype) | |
| order_mask = torch.zeros(1, device=device, dtype=torch.long) | |
| with torch.no_grad(): | |
| enc_hs, enc_mask = model.encoder( | |
| text_hidden_states=text_hs, | |
| text_attention_mask=text_mask, | |
| lyric_hidden_states=lyric_hs, | |
| lyric_attention_mask=lyric_mask, | |
| refer_audio_acoustic_hidden_states_packed=refer, | |
| refer_audio_order_mask=order_mask, | |
| ) | |
| return enc_hs, enc_mask | |
| def build_context_latents(silence_latent, latent_length: int, device, dtype): | |
| src = silence_latent[:, :latent_length, :].to(dtype) | |
| if src.shape[0] < 1: | |
| src = src.expand(1, -1, -1) | |
| if src.shape[1] < latent_length: | |
| pad_len = latent_length - src.shape[1] | |
| src = torch.cat([src, silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)], dim=1) | |
| elif src.shape[1] > latent_length: | |
| src = src[:, :latent_length, :] | |
| masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype) | |
| return torch.cat([src, masks], dim=-1) | |
| # ============================================================================ | |
| # AUDIO DISCOVERY | |
| # ============================================================================ | |
| def _discover_audio_files(audio_dir: str) -> List[Path]: | |
| files = [] | |
| for root, _, names in os.walk(audio_dir): | |
| for name in sorted(names): | |
| if Path(name).suffix.lower() in AUDIO_EXTENSIONS: | |
| files.append(Path(root) / name) | |
| return files | |
| def _detect_max_duration(files: List[Path]) -> float: | |
| """Return the longest audio file duration (capped at MAX_AUDIO_DURATION).""" | |
| max_dur = 0.0 | |
| try: | |
| import soundfile as sf | |
| for f in files[:50]: | |
| try: | |
| info = sf.info(str(f)) | |
| max_dur = max(max_dur, info.duration) | |
| except Exception: | |
| pass | |
| except ImportError: | |
| pass | |
| return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION) | |
| # ============================================================================ | |
| # PREPROCESSING (2-pass sequential) | |
| # ============================================================================ | |
| def preprocess_audio( | |
| audio_dir: str, | |
| output_dir: str, | |
| checkpoint_dir: str, | |
| device: str = "cpu", | |
| variant: str = "base", | |
| max_duration: float = 0, | |
| progress_callback: Optional[Callable] = None, | |
| cancel_check: Optional[Callable] = None, | |
| ) -> Dict[str, Any]: | |
| """2-pass sequential preprocessing. | |
| Pass 1: Load VAE + text encoder, encode audio + text, save intermediates. | |
| Pass 2: Load DIT model, run encoder, build context, save final .pt files. | |
| """ | |
| out = Path(output_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| # Clean orphaned staging files | |
| for orphan in out.glob("*.__writing__"): | |
| try: | |
| orphan.unlink() | |
| except OSError: | |
| pass | |
| audio_files = _discover_audio_files(audio_dir) | |
| if not audio_files: | |
| return {"processed": 0, "failed": 0, "total": 0, "output_dir": str(out)} | |
| total = len(audio_files) | |
| if max_duration <= 0: | |
| max_duration = _detect_max_duration(audio_files) | |
| dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 | |
| # ---- Pass 1: VAE + Text Encoder ---- | |
| logger.info("Pass 1/2: Loading VAE + Text Encoder...") | |
| vae = load_vae(checkpoint_dir, device) | |
| tokenizer, text_enc = load_text_encoder(checkpoint_dir, device) | |
| silence_lat = load_silence_latent(checkpoint_dir, device, variant=variant) | |
| intermediates: List[Path] = [] | |
| p1_failed = 0 | |
| try: | |
| for i, af in enumerate(audio_files): | |
| if cancel_check and cancel_check(): | |
| break | |
| stem = af.stem | |
| final_pt = out / f"{stem}.pt" | |
| if final_pt.exists(): | |
| continue | |
| try: | |
| audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration) | |
| audio = audio.unsqueeze(0).to(device=device, dtype=vae.dtype) | |
| with torch.no_grad(): | |
| target_latents = tiled_vae_encode(vae, audio, dtype) | |
| del audio | |
| if torch.isnan(target_latents).any() or torch.isinf(target_latents).any(): | |
| p1_failed += 1 | |
| del target_latents | |
| continue | |
| lat_len = target_latents.shape[1] | |
| att_mask = torch.ones(1, lat_len, device=device, dtype=dtype) | |
| caption = af.stem | |
| lyrics = "[Instrumental]" | |
| text_prompt = caption | |
| with torch.no_grad(): | |
| text_hs, text_mask = encode_text(text_enc, tokenizer, text_prompt, device, dtype) | |
| lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype) | |
| has_bad = any( | |
| torch.isnan(t).any() or torch.isinf(t).any() | |
| for t in [text_hs, lyric_hs] | |
| ) | |
| if has_bad: | |
| p1_failed += 1 | |
| del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask | |
| continue | |
| tmp_path = out / f"{stem}.tmp.pt" | |
| torch.save({ | |
| "target_latents": target_latents.squeeze(0).cpu(), | |
| "attention_mask": att_mask.squeeze(0).cpu(), | |
| "text_hidden_states": text_hs.cpu(), | |
| "text_attention_mask": text_mask.cpu(), | |
| "lyric_hidden_states": lyric_hs.cpu(), | |
| "lyric_attention_mask": lyric_mask.cpu(), | |
| "silence_latent": silence_lat.cpu(), | |
| "latent_length": lat_len, | |
| "metadata": { | |
| "audio_path": str(af), | |
| "filename": af.name, | |
| "caption": caption, | |
| "lyrics": lyrics, | |
| }, | |
| }, tmp_path) | |
| del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask | |
| intermediates.append(tmp_path) | |
| if progress_callback: | |
| progress_callback(i + 1, total, f"[Pass 1] {af.name}") | |
| except Exception as exc: | |
| p1_failed += 1 | |
| logger.error("Pass 1 FAIL %s: %s", af.name, exc) | |
| finally: | |
| logger.info("Unloading VAE + Text Encoder...") | |
| unload_models(vae, text_enc, tokenizer, silence_lat) | |
| # ---- Pass 2: DIT Encoder ---- | |
| if not intermediates: | |
| return {"processed": 0, "failed": p1_failed, "total": total, "output_dir": str(out)} | |
| logger.info("Pass 2/2: Loading DIT model (variant=%s)...", variant) | |
| model = load_model_for_training(checkpoint_dir, variant, device) | |
| processed = 0 | |
| p2_failed = 0 | |
| p2_total = len(intermediates) | |
| try: | |
| for i, tmp_path in enumerate(intermediates): | |
| if cancel_check and cancel_check(): | |
| break | |
| try: | |
| data = torch.load(str(tmp_path), weights_only=True) | |
| m_device = next(model.parameters()).device | |
| m_dtype = next(model.parameters()).dtype | |
| text_hs = data["text_hidden_states"].to(m_device, dtype=m_dtype) | |
| text_mask = data["text_attention_mask"].to(m_device, dtype=m_dtype) | |
| lyric_hs = data["lyric_hidden_states"].to(m_device, dtype=m_dtype) | |
| lyric_mask = data["lyric_attention_mask"].to(m_device, dtype=m_dtype) | |
| silence_lat = data["silence_latent"].to(m_device, dtype=m_dtype) | |
| lat_len = data["latent_length"] | |
| enc_hs, enc_mask = run_encoder( | |
| model, text_hs, text_mask, lyric_hs, lyric_mask, | |
| str(m_device), m_dtype, | |
| ) | |
| del text_hs, text_mask, lyric_hs, lyric_mask | |
| if silence_lat.dim() == 2: | |
| silence_lat = silence_lat.unsqueeze(0) | |
| ctx = build_context_latents(silence_lat, lat_len, str(m_device), m_dtype) | |
| del silence_lat | |
| has_bad = any( | |
| torch.isnan(t).any() or torch.isinf(t).any() | |
| for t in [enc_hs, ctx] | |
| ) | |
| if has_bad: | |
| p2_failed += 1 | |
| del enc_hs, enc_mask, ctx, data | |
| continue | |
| base_name = tmp_path.name.replace(".tmp.pt", ".pt") | |
| final_path = out / base_name | |
| staging_path = out / (base_name + ".__writing__") | |
| torch.save({ | |
| "target_latents": data["target_latents"], | |
| "attention_mask": data["attention_mask"], | |
| "encoder_hidden_states": enc_hs.squeeze(0).cpu(), | |
| "encoder_attention_mask": enc_mask.squeeze(0).cpu(), | |
| "context_latents": ctx.squeeze(0).cpu(), | |
| "metadata": data.get("metadata", {}), | |
| }, staging_path) | |
| os.replace(staging_path, final_path) | |
| del enc_hs, enc_mask, ctx, data | |
| tmp_path.unlink(missing_ok=True) | |
| processed += 1 | |
| if progress_callback: | |
| progress_callback(i + 1, p2_total, f"[Pass 2] {tmp_path.stem}") | |
| except Exception as exc: | |
| p2_failed += 1 | |
| logger.error("Pass 2 FAIL %s: %s", tmp_path.stem, exc) | |
| finally: | |
| logger.info("Unloading DIT model...") | |
| unload_models(model) | |
| failed = p1_failed + p2_failed | |
| return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)} | |
| # ============================================================================ | |
| # TRAINING LOOP (generator for Gradio compatibility) | |
| # ============================================================================ | |
| def train_lora_generator( | |
| dataset_dir: str, | |
| output_dir: str, | |
| checkpoint_dir: str, | |
| epochs: int = 1000, | |
| lr: float = 3e-4, | |
| rank: int = 64, | |
| alpha: int = 128, | |
| dropout: float = 0.1, | |
| batch_size: int = 1, | |
| gradient_accumulation_steps: int = 4, | |
| warmup_steps: int = 100, | |
| weight_decay: float = 0.01, | |
| max_grad_norm: float = 1.0, | |
| save_every_n_epochs: int = 50, | |
| seed: int = 42, | |
| variant: str = "base", | |
| device: str = "cpu", | |
| cfg_ratio: float = 0.15, | |
| timestep_mu: float = -0.4, | |
| timestep_sigma: float = 1.0, | |
| target_modules: Optional[List[str]] = None, | |
| log_every: int = 10, | |
| resume_from: Optional[str] = None, | |
| ) -> Generator[str, None, None]: | |
| """Run LoRA training, yielding progress strings each epoch. | |
| This is a generator for Gradio live-update compatibility. | |
| Call cancel_training() to stop after the current epoch. | |
| """ | |
| _training_cancel.clear() | |
| train_start = time.time() | |
| if target_modules is None: | |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] | |
| ds_path = Path(dataset_dir) | |
| if not ds_path.is_dir(): | |
| yield f"[FAIL] Dataset directory not found: {ds_path}" | |
| return | |
| out_path = Path(output_dir) | |
| out_path.mkdir(parents=True, exist_ok=True) | |
| yield "[INFO] Loading model..." | |
| try: | |
| model = load_model_for_training(checkpoint_dir, variant, device) | |
| except Exception as exc: | |
| yield f"[FAIL] Model load failed: {exc}" | |
| return | |
| # float32 on CPU (bfloat16 deadlocks) | |
| dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16 | |
| model = model.to(dtype=dtype) | |
| yield "[INFO] Injecting LoRA..." | |
| lora_cfg = LoRAConfig( | |
| r=rank, alpha=alpha, dropout=dropout, | |
| target_modules=target_modules, bias="none", | |
| ) | |
| try: | |
| model, info = inject_lora(model, lora_cfg) | |
| except Exception as exc: | |
| yield f"[FAIL] LoRA injection failed: {exc}" | |
| unload_models(model) | |
| return | |
| yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params" | |
| # Gradient checkpointing + cache disable | |
| force_disable_cache(model.decoder) | |
| ckpt_ok = enable_gradient_checkpointing(model.decoder) | |
| force_input_grads = ckpt_ok | |
| if ckpt_ok: | |
| yield "[INFO] Gradient checkpointing enabled" | |
| # Dataset | |
| dataset = TensorDataset(dataset_dir) | |
| if len(dataset) == 0: | |
| yield "[FAIL] No valid .pt files found in dataset directory" | |
| unload_models(model) | |
| return | |
| yield f"[OK] Loaded {len(dataset)} preprocessed samples" | |
| loader = DataLoader( | |
| dataset, batch_size=batch_size, shuffle=True, | |
| num_workers=0, collate_fn=_collate_batch, drop_last=False, | |
| ) | |
| # Optimizer & scheduler | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| trainable_params = [p for p in model.parameters() if p.requires_grad] | |
| if not trainable_params: | |
| yield "[FAIL] No trainable parameters found" | |
| unload_models(model) | |
| return | |
| optimizer = build_optimizer(trainable_params, lr=lr, weight_decay=weight_decay) | |
| steps_per_epoch = max(1, math.ceil(len(loader) / gradient_accumulation_steps)) | |
| total_steps = steps_per_epoch * epochs | |
| scheduler = build_scheduler(optimizer, total_steps, warmup_steps, lr) | |
| yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs" | |
| yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}" | |
| # Null condition embedding for CFG dropout | |
| null_cond = getattr(model, "null_condition_emb", None) | |
| # Resume checkpoint | |
| start_epoch = 0 | |
| global_step = 0 | |
| if resume_from and Path(resume_from).exists(): | |
| try: | |
| yield f"[INFO] Resuming from {resume_from}" | |
| ckpt_dir = Path(resume_from) | |
| if ckpt_dir.is_file(): | |
| ckpt_dir = ckpt_dir.parent | |
| # Load adapter weights | |
| aw = ckpt_dir / "adapter_model.safetensors" | |
| if aw.exists(): | |
| from safetensors.torch import load_file | |
| state = load_file(str(aw)) | |
| decoder = model.decoder | |
| while hasattr(decoder, "_forward_module"): | |
| decoder = decoder._forward_module | |
| decoder.load_state_dict(state, strict=False) | |
| # Load training state | |
| ts = ckpt_dir / "training_state.pt" | |
| if ts.exists(): | |
| tstate = torch.load(str(ts), map_location=device, weights_only=True) | |
| start_epoch = tstate.get("epoch", 0) | |
| global_step = tstate.get("global_step", 0) | |
| if "optimizer_state_dict" in tstate: | |
| try: | |
| optimizer.load_state_dict(tstate["optimizer_state_dict"]) | |
| except Exception: | |
| pass | |
| if "scheduler_state_dict" in tstate: | |
| try: | |
| scheduler.load_state_dict(tstate["scheduler_state_dict"]) | |
| except Exception: | |
| pass | |
| yield f"[OK] Resumed from epoch {start_epoch}, step {global_step}" | |
| except Exception as exc: | |
| yield f"[WARN] Checkpoint load failed: {exc}, starting fresh" | |
| start_epoch = 0 | |
| global_step = 0 | |
| # Training loop | |
| model.decoder.train() | |
| acc_step = 0 | |
| acc_loss = 0.0 | |
| optimizer.zero_grad(set_to_none=True) | |
| best_loss = float("inf") | |
| best_epoch = 0 | |
| consecutive_nan = 0 | |
| MAX_NAN = 10 | |
| for epoch in range(start_epoch, epochs): | |
| # Cancel check | |
| if _training_cancel.is_set(): | |
| _training_cancel.clear() | |
| early_path = str(out_path / "early_exit") | |
| model.decoder.eval() | |
| save_lora_adapter(model, early_path) | |
| model.decoder.train() | |
| yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}" | |
| yield "[DONE]" | |
| unload_models(model) | |
| return | |
| # Timeout check | |
| elapsed = time.time() - train_start | |
| if elapsed > MAX_TRAINING_TIME: | |
| early_path = str(out_path / "timeout_exit") | |
| model.decoder.eval() | |
| save_lora_adapter(model, early_path) | |
| yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}" | |
| yield "[DONE]" | |
| unload_models(model) | |
| return | |
| epoch_loss = 0.0 | |
| num_updates = 0 | |
| epoch_start = time.time() | |
| for batch in loader: | |
| # Forward | |
| nb = device != "cpu" | |
| tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb) | |
| att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb) | |
| enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb) | |
| enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb) | |
| ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb) | |
| bsz = tgt.shape[0] | |
| # CFG dropout | |
| if null_cond is not None and cfg_ratio > 0: | |
| enc_hs = apply_cfg_dropout(enc_hs, null_cond, cfg_ratio) | |
| # Timestep sampling | |
| t, _r = sample_timesteps(bsz, torch.device(device), dtype, timestep_mu, timestep_sigma) | |
| # Flow matching noise | |
| x1 = torch.randn_like(tgt) | |
| x0 = tgt | |
| t_ = t.unsqueeze(-1).unsqueeze(-1) | |
| xt = t_ * x1 + (1.0 - t_) * x0 | |
| if force_input_grads: | |
| xt = xt.requires_grad_(True) | |
| # Decoder forward | |
| dec_out = model.decoder( | |
| hidden_states=xt, | |
| timestep=t, | |
| timestep_r=t, | |
| attention_mask=att, | |
| encoder_hidden_states=enc_hs, | |
| encoder_attention_mask=enc_mask, | |
| context_latents=ctx, | |
| ) | |
| flow = x1 - x0 | |
| loss = F.mse_loss(dec_out[0], flow) | |
| loss = loss.float() # fp32 for stable backward | |
| # NaN guard | |
| if torch.isnan(loss) or torch.isinf(loss): | |
| consecutive_nan += 1 | |
| del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow | |
| if consecutive_nan >= MAX_NAN: | |
| yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting" | |
| unload_models(model) | |
| return | |
| if acc_step > 0: | |
| optimizer.zero_grad(set_to_none=True) | |
| acc_loss = 0.0 | |
| acc_step = 0 | |
| continue | |
| consecutive_nan = 0 | |
| loss = loss / gradient_accumulation_steps | |
| loss.backward() | |
| acc_loss += loss.item() | |
| del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow | |
| acc_step += 1 | |
| if acc_step >= gradient_accumulation_steps: | |
| torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| global_step += 1 | |
| avg_loss = acc_loss * gradient_accumulation_steps / acc_step | |
| if global_step % log_every == 0: | |
| current_lr = scheduler.get_last_lr()[0] | |
| yield ( | |
| f"Epoch {epoch + 1}/{epochs}, " | |
| f"Step {global_step}, " | |
| f"Loss: {avg_loss:.4f}, " | |
| f"LR: {current_lr:.2e}" | |
| ) | |
| optimizer.zero_grad(set_to_none=True) | |
| epoch_loss += avg_loss | |
| num_updates += 1 | |
| acc_loss = 0.0 | |
| acc_step = 0 | |
| # Flush remainder | |
| if acc_step > 0: | |
| torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| global_step += 1 | |
| avg_loss = acc_loss * gradient_accumulation_steps / acc_step | |
| optimizer.zero_grad(set_to_none=True) | |
| epoch_loss += avg_loss | |
| num_updates += 1 | |
| acc_loss = 0.0 | |
| acc_step = 0 | |
| epoch_time = time.time() - epoch_start | |
| avg_epoch_loss = epoch_loss / max(num_updates, 1) | |
| is_best = avg_epoch_loss < best_loss - 0.001 | |
| if is_best: | |
| best_loss = avg_epoch_loss | |
| best_epoch = epoch + 1 | |
| best_str = f" (best: {best_loss:.4f} @ ep{best_epoch})" if best_epoch > 0 else "" | |
| yield ( | |
| f"[OK] Epoch {epoch + 1}/{epochs} in {epoch_time:.1f}s, " | |
| f"Loss: {avg_epoch_loss:.4f}{best_str}" | |
| ) | |
| # Save best | |
| if is_best and epoch + 1 >= 10: | |
| best_path = str(out_path / "best") | |
| model.decoder.eval() | |
| save_lora_adapter(model, best_path) | |
| model.decoder.train() | |
| yield f"[OK] Best model saved (epoch {epoch + 1}, loss: {best_loss:.4f})" | |
| # Periodic checkpoint | |
| if (epoch + 1) % save_every_n_epochs == 0: | |
| ckpt_path = str(out_path / "checkpoints" / f"epoch_{epoch + 1}") | |
| model.decoder.eval() | |
| save_lora_adapter(model, ckpt_path) | |
| tstate = { | |
| "epoch": epoch + 1, | |
| "global_step": global_step, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| } | |
| os.makedirs(ckpt_path, exist_ok=True) | |
| torch.save(tstate, str(Path(ckpt_path) / "training_state.pt")) | |
| model.decoder.train() | |
| yield f"[OK] Checkpoint saved at epoch {epoch + 1}" | |
| # Sanity check | |
| if global_step == 0: | |
| yield "[FAIL] Training completed 0 steps -- no batches processed" | |
| unload_models(model) | |
| return | |
| # Final save (directly to output_dir, not a subdirectory) | |
| model.decoder.eval() | |
| save_lora_adapter(model, str(out_path)) | |
| final_loss = avg_epoch_loss if num_updates > 0 else 0.0 | |
| best_note = "" | |
| if best_epoch > 0 and Path(out_path / "best").exists(): | |
| best_note = f"\n Best: {out_path / 'best'} (epoch {best_epoch}, loss: {best_loss:.4f})" | |
| yield ( | |
| f"[OK] Training complete! LoRA saved to {out_path}{best_note}\n" | |
| f" Adapter ready for inference." | |
| ) | |
| yield "[DONE]" | |
| unload_models(model) | |
| # ============================================================================ | |
| # ADAPTER LISTING | |
| # ============================================================================ | |
| def get_trained_loras(adapter_dir: str) -> List[str]: | |
| """List all saved LoRA adapter directories under adapter_dir.""" | |
| result = [] | |
| base = Path(adapter_dir) | |
| if not base.is_dir(): | |
| return result | |
| for root, dirs, files in os.walk(str(base)): | |
| for f in files: | |
| if f in ("adapter_config.json", "adapter_model.safetensors", "lora_weights.pt"): | |
| result.append(root) | |
| break | |
| return sorted(set(result)) | |