File size: 12,813 Bytes
a602628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
"""
LoRA Trainer - Handles LoRA training for custom models
"""

import torch
import torchaudio
from pathlib import Path
import logging
from typing import List, Dict, Any, Optional, Callable
import json
from datetime import datetime

logger = logging.getLogger(__name__)


class LoRATrainer:
    """Manages LoRA training for ACE-Step model."""
    
    def __init__(self, config: Dict[str, Any]):
        """
        Initialize LoRA trainer.
        
        Args:
            config: Configuration dictionary
        """
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.training_dir = Path(config.get("training_dir", "lora_training"))
        self.training_dir.mkdir(exist_ok=True)
        
        self.model = None
        self.lora_config = None
        
        logger.info(f"LoRA Trainer initialized on {self.device}")
    
    def prepare_dataset(self, audio_files: List[str]) -> List[str]:
        """
        Prepare audio files for training.
        
        Args:
            audio_files: List of audio file paths
            
        Returns:
            List of prepared file paths
        """
        try:
            logger.info(f"Preparing {len(audio_files)} files for training...")
            
            prepared_dir = self.training_dir / "prepared_data" / datetime.now().strftime("%Y%m%d_%H%M%S")
            prepared_dir.mkdir(parents=True, exist_ok=True)
            
            prepared_files = []
            
            for i, file_path in enumerate(audio_files):
                try:
                    # Load audio
                    audio, sr = torchaudio.load(file_path)
                    
                    # Resample to target sample rate if needed
                    target_sr = self.config.get("sample_rate", 44100)
                    if sr != target_sr:
                        resampler = torchaudio.transforms.Resample(sr, target_sr)
                        audio = resampler(audio)
                    
                    # Convert to mono if needed (for some training scenarios)
                    if audio.shape[0] > 1 and self.config.get("force_mono", False):
                        audio = torch.mean(audio, dim=0, keepdim=True)
                    
                    # Normalize
                    audio = audio / (torch.abs(audio).max() + 1e-8)
                    
                    # Split long files into chunks if needed
                    chunk_duration = self.config.get("chunk_duration", 30)  # seconds
                    chunk_samples = int(chunk_duration * target_sr)
                    
                    if audio.shape[1] > chunk_samples:
                        # Split into chunks
                        num_chunks = audio.shape[1] // chunk_samples
                        for j in range(num_chunks):
                            start = j * chunk_samples
                            end = start + chunk_samples
                            chunk = audio[:, start:end]
                            
                            # Save chunk
                            chunk_path = prepared_dir / f"audio_{i:04d}_chunk_{j:02d}.wav"
                            torchaudio.save(
                                str(chunk_path),
                                chunk,
                                target_sr,
                                encoding="PCM_S",
                                bits_per_sample=16
                            )
                            prepared_files.append(str(chunk_path))
                    else:
                        # Save as-is
                        output_path = prepared_dir / f"audio_{i:04d}.wav"
                        torchaudio.save(
                            str(output_path),
                            audio,
                            target_sr,
                            encoding="PCM_S",
                            bits_per_sample=16
                        )
                        prepared_files.append(str(output_path))
                    
                except Exception as e:
                    logger.warning(f"Failed to process {file_path}: {e}")
                    continue
            
            # Save dataset metadata
            metadata = {
                "num_files": len(prepared_files),
                "original_files": len(audio_files),
                "sample_rate": target_sr,
                "prepared_at": datetime.now().isoformat(),
                "files": prepared_files
            }
            
            metadata_path = prepared_dir / "metadata.json"
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)
            
            logger.info(f"✅ Prepared {len(prepared_files)} training files")
            return prepared_files
            
        except Exception as e:
            logger.error(f"Dataset preparation failed: {e}")
            raise
    
    def initialize_lora(self, rank: int = 16, alpha: int = 32):
        """
        Initialize LoRA configuration.
        
        Args:
            rank: LoRA rank
            alpha: LoRA alpha
        """
        try:
            from peft import LoraConfig, get_peft_model
            
            self.lora_config = LoraConfig(
                r=rank,
                lora_alpha=alpha,
                target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],  # Attention layers
                lora_dropout=0.1,
                bias="none",
                task_type="CAUSAL_LM"
            )
            
            logger.info(f"✅ LoRA initialized: rank={rank}, alpha={alpha}")
            
        except Exception as e:
            logger.error(f"LoRA initialization failed: {e}")
            raise
    
    def load_lora(self, lora_path: str):
        """
        Load existing LoRA model for continued training.
        
        Args:
            lora_path: Path to LoRA model
        """
        try:
            from peft import PeftModel
            from transformers import AutoModel
            
            # Load base model
            base_model = AutoModel.from_pretrained(
                self.config.get("model_path", "ACE-Step/ACE-Step-v1-3.5B"),
                torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32
            )
            
            # Load with LoRA
            self.model = PeftModel.from_pretrained(base_model, lora_path)
            
            logger.info(f"✅ Loaded LoRA from {lora_path}")
            
        except Exception as e:
            logger.error(f"Failed to load LoRA: {e}")
            raise
    
    def train(
        self,
        dataset_path: str,
        model_name: str,
        learning_rate: float = 1e-4,
        batch_size: int = 4,
        num_epochs: int = 10,
        progress_callback: Optional[Callable] = None
    ) -> str:
        """
        Train LoRA model.
        
        Args:
            dataset_path: Path to prepared dataset
            model_name: Name for the trained model
            learning_rate: Learning rate
            batch_size: Batch size
            num_epochs: Number of epochs
            progress_callback: Optional callback for progress updates
            
        Returns:
            Path to trained model
        """
        try:
            logger.info(f"Starting LoRA training: {model_name}")
            
            # Load dataset
            dataset = self._load_dataset(dataset_path)
            
            # Load base model if not already loaded
            if self.model is None:
                from transformers import AutoModel
                from peft import get_peft_model
                
                base_model = AutoModel.from_pretrained(
                    self.config.get("model_path", "ACE-Step/ACE-Step-v1-3.5B"),
                    torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
                    device_map="auto"
                )
                
                self.model = get_peft_model(base_model, self.lora_config)
            
            self.model.train()
            
            # Setup optimizer
            optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=learning_rate,
                weight_decay=0.01
            )
            
            # Training loop
            total_steps = (len(dataset) // batch_size) * num_epochs
            step = 0
            
            for epoch in range(num_epochs):
                epoch_loss = 0.0
                
                for batch_idx in range(0, len(dataset), batch_size):
                    batch = dataset[batch_idx:batch_idx + batch_size]
                    
                    # Forward pass (simplified - actual implementation would be more complex)
                    loss = self._training_step(batch)
                    
                    # Backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    epoch_loss += loss.item()
                    step += 1
                    
                    # Progress callback
                    if progress_callback:
                        progress_callback(step, total_steps, loss.item())
                
                avg_loss = epoch_loss / (len(dataset) // batch_size)
                logger.info(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")
            
            # Save trained model
            output_dir = self.training_dir / "models" / model_name
            output_dir.mkdir(parents=True, exist_ok=True)
            
            self.model.save_pretrained(str(output_dir))
            
            # Save training info
            info = {
                "model_name": model_name,
                "learning_rate": learning_rate,
                "batch_size": batch_size,
                "num_epochs": num_epochs,
                "dataset_size": len(dataset),
                "trained_at": datetime.now().isoformat(),
                "lora_config": {
                    "rank": self.lora_config.r,
                    "alpha": self.lora_config.lora_alpha
                }
            }
            
            info_path = output_dir / "training_info.json"
            with open(info_path, 'w') as f:
                json.dump(info, f, indent=2)
            
            logger.info(f"✅ Training complete! Model saved to {output_dir}")
            return str(output_dir)
            
        except Exception as e:
            logger.error(f"Training failed: {e}")
            raise
    
    def _load_dataset(self, dataset_path: str) -> List[Dict[str, Any]]:
        """Load prepared dataset."""
        dataset_path = Path(dataset_path)
        
        # Load metadata
        metadata_path = dataset_path / "metadata.json"
        if metadata_path.exists():
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
            files = metadata.get("files", [])
        else:
            # Scan directory for audio files
            files = list(dataset_path.glob("*.wav"))
        
        dataset = []
        for file_path in files:
            dataset.append({
                "path": str(file_path),
                "audio": None  # Lazy loading
            })
        
        return dataset
    
    def _training_step(self, batch: List[Dict[str, Any]]) -> torch.Tensor:
        """
        Perform single training step.
        
        This is a simplified placeholder - actual implementation would:
        1. Load audio from batch
        2. Encode to latent space
        3. Generate predictions
        4. Calculate loss
        5. Return loss
        
        Args:
            batch: Training batch
            
        Returns:
            Loss tensor
        """
        # Placeholder loss calculation
        # Actual implementation would process audio through model
        loss = torch.tensor(0.5, requires_grad=True, device=self.device)
        return loss
    
    def export_for_inference(self, lora_path: str, output_path: str):
        """
        Export LoRA model for inference.
        
        Args:
            lora_path: Path to LoRA model
            output_path: Output path for exported model
        """
        try:
            # Load LoRA
            self.load_lora(lora_path)
            
            # Merge LoRA with base model
            merged_model = self.model.merge_and_unload()
            
            # Save merged model
            merged_model.save_pretrained(output_path)
            
            logger.info(f"✅ Exported model to {output_path}")
            
        except Exception as e:
            logger.error(f"Export failed: {e}")
            raise