| """ |
| LoRA Trainer - Handles LoRA training for custom models |
| """ |
|
|
| import torch |
| import torchaudio |
| from pathlib import Path |
| import logging |
| from typing import List, Dict, Any, Optional, Callable |
| import json |
| from datetime import datetime |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LoRATrainer: |
| """Manages LoRA training for ACE-Step model.""" |
| |
| def __init__(self, config: Dict[str, Any]): |
| """ |
| Initialize LoRA trainer. |
| |
| Args: |
| config: Configuration dictionary |
| """ |
| self.config = config |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.training_dir = Path(config.get("training_dir", "lora_training")) |
| self.training_dir.mkdir(exist_ok=True) |
| |
| self.model = None |
| self.lora_config = None |
| |
| logger.info(f"LoRA Trainer initialized on {self.device}") |
| |
| def prepare_dataset(self, audio_files: List[str]) -> List[str]: |
| """ |
| Prepare audio files for training. |
| |
| Args: |
| audio_files: List of audio file paths |
| |
| Returns: |
| List of prepared file paths |
| """ |
| try: |
| logger.info(f"Preparing {len(audio_files)} files for training...") |
| |
| prepared_dir = self.training_dir / "prepared_data" / datetime.now().strftime("%Y%m%d_%H%M%S") |
| prepared_dir.mkdir(parents=True, exist_ok=True) |
| |
| prepared_files = [] |
| |
| for i, file_path in enumerate(audio_files): |
| try: |
| |
| audio, sr = torchaudio.load(file_path) |
| |
| |
| target_sr = self.config.get("sample_rate", 44100) |
| if sr != target_sr: |
| resampler = torchaudio.transforms.Resample(sr, target_sr) |
| audio = resampler(audio) |
| |
| |
| if audio.shape[0] > 1 and self.config.get("force_mono", False): |
| audio = torch.mean(audio, dim=0, keepdim=True) |
| |
| |
| audio = audio / (torch.abs(audio).max() + 1e-8) |
| |
| |
| chunk_duration = self.config.get("chunk_duration", 30) |
| chunk_samples = int(chunk_duration * target_sr) |
| |
| if audio.shape[1] > chunk_samples: |
| |
| num_chunks = audio.shape[1] // chunk_samples |
| for j in range(num_chunks): |
| start = j * chunk_samples |
| end = start + chunk_samples |
| chunk = audio[:, start:end] |
| |
| |
| chunk_path = prepared_dir / f"audio_{i:04d}_chunk_{j:02d}.wav" |
| torchaudio.save( |
| str(chunk_path), |
| chunk, |
| target_sr, |
| encoding="PCM_S", |
| bits_per_sample=16 |
| ) |
| prepared_files.append(str(chunk_path)) |
| else: |
| |
| output_path = prepared_dir / f"audio_{i:04d}.wav" |
| torchaudio.save( |
| str(output_path), |
| audio, |
| target_sr, |
| encoding="PCM_S", |
| bits_per_sample=16 |
| ) |
| prepared_files.append(str(output_path)) |
| |
| except Exception as e: |
| logger.warning(f"Failed to process {file_path}: {e}") |
| continue |
| |
| |
| metadata = { |
| "num_files": len(prepared_files), |
| "original_files": len(audio_files), |
| "sample_rate": target_sr, |
| "prepared_at": datetime.now().isoformat(), |
| "files": prepared_files |
| } |
| |
| metadata_path = prepared_dir / "metadata.json" |
| with open(metadata_path, 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| logger.info(f"✅ Prepared {len(prepared_files)} training files") |
| return prepared_files |
| |
| except Exception as e: |
| logger.error(f"Dataset preparation failed: {e}") |
| raise |
| |
| def initialize_lora(self, rank: int = 16, alpha: int = 32): |
| """ |
| Initialize LoRA configuration. |
| |
| Args: |
| rank: LoRA rank |
| alpha: LoRA alpha |
| """ |
| try: |
| from peft import LoraConfig, get_peft_model |
| |
| self.lora_config = LoraConfig( |
| r=rank, |
| lora_alpha=alpha, |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], |
| lora_dropout=0.1, |
| bias="none", |
| task_type="CAUSAL_LM" |
| ) |
| |
| logger.info(f"✅ LoRA initialized: rank={rank}, alpha={alpha}") |
| |
| except Exception as e: |
| logger.error(f"LoRA initialization failed: {e}") |
| raise |
| |
| def load_lora(self, lora_path: str): |
| """ |
| Load existing LoRA model for continued training. |
| |
| Args: |
| lora_path: Path to LoRA model |
| """ |
| try: |
| from peft import PeftModel |
| from transformers import AutoModel |
| |
| |
| base_model = AutoModel.from_pretrained( |
| self.config.get("model_path", "ACE-Step/ACE-Step-v1-3.5B"), |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32 |
| ) |
| |
| |
| self.model = PeftModel.from_pretrained(base_model, lora_path) |
| |
| logger.info(f"✅ Loaded LoRA from {lora_path}") |
| |
| except Exception as e: |
| logger.error(f"Failed to load LoRA: {e}") |
| raise |
| |
| def train( |
| self, |
| dataset_path: str, |
| model_name: str, |
| learning_rate: float = 1e-4, |
| batch_size: int = 4, |
| num_epochs: int = 10, |
| progress_callback: Optional[Callable] = None |
| ) -> str: |
| """ |
| Train LoRA model. |
| |
| Args: |
| dataset_path: Path to prepared dataset |
| model_name: Name for the trained model |
| learning_rate: Learning rate |
| batch_size: Batch size |
| num_epochs: Number of epochs |
| progress_callback: Optional callback for progress updates |
| |
| Returns: |
| Path to trained model |
| """ |
| try: |
| logger.info(f"Starting LoRA training: {model_name}") |
| |
| |
| dataset = self._load_dataset(dataset_path) |
| |
| |
| if self.model is None: |
| from transformers import AutoModel |
| from peft import get_peft_model |
| |
| base_model = AutoModel.from_pretrained( |
| self.config.get("model_path", "ACE-Step/ACE-Step-v1-3.5B"), |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, |
| device_map="auto" |
| ) |
| |
| self.model = get_peft_model(base_model, self.lora_config) |
| |
| self.model.train() |
| |
| |
| optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=learning_rate, |
| weight_decay=0.01 |
| ) |
| |
| |
| total_steps = (len(dataset) // batch_size) * num_epochs |
| step = 0 |
| |
| for epoch in range(num_epochs): |
| epoch_loss = 0.0 |
| |
| for batch_idx in range(0, len(dataset), batch_size): |
| batch = dataset[batch_idx:batch_idx + batch_size] |
| |
| |
| loss = self._training_step(batch) |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| epoch_loss += loss.item() |
| step += 1 |
| |
| |
| if progress_callback: |
| progress_callback(step, total_steps, loss.item()) |
| |
| avg_loss = epoch_loss / (len(dataset) // batch_size) |
| logger.info(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}") |
| |
| |
| output_dir = self.training_dir / "models" / model_name |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| self.model.save_pretrained(str(output_dir)) |
| |
| |
| info = { |
| "model_name": model_name, |
| "learning_rate": learning_rate, |
| "batch_size": batch_size, |
| "num_epochs": num_epochs, |
| "dataset_size": len(dataset), |
| "trained_at": datetime.now().isoformat(), |
| "lora_config": { |
| "rank": self.lora_config.r, |
| "alpha": self.lora_config.lora_alpha |
| } |
| } |
| |
| info_path = output_dir / "training_info.json" |
| with open(info_path, 'w') as f: |
| json.dump(info, f, indent=2) |
| |
| logger.info(f"✅ Training complete! Model saved to {output_dir}") |
| return str(output_dir) |
| |
| except Exception as e: |
| logger.error(f"Training failed: {e}") |
| raise |
| |
| def _load_dataset(self, dataset_path: str) -> List[Dict[str, Any]]: |
| """Load prepared dataset.""" |
| dataset_path = Path(dataset_path) |
| |
| |
| metadata_path = dataset_path / "metadata.json" |
| if metadata_path.exists(): |
| with open(metadata_path, 'r') as f: |
| metadata = json.load(f) |
| files = metadata.get("files", []) |
| else: |
| |
| files = list(dataset_path.glob("*.wav")) |
| |
| dataset = [] |
| for file_path in files: |
| dataset.append({ |
| "path": str(file_path), |
| "audio": None |
| }) |
| |
| return dataset |
| |
| def _training_step(self, batch: List[Dict[str, Any]]) -> torch.Tensor: |
| """ |
| Perform single training step. |
| |
| This is a simplified placeholder - actual implementation would: |
| 1. Load audio from batch |
| 2. Encode to latent space |
| 3. Generate predictions |
| 4. Calculate loss |
| 5. Return loss |
| |
| Args: |
| batch: Training batch |
| |
| Returns: |
| Loss tensor |
| """ |
| |
| |
| loss = torch.tensor(0.5, requires_grad=True, device=self.device) |
| return loss |
| |
| def export_for_inference(self, lora_path: str, output_path: str): |
| """ |
| Export LoRA model for inference. |
| |
| Args: |
| lora_path: Path to LoRA model |
| output_path: Output path for exported model |
| """ |
| try: |
| |
| self.load_lora(lora_path) |
| |
| |
| merged_model = self.model.merge_and_unload() |
| |
| |
| merged_model.save_pretrained(output_path) |
| |
| logger.info(f"✅ Exported model to {output_path}") |
| |
| except Exception as e: |
| logger.error(f"Export failed: {e}") |
| raise |
|
|