Nathan9 commited on
Commit
103c8f5
·
verified ·
1 Parent(s): cc8f373

Upload 8 files

Browse files
.gitattributes CHANGED
@@ -1,2 +1,5 @@
1
  model-00001-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
2
  tokenizer.model filter=lfs diff=lfs merge=lfs -text
 
 
 
 
1
  model-00001-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
2
  tokenizer.model filter=lfs diff=lfs merge=lfs -text
3
+ Document1.pdf filter=lfs diff=lfs merge=lfs -text
4
+ model-00002-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
5
+ model-00003-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
Document1.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e678058adddf0a03284f79f65242699fee2cf5191b239a9a668ada8be9862e90
3
+ size 6035292
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9a7c7adf0142010ea7fb2d6d60b2698b86f36847d00d0afa4170c3a9fb66a9c
3
+ size 4934842808
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4c9f1d21524ad189e63230a62a62997c52205f9ce3099948c7fc3d27385d0dc
3
+ size 2598483736
prepare_data.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from pathlib import Path
4
+ import librosa
5
+ import taglib
6
+ from tqdm import tqdm
7
+ import logging
8
+ import soundfile as sf
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class MusicDataPreprocessor:
14
+ def __init__(self, input_dir: str, output_dir: str):
15
+ self.input_dir = Path(input_dir)
16
+ self.output_dir = Path(output_dir)
17
+ self.metadata = []
18
+
19
+ # Create necessary directories
20
+ self.output_dir.mkdir(parents=True, exist_ok=True)
21
+ (self.output_dir / "audio").mkdir(exist_ok=True)
22
+ (self.output_dir / "metadata").mkdir(exist_ok=True)
23
+
24
+ def extract_metadata(self, audio_path: Path) -> dict:
25
+ """Extract metadata from audio file (MP3 or WAV)"""
26
+ try:
27
+ # Read audio file metadata
28
+ audio_format = audio_path.suffix.lower()[1:] # Get extension without dot
29
+ audio_file = taglib.File(str(audio_path))
30
+
31
+ # Get basic audio properties
32
+ y, sr = librosa.load(audio_path, sr=16000) # Resample to 16kHz
33
+ duration = librosa.get_duration(y=y, sr=sr)
34
+
35
+ metadata = {
36
+ "filename": audio_path.name,
37
+ "format": audio_format,
38
+ "duration": duration,
39
+ "genre": audio_file.tags.get("GENRE", ["unknown"])[0],
40
+ "title": audio_file.tags.get("TITLE", ["unknown"])[0],
41
+ "artist": audio_file.tags.get("ARTIST", ["unknown"])[0],
42
+ "sample_rate": sr,
43
+ "channels": audio_file.channels
44
+ }
45
+
46
+ return metadata
47
+
48
+ except Exception as e:
49
+ logger.error(f"Error processing {audio_path}: {str(e)}")
50
+ return None
51
+
52
+ def process_files(self):
53
+ """Process all audio files (MP3 and WAV) in the input directory"""
54
+ # Find all MP3 and WAV files
55
+ audio_files = list(self.input_dir.glob("**/*.[mw][pa][3v]")) # Match mp3, wav files
56
+
57
+ formats_found = {"mp3": 0, "wav": 0, "other": 0}
58
+ formats_processed = {"mp3": 0, "wav": 0}
59
+
60
+ logger.info(f"Found {len(audio_files)} audio files to process")
61
+
62
+ for audio_path in tqdm(audio_files, desc="Processing audio files"):
63
+ # Track format statistics
64
+ file_ext = audio_path.suffix.lower()[1:]
65
+ if file_ext == "mp3":
66
+ formats_found["mp3"] += 1
67
+ elif file_ext == "wav":
68
+ formats_found["wav"] += 1
69
+ else:
70
+ formats_found["other"] += 1
71
+ logger.warning(f"Unexpected file format: {file_ext} for file {audio_path}")
72
+
73
+ metadata = self.extract_metadata(audio_path)
74
+
75
+ if metadata:
76
+ # Save processed audio - convert all to WAV
77
+ output_audio_path = self.output_dir / "audio" / f"{audio_path.stem}.wav"
78
+ try:
79
+ y, sr = librosa.load(audio_path, sr=16000, mono=True)
80
+ sf.write(output_audio_path, y, sr, format='WAV')
81
+
82
+ # Track successful processing
83
+ formats_processed[file_ext] += 1
84
+
85
+ # Add path information to metadata
86
+ metadata["processed_path"] = str(output_audio_path.relative_to(self.output_dir))
87
+ self.metadata.append(metadata)
88
+
89
+ except Exception as e:
90
+ logger.error(f"Error saving {audio_path}: {str(e)}")
91
+ continue
92
+
93
+ # Save metadata
94
+ with open(self.output_dir / "metadata" / "dataset_info.json", "w") as f:
95
+ json.dump({
96
+ "files": self.metadata,
97
+ "stats": {
98
+ "total_processed": len(self.metadata),
99
+ "formats_found": formats_found,
100
+ "formats_processed": formats_processed
101
+ }
102
+ }, f, indent=2)
103
+
104
+ logger.info(f"Processed {len(self.metadata)} files successfully")
105
+ logger.info(f"Files found: MP3: {formats_found['mp3']}, WAV: {formats_found['wav']}")
106
+ logger.info(f"Files processed: MP3: {formats_processed['mp3']}, WAV: {formats_processed['wav']}")
107
+
108
+ if __name__ == "__main__":
109
+ import argparse
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument("--input_dir", type=str, required=True, help="Directory containing music files")
112
+ parser.add_argument("--output_dir", type=str, required=True, help="Directory to save processed files")
113
+ args = parser.parse_args()
114
+
115
+ preprocessor = MusicDataPreprocessor(args.input_dir, args.output_dir)
116
+ preprocessor.process_files()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.42.0
3
+ datasets>=2.14.0
4
+ accelerate>=0.27.0
5
+ librosa>=0.10.0
6
+ pytaglib>=2.0.0
7
+ tqdm>=4.65.0
8
+ numpy>=1.24.0
9
+ einops>=0.6.0
10
+ flash-attn>=2.3.0 # Optional, for CUDA acceleration
11
+ safetensors>=0.4.0
12
+ soundfile>=0.12.0
13
+ pydub>=0.25.1 # For better MP3 support
14
+ huggingface_hub>=0.20.3
15
+ tokenizers>=0.15.0
train_hcf.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import logging
5
+ from pathlib import Path
6
+ from dataclasses import dataclass
7
+ from typing import Optional, List, Dict, Tuple, Any
8
+ import transformers
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ TrainingArguments,
13
+ Trainer,
14
+ DataCollatorForLanguageModeling
15
+ )
16
+ from datasets import Dataset, load_dataset
17
+ import numpy as np
18
+ from accelerate import Accelerator
19
+ from safetensors import safe_open
20
+ from safetensors.torch import save_file, load_file
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ @dataclass
26
+ class TensorInfo:
27
+ """Stores metadata about tensor indices and shape"""
28
+ shape: Tuple[int, ...]
29
+ dtype: str
30
+ indices: Optional[torch.Tensor] = None
31
+ hcf_patterns: Optional[Dict] = None
32
+
33
+ class SafeTensorHCFAnalyzer:
34
+ """
35
+ Analyzes HCF patterns in model weights using SafeTensors format.
36
+ Handles efficient loading and analysis of large model weights.
37
+ """
38
+
39
+ def __init__(self, tolerance: float = 1e-5):
40
+ self.tolerance = tolerance
41
+ self.tensor_info = {}
42
+ self.metadata = {}
43
+
44
+ def load_safetensor_file(self,
45
+ filepath: str,
46
+ device: str = 'cpu',
47
+ load_indices: bool = True) -> Dict[str, TensorInfo]:
48
+ """
49
+ Load and parse a SafeTensor file with proper memory management.
50
+
51
+ Args:
52
+ filepath: Path to .safetensors file
53
+ device: Device to load tensors to
54
+ load_indices: Whether to load weight indices
55
+
56
+ Returns:
57
+ Dictionary mapping tensor names to their metadata
58
+ """
59
+ try:
60
+ # First load metadata only to check structure
61
+ with safe_open(filepath, framework="pt") as f:
62
+ self.metadata = json.loads(f.metadata()) if f.metadata() else {}
63
+
64
+ # Load tensors efficiently
65
+ tensors = load_file(filepath, device=device)
66
+
67
+ for tensor_name, tensor in tensors.items():
68
+ self.tensor_info[tensor_name] = TensorInfo(
69
+ shape=tuple(tensor.shape),
70
+ dtype=str(tensor.dtype)
71
+ )
72
+
73
+ # Load indices if available in metadata
74
+ if load_indices and tensor_name in self.metadata:
75
+ if 'indices' in self.metadata[tensor_name]:
76
+ indices_data = self.metadata[tensor_name]['indices']
77
+ if isinstance(indices_data, list):
78
+ self.tensor_info[tensor_name].indices = torch.tensor(
79
+ indices_data, device=device
80
+ )
81
+ elif isinstance(indices_data, str) and os.path.exists(indices_data):
82
+ # Load indices from separate file if provided as path
83
+ self.tensor_info[tensor_name].indices = torch.load(indices_data)
84
+
85
+ return self.tensor_info
86
+
87
+ except Exception as e:
88
+ raise RuntimeError(f"Error loading SafeTensor file: {str(e)}")
89
+
90
+ def analyze_safetensor_weights(self,
91
+ filepath: str,
92
+ batch_size: int = 1000) -> Dict:
93
+ """
94
+ Analyze weights from SafeTensor file in memory-efficient batches.
95
+
96
+ Args:
97
+ filepath: Path to .safetensors file
98
+ batch_size: Number of weights to process at once
99
+
100
+ Returns:
101
+ Analysis results including HCF patterns and optimization opportunities
102
+ """
103
+ results = {
104
+ 'tensor_hcfs': {},
105
+ 'shared_patterns': [],
106
+ 'optimization_suggestions': [],
107
+ 'memory_impact': {}
108
+ }
109
+
110
+ # Process tensors in batches
111
+ with safe_open(filepath, framework="pt") as f:
112
+ for tensor_name in f.keys():
113
+ # Get tensor info
114
+ tensor_data = f.get_tensor(tensor_name)
115
+ tensor_size = np.prod(tensor_data.shape)
116
+
117
+ if tensor_name in self.tensor_info and self.tensor_info[tensor_name].indices is not None:
118
+ indices = self.tensor_info[tensor_name].indices
119
+ unique_indices = torch.unique(indices)
120
+
121
+ # Process each index group
122
+ tensor_hcfs = {}
123
+ for idx in unique_indices:
124
+ mask = (indices == idx)
125
+ indexed_weights = tensor_data[mask]
126
+
127
+ # Process in batches if needed
128
+ if len(indexed_weights) > batch_size:
129
+ hcf = self._process_large_weight_group(indexed_weights, batch_size)
130
+ else:
131
+ hcf = self._calculate_hcf(indexed_weights)
132
+
133
+ tensor_hcfs[idx.item()] = hcf
134
+
135
+ results['tensor_hcfs'][tensor_name] = tensor_hcfs
136
+
137
+ # Find optimization opportunities
138
+ patterns = self._analyze_weight_patterns(tensor_data, indices)
139
+ self.tensor_info[tensor_name].hcf_patterns = patterns
140
+
141
+ # Calculate potential memory savings
142
+ savings = self._estimate_memory_savings(patterns, tensor_data.dtype)
143
+ results['memory_impact'][tensor_name] = {
144
+ 'original_size': tensor_size * tensor_data.element_size(),
145
+ 'potential_savings': savings
146
+ }
147
+
148
+ # Find shared patterns across tensors
149
+ results['shared_patterns'] = self._find_shared_patterns()
150
+ results['optimization_suggestions'] = self._generate_optimization_suggestions(results)
151
+
152
+ return results
153
+
154
+ def _calculate_hcf(self, weights: torch.Tensor) -> float:
155
+ """Calculate HCF for a tensor of weights, with tolerance for floating point"""
156
+ # Implementation placeholder - actual implementation would depend on specific needs
157
+ if len(weights) == 0:
158
+ return 0.0
159
+ return 1.0 # Simplified for example
160
+
161
+ def _gcd_float(self, a: float, b: float) -> float:
162
+ """Calculate greatest common divisor for floating point numbers"""
163
+ # Implementation placeholder
164
+ return min(a, b) # Simplified for example
165
+
166
+ def _process_large_weight_group(self,
167
+ weights: torch.Tensor,
168
+ batch_size: int) -> float:
169
+ """Process large weight groups in batches to manage memory."""
170
+ current_hcf = None
171
+
172
+ for i in range(0, len(weights), batch_size):
173
+ batch = weights[i:i + batch_size]
174
+ batch_hcf = self._calculate_hcf(batch)
175
+
176
+ if current_hcf is None:
177
+ current_hcf = batch_hcf
178
+ elif batch_hcf > self.tolerance:
179
+ current_hcf = self._gcd_float(current_hcf, batch_hcf)
180
+
181
+ return current_hcf if current_hcf is not None else 0.0
182
+
183
+ def _analyze_weight_patterns(self,
184
+ weights: torch.Tensor,
185
+ indices: torch.Tensor) -> Dict:
186
+ """Analyze weight patterns within indexed groups."""
187
+ patterns = {}
188
+ unique_indices = torch.unique(indices)
189
+
190
+ for idx in unique_indices:
191
+ mask = (indices == idx)
192
+ pattern_weights = weights[mask]
193
+
194
+ patterns[idx.item()] = {
195
+ 'mean': float(pattern_weights.mean()),
196
+ 'std': float(pattern_weights.std()),
197
+ 'size': len(pattern_weights),
198
+ 'hcf': self._calculate_hcf(pattern_weights)
199
+ }
200
+
201
+ return patterns
202
+
203
+ def _estimate_memory_savings(self, patterns: Dict, dtype: torch.dtype) -> int:
204
+ """Estimate potential memory savings from patterns"""
205
+ # Implementation placeholder
206
+ return sum(p['size'] for p in patterns.values()) // 2 # Simplified estimate
207
+
208
+ def _find_shared_patterns(self) -> List[Dict]:
209
+ """Find patterns that could be shared across tensors."""
210
+ shared_patterns = []
211
+ pattern_groups = {}
212
+
213
+ for tensor_name, info in self.tensor_info.items():
214
+ if info.hcf_patterns:
215
+ for idx, pattern in info.hcf_patterns.items():
216
+ # Create pattern signature
217
+ signature = f"{pattern['mean']:.4f}_{pattern['std']:.4f}"
218
+
219
+ if signature not in pattern_groups:
220
+ pattern_groups[signature] = []
221
+ pattern_groups[signature].append({
222
+ 'tensor': tensor_name,
223
+ 'index': idx,
224
+ 'pattern': pattern
225
+ })
226
+
227
+ # Find groups with similar patterns
228
+ for signature, group in pattern_groups.items():
229
+ if len(group) > 1:
230
+ shared_patterns.append({
231
+ 'signature': signature,
232
+ 'occurrences': group,
233
+ 'potential_savings': sum(p['pattern']['size'] for p in group[1:])
234
+ })
235
+
236
+ return shared_patterns
237
+
238
+ def _generate_optimization_suggestions(self, results: Dict) -> List[Dict]:
239
+ """Generate optimization suggestions based on analysis"""
240
+ # Implementation placeholder
241
+ suggestions = []
242
+ for tensor_name, impact in results['memory_impact'].items():
243
+ if impact['potential_savings'] > 1000000: # If savings > 1MB
244
+ suggestions.append({
245
+ 'tensor': tensor_name,
246
+ 'suggestion': 'Consider weight quantization',
247
+ 'impact': f"Save {impact['potential_savings'] / 1024 / 1024:.2f}MB"
248
+ })
249
+ return suggestions
250
+
251
+ @dataclass
252
+ class TrainingStatistics:
253
+ """Statistics collected during HCF-aware training"""
254
+ memory_savings: int = 0
255
+ quantization_error: float = 0.0
256
+ convergence_rate: float = 0.0
257
+ epoch: int = 0
258
+ batch_count: int = 0
259
+
260
+ def update(self, batch_stats: Dict[str, Any]):
261
+ """Update statistics with batch results"""
262
+ self.memory_savings += batch_stats.get('memory_savings', 0)
263
+ self.quantization_error = batch_stats.get('quantization_error', self.quantization_error)
264
+ self.convergence_rate = batch_stats.get('convergence_rate', self.convergence_rate)
265
+ self.batch_count += 1
266
+
267
+ class HCFTrainingOptimizer(torch.optim.Adam):
268
+ """
269
+ Optimizer with HCF-awareness for more efficient training
270
+ """
271
+ def __init__(self,
272
+ params,
273
+ lr=0.001,
274
+ betas=(0.9, 0.999),
275
+ eps=1e-8,
276
+ weight_decay=0,
277
+ weight_quantization=True,
278
+ maintain_patterns=True):
279
+ super().__init__(params, lr, betas, eps, weight_decay)
280
+ self.weight_quantization = weight_quantization
281
+ self.maintain_patterns = maintain_patterns
282
+ self.analyzer = SafeTensorHCFAnalyzer()
283
+ self.stats = {'memory_savings': 0, 'quantization_error': 0.0}
284
+
285
+ def step(self, closure=None):
286
+ """Perform optimization step with HCF awareness"""
287
+ # Run standard optimization step
288
+ loss = super().step(closure)
289
+
290
+ # Apply HCF optimizations if enabled
291
+ if self.weight_quantization:
292
+ self._apply_weight_quantization()
293
+
294
+ if self.maintain_patterns:
295
+ self._maintain_weight_patterns()
296
+
297
+ return loss
298
+
299
+ def _apply_weight_quantization(self):
300
+ """Apply dynamic weight quantization using HCF patterns"""
301
+ savings = 0
302
+ total_error = 0.0
303
+
304
+ for group in self.param_groups:
305
+ for p in group['params']:
306
+ if p.grad is None or not p.requires_grad:
307
+ continue
308
+
309
+ # Apply weight quantization logic based on HCF analysis
310
+ # This is a simplified placeholder - real implementation would be more complex
311
+ if p.dim() > 1: # Only apply to matrices/tensors
312
+ # Find suitable quantization factor
313
+ factor = torch.max(torch.abs(p.data)) / 127 # 8-bit quantization example
314
+
315
+ # Quantize weights
316
+ quantized = torch.round(p.data / factor) * factor
317
+
318
+ # Calculate error and savings
319
+ error = torch.mean((p.data - quantized)**2).item()
320
+ savings += p.numel() * (p.element_size() - 1) # Assuming 8-bit savings
321
+
322
+ # Apply quantized weights
323
+ p.data.copy_(quantized)
324
+
325
+ total_error += error
326
+
327
+ # Update statistics
328
+ self.stats['memory_savings'] = savings
329
+ self.stats['quantization_error'] = total_error
330
+
331
+ def _maintain_weight_patterns(self):
332
+ """Maintain efficient weight patterns identified by HCF analysis"""
333
+ # Placeholder for pattern maintenance logic
334
+ # Real implementation would analyze weight matrices and enforce patterns
335
+ pass
336
+
337
+ def get_stats(self):
338
+ """Get current optimization statistics"""
339
+ return self.stats
340
+
341
+ class HCFAwareTrainer:
342
+ """
343
+ Trainer that incorporates HCF analysis for better training efficiency
344
+ """
345
+ def __init__(self, model, optimizer):
346
+ self.model = model
347
+ self.optimizer = optimizer
348
+ self.analyzer = SafeTensorHCFAnalyzer()
349
+
350
+ def train_epoch(self, train_loader, criterion, epoch):
351
+ """Train one epoch with HCF awareness"""
352
+ self.model.train()
353
+ stats = TrainingStatistics(epoch=epoch)
354
+
355
+ for batch_idx, batch in enumerate(train_loader):
356
+ # Get data
357
+ inputs, targets = self._prepare_batch(batch)
358
+
359
+ # Forward pass
360
+ self.optimizer.zero_grad()
361
+ outputs = self.model(inputs)
362
+ loss = criterion(outputs, targets)
363
+
364
+ # Backward pass
365
+ loss.backward()
366
+
367
+ # Optimize with HCF awareness
368
+ self.optimizer.step()
369
+
370
+ # Get batch statistics
371
+ batch_stats = self.optimizer.get_stats()
372
+ stats.update(batch_stats)
373
+
374
+ # Log progress
375
+ if batch_idx % 50 == 0:
376
+ logger.info(f"Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | "
377
+ f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB | "
378
+ f"Quantization Error: {stats.quantization_error:.6f}")
379
+
380
+ # End of epoch analysis
381
+ self._analyze_model_weights()
382
+
383
+ return stats
384
+
385
+ def _prepare_batch(self, batch):
386
+ """Prepare batch data for training"""
387
+ # Implementation depends on dataset structure
388
+ if isinstance(batch, dict):
389
+ inputs = batch.get('input_ids')
390
+ targets = batch.get('labels', inputs)
391
+ else:
392
+ # Assume batch is a tuple of (inputs, targets)
393
+ inputs, targets = batch
394
+
395
+ return inputs, targets
396
+
397
+ def _analyze_model_weights(self):
398
+ """Analyze model weights for patterns and optimizations"""
399
+ # Save model to temporary safetensor file for analysis
400
+ model_path = "temp_model.safetensors"
401
+ tensors = {name: param for name, param in self.model.named_parameters()}
402
+ save_file(tensors, model_path)
403
+
404
+ # Analyze weights
405
+ results = self.analyzer.analyze_safetensor_weights(model_path)
406
+
407
+ # Log findings
408
+ logger.info(f"Weight Analysis: Found {len(results['shared_patterns'])} shared patterns")
409
+ logger.info(f"Potential memory savings: "
410
+ f"{sum(i['potential_savings'] for i in results['memory_impact'].values())/1024/1024:.2f}MB")
411
+
412
+ # Clean up
413
+ if os.path.exists(model_path):
414
+ os.remove(model_path)
415
+
416
+ @dataclass
417
+ class ModelConfig:
418
+ name: str
419
+ model_id: str
420
+ tokenizer_id: str
421
+
422
+ CONFIGS = {
423
+ "7b": ModelConfig(
424
+ name="7b",
425
+ model_id="scrapegoat/ScrapeGoat-Music-Stage1",
426
+ tokenizer_id="scrapegoat/ScrapeGoat-Music-Stage1"
427
+ ),
428
+ "1b": ModelConfig(
429
+ name="1b",
430
+ model_id="scrapegoat/ScrapeGoat-Music-Stage2",
431
+ tokenizer_id="scrapegoat/ScrapeGoat-Music-Stage2"
432
+ )
433
+ }
434
+
435
+ class MusicFineTuner:
436
+ def __init__(
437
+ self,
438
+ model_size: str,
439
+ dataset_path: str,
440
+ output_dir: str,
441
+ device: str = "auto",
442
+ batch_size: int = 4,
443
+ gradient_accumulation_steps: int = 4,
444
+ learning_rate: float = 1e-5,
445
+ num_epochs: int = 3,
446
+ use_hcf: bool = True
447
+ ):
448
+ self.config = CONFIGS[model_size]
449
+ self.dataset_path = Path(dataset_path)
450
+ self.output_dir = Path(output_dir)
451
+ self.device = self._setup_device(device)
452
+ self.use_hcf = use_hcf
453
+ self.training_args = TrainingArguments(
454
+ output_dir=str(self.output_dir),
455
+ per_device_train_batch_size=batch_size,
456
+ gradient_accumulation_steps=gradient_accumulation_steps,
457
+ learning_rate=learning_rate,
458
+ num_train_epochs=num_epochs,
459
+ logging_steps=100,
460
+ save_steps=1000,
461
+ evaluation_strategy="steps",
462
+ eval_steps=500,
463
+ save_total_limit=3,
464
+ load_best_model_at_end=True,
465
+ gradient_checkpointing=True,
466
+ fp16=torch.cuda.is_available(),
467
+ optim="adamw_torch"
468
+ )
469
+
470
+ def _setup_device(self, device: str) -> str:
471
+ if device == "auto":
472
+ if torch.cuda.is_available():
473
+ return "cuda"
474
+ elif torch.backends.mps.is_available():
475
+ return "mps"
476
+ else:
477
+ return "cpu"
478
+ return device
479
+
480
+ def _load_model_and_tokenizer(self):
481
+ logger.info(f"Loading model {self.config.model_id}")
482
+
483
+ # Determine dtype based on device
484
+ dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
485
+
486
+ model = AutoModelForCausalLM.from_pretrained(
487
+ self.config.model_id,
488
+ torch_dtype=dtype,
489
+ device_map="auto" if self.device == "cuda" else None,
490
+ attn_implementation="flash_attention_2" if self.device == "cuda" else "eager"
491
+ )
492
+
493
+ tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_id)
494
+ return model, tokenizer
495
+
496
+ def _prepare_dataset(self, tokenizer):
497
+ logger.info("Preparing dataset")
498
+
499
+ with open(self.dataset_path / "metadata" / "dataset_info.json") as f:
500
+ metadata = json.load(f)
501
+
502
+ def generate_text(item):
503
+ return f"Genre: {item['genre']}\nDuration: {item['duration']:.2f}s\nTitle: {item['title']}\nArtist: {item['artist']}\n"
504
+
505
+ texts = [generate_text(item) for item in metadata["files"]]
506
+ dataset = Dataset.from_dict({"text": texts})
507
+
508
+ def tokenize(examples):
509
+ return tokenizer(
510
+ examples["text"],
511
+ truncation=True,
512
+ padding="max_length",
513
+ max_length=512,
514
+ return_tensors="pt"
515
+ )
516
+
517
+ tokenized_dataset = dataset.map(
518
+ tokenize,
519
+ batched=True,
520
+ remove_columns=dataset.column_names
521
+ )
522
+
523
+ return tokenized_dataset
524
+
525
+ def train(self):
526
+ # Create output directory
527
+ self.output_dir.mkdir(parents=True, exist_ok=True)
528
+
529
+ # Load model and tokenizer
530
+ model, tokenizer = self._load_model_and_tokenizer()
531
+
532
+ # Prepare dataset
533
+ dataset = self._prepare_dataset(tokenizer)
534
+
535
+ # Split dataset
536
+ dataset = dataset.train_test_split(test_size=0.1)
537
+
538
+ if self.use_hcf:
539
+ logger.info("Using HCF-aware training")
540
+ # Create custom HCF optimizer
541
+ optimizer = HCFTrainingOptimizer(
542
+ model.parameters(),
543
+ lr=self.training_args.learning_rate,
544
+ weight_quantization=True,
545
+ maintain_patterns=True
546
+ )
547
+
548
+ # Create HCF trainer
549
+ hcf_trainer = HCFAwareTrainer(model, optimizer)
550
+
551
+ # Create custom training loop
552
+ train_loader = torch.utils.data.DataLoader(
553
+ dataset["train"],
554
+ batch_size=self.training_args.per_device_train_batch_size,
555
+ shuffle=True
556
+ )
557
+
558
+ # Training loop with HCF awareness
559
+ criterion = torch.nn.CrossEntropyLoss()
560
+ for epoch in range(int(self.training_args.num_train_epochs)):
561
+ stats = hcf_trainer.train_epoch(train_loader, criterion, epoch)
562
+
563
+ # Log training metrics
564
+ logger.info(f"Epoch {epoch} completed")
565
+ logger.info(f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB")
566
+ logger.info(f"Quantization Error: {stats.quantization_error:.6f}")
567
+ logger.info(f"Convergence Rate: {stats.convergence_rate:.4f}")
568
+
569
+ # Save checkpoint
570
+ self._save_hcf_checkpoint(model, tokenizer, epoch)
571
+ else:
572
+ # Use standard HuggingFace Trainer
573
+ logger.info("Using standard training")
574
+ trainer = Trainer(
575
+ model=model,
576
+ args=self.training_args,
577
+ train_dataset=dataset["train"],
578
+ eval_dataset=dataset["test"],
579
+ data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
580
+ )
581
+
582
+ # Train
583
+ logger.info("Starting training")
584
+ trainer.train()
585
+
586
+ # Save final model
587
+ logger.info("Saving model")
588
+ model.save_pretrained(str(self.output_dir / "final_model"))
589
+ tokenizer.save_pretrained(str(self.output_dir / "final_model"))
590
+
591
+ def _save_hcf_checkpoint(self, model, tokenizer, epoch):
592
+ """Save checkpoint with HCF metadata"""
593
+ checkpoint_dir = self.output_dir / f"checkpoint-{epoch}"
594
+ checkpoint_dir.mkdir(exist_ok=True)
595
+
596
+ # Save model and tokenizer
597
+ model.save_pretrained(str(checkpoint_dir))
598
+ tokenizer.save_pretrained(str(checkpoint_dir))
599
+
600
+ # Analyze and save HCF metadata
601
+ analyzer = SafeTensorHCFAnalyzer()
602
+
603
+ # Save tensors to analyze
604
+ model_path = str(checkpoint_dir / "model.safetensors")
605
+ if os.path.exists(model_path):
606
+ results = analyzer.analyze_safetensor_weights(model_path)
607
+
608
+ # Save analysis results
609
+ with open(checkpoint_dir / "hcf_analysis.json", "w") as f:
610
+ json.dump(results, f, indent=2)
611
+
612
+ logger.info(f"Saved checkpoint at {checkpoint_dir}")
613
+
614
+ if __name__ == "__main__":
615
+ import argparse
616
+ parser = argparse.ArgumentParser()
617
+ parser.add_argument("--model_size", type=str, choices=["1b", "7b"], required=True)
618
+ parser.add_argument("--dataset_path", type=str, required=True)
619
+ parser.add_argument("--output_dir", type=str, required=True)
620
+ parser.add_argument("--device", type=str, default="auto")
621
+ parser.add_argument("--batch_size", type=int, default=4)
622
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
623
+ parser.add_argument("--learning_rate", type=float, default=1e-5)
624
+ parser.add_argument("--num_epochs", type=int, default=3)
625
+ parser.add_argument("--use_hcf", action="store_true", help="Enable HCF-aware training")
626
+ args = parser.parse_args()
627
+
628
+ fine_tuner = MusicFineTuner(
629
+ model_size=args.model_size,
630
+ dataset_path=args.dataset_path,
631
+ output_dir=args.output_dir,
632
+ device=args.device,
633
+ batch_size=args.batch_size,
634
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
635
+ learning_rate=args.learning_rate,
636
+ num_epochs=args.num_epochs,
637
+ use_hcf=args.use_hcf
638
+ )
639
+ fine_tuner.train()
train_local.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Training script for ScrapeGoat Music models using local model files with HCF optimization.
4
+ Optimized for local training with the models in the provided directory structure.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import json
10
+ import torch
11
+ import logging
12
+ from pathlib import Path
13
+ from dataclasses import dataclass
14
+ from typing import Optional, List, Dict, Tuple, Any
15
+ import transformers
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ TrainingArguments,
20
+ Trainer,
21
+ DataCollatorForLanguageModeling
22
+ )
23
+ from datasets import Dataset
24
+ import numpy as np
25
+ from accelerate import Accelerator
26
+ from safetensors import safe_open
27
+ from safetensors.torch import save_file, load_file
28
+
29
+ # Configure logging
30
+ logging.basicConfig(level=logging.INFO,
31
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Add xcodec_mini_infer to path to access its modules
35
+ SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
36
+ XCODEC_PATH = os.path.join(SCRIPT_DIR, "xcodec_mini_infer")
37
+ sys.path.append(XCODEC_PATH)
38
+
39
+ # Import HCF training components from train_hcf.py
40
+ from train_hcf import (
41
+ TensorInfo,
42
+ SafeTensorHCFAnalyzer,
43
+ TrainingStatistics,
44
+ HCFTrainingOptimizer,
45
+ HCFAwareTrainer
46
+ )
47
+
48
+ @dataclass
49
+ class LocalModelConfig:
50
+ """Configuration for local model directories"""
51
+ model_path: str
52
+ name: str
53
+
54
+ @property
55
+ def model_dir(self) -> str:
56
+ return os.path.abspath(self.model_path)
57
+
58
+ class LocalFineTuner:
59
+ """Fine-tuner that works with local model files"""
60
+
61
+ def __init__(
62
+ self,
63
+ model_config: LocalModelConfig,
64
+ dataset_path: str,
65
+ output_dir: str,
66
+ device: str = "auto",
67
+ batch_size: int = 4,
68
+ gradient_accumulation_steps: int = 4,
69
+ learning_rate: float = 1e-5,
70
+ num_epochs: int = 3,
71
+ use_hcf: bool = True
72
+ ):
73
+ self.model_config = model_config
74
+ self.dataset_path = Path(dataset_path)
75
+ self.output_dir = Path(output_dir)
76
+ self.device = self._setup_device(device)
77
+ self.use_hcf = use_hcf
78
+
79
+ # Ensure output directory exists
80
+ self.output_dir.mkdir(parents=True, exist_ok=True)
81
+
82
+ # Set up training arguments
83
+ self.training_args = TrainingArguments(
84
+ output_dir=str(self.output_dir),
85
+ per_device_train_batch_size=batch_size,
86
+ gradient_accumulation_steps=gradient_accumulation_steps,
87
+ learning_rate=learning_rate,
88
+ num_train_epochs=num_epochs,
89
+ logging_steps=100,
90
+ save_steps=1000,
91
+ evaluation_strategy="steps",
92
+ eval_steps=500,
93
+ save_total_limit=3,
94
+ load_best_model_at_end=True,
95
+ gradient_checkpointing=True,
96
+ fp16=torch.cuda.is_available(),
97
+ optim="adamw_torch"
98
+ )
99
+
100
+ def _setup_device(self, device: str) -> str:
101
+ """Set up the training device"""
102
+ if device == "auto":
103
+ if torch.cuda.is_available():
104
+ return "cuda"
105
+ elif torch.backends.mps.is_available():
106
+ return "mps"
107
+ else:
108
+ return "cpu"
109
+ return device
110
+
111
+ def _load_model_and_tokenizer(self):
112
+ """Load model and tokenizer from local path"""
113
+ logger.info(f"Loading model from {self.model_config.model_dir}")
114
+
115
+ # Determine dtype based on device
116
+ dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
117
+
118
+ # Load model from local path
119
+ model = AutoModelForCausalLM.from_pretrained(
120
+ self.model_config.model_dir,
121
+ torch_dtype=dtype,
122
+ device_map="auto" if self.device == "cuda" else None,
123
+ attn_implementation="flash_attention_2" if self.device == "cuda" else "eager",
124
+ local_files_only=True
125
+ )
126
+
127
+ # Load tokenizer from local path
128
+ tokenizer = AutoTokenizer.from_pretrained(
129
+ self.model_config.model_dir,
130
+ local_files_only=True
131
+ )
132
+
133
+ return model, tokenizer
134
+
135
+ def _prepare_dataset(self, tokenizer):
136
+ """Prepare dataset for training"""
137
+ logger.info("Preparing dataset")
138
+
139
+ # Load metadata
140
+ with open(self.dataset_path / "metadata" / "dataset_info.json") as f:
141
+ metadata = json.load(f)
142
+
143
+ # Define text generation from metadata
144
+ def generate_text(item):
145
+ return f"Genre: {item['genre']}\nDuration: {item['duration']:.2f}s\nTitle: {item['title']}\nArtist: {item['artist']}\n"
146
+
147
+ # Generate text samples
148
+ texts = [generate_text(item) for item in metadata["files"]]
149
+ dataset = Dataset.from_dict({"text": texts})
150
+
151
+ # Tokenize function
152
+ def tokenize(examples):
153
+ return tokenizer(
154
+ examples["text"],
155
+ truncation=True,
156
+ padding="max_length",
157
+ max_length=512,
158
+ return_tensors="pt"
159
+ )
160
+
161
+ # Apply tokenization
162
+ tokenized_dataset = dataset.map(
163
+ tokenize,
164
+ batched=True,
165
+ remove_columns=dataset.column_names
166
+ )
167
+
168
+ return tokenized_dataset
169
+
170
+ def train(self):
171
+ """Train the model with HCF optimization"""
172
+ # Create output directory
173
+ self.output_dir.mkdir(parents=True, exist_ok=True)
174
+
175
+ # Log training configuration
176
+ logger.info(f"Training {self.model_config.name} model with HCF optimization")
177
+ logger.info(f"Model path: {self.model_config.model_dir}")
178
+ logger.info(f"Dataset path: {self.dataset_path}")
179
+ logger.info(f"Output directory: {self.output_dir}")
180
+ logger.info(f"Device: {self.device}")
181
+ logger.info(f"HCF optimization: {'enabled' if self.use_hcf else 'disabled'}")
182
+
183
+ # Load model and tokenizer
184
+ model, tokenizer = self._load_model_and_tokenizer()
185
+
186
+ # Prepare dataset
187
+ dataset = self._prepare_dataset(tokenizer)
188
+
189
+ # Split dataset
190
+ dataset = dataset.train_test_split(test_size=0.1)
191
+
192
+ if self.use_hcf:
193
+ logger.info("Using HCF-aware training")
194
+ # Create custom HCF optimizer
195
+ optimizer = HCFTrainingOptimizer(
196
+ model.parameters(),
197
+ lr=self.training_args.learning_rate,
198
+ weight_quantization=True,
199
+ maintain_patterns=True
200
+ )
201
+
202
+ # Create HCF trainer
203
+ hcf_trainer = HCFAwareTrainer(model, optimizer)
204
+
205
+ # Create custom training loop
206
+ train_loader = torch.utils.data.DataLoader(
207
+ dataset["train"],
208
+ batch_size=self.training_args.per_device_train_batch_size,
209
+ shuffle=True
210
+ )
211
+
212
+ # Training loop with HCF awareness
213
+ criterion = torch.nn.CrossEntropyLoss()
214
+ for epoch in range(int(self.training_args.num_train_epochs)):
215
+ stats = hcf_trainer.train_epoch(train_loader, criterion, epoch)
216
+
217
+ # Log training metrics
218
+ logger.info(f"Epoch {epoch} completed")
219
+ logger.info(f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB")
220
+ logger.info(f"Quantization Error: {stats.quantization_error:.6f}")
221
+ logger.info(f"Convergence Rate: {stats.convergence_rate:.4f}")
222
+
223
+ # Save checkpoint
224
+ self._save_hcf_checkpoint(model, tokenizer, epoch)
225
+ else:
226
+ # Use standard HuggingFace Trainer
227
+ logger.info("Using standard training")
228
+ trainer = Trainer(
229
+ model=model,
230
+ args=self.training_args,
231
+ train_dataset=dataset["train"],
232
+ eval_dataset=dataset["test"],
233
+ data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
234
+ )
235
+
236
+ # Train
237
+ logger.info("Starting training")
238
+ trainer.train()
239
+
240
+ # Save final model
241
+ logger.info("Saving model")
242
+ final_output_dir = self.output_dir / "final_model"
243
+ final_output_dir.mkdir(exist_ok=True)
244
+
245
+ model.save_pretrained(str(final_output_dir))
246
+ tokenizer.save_pretrained(str(final_output_dir))
247
+
248
+ logger.info(f"Training complete. Model saved to {final_output_dir}")
249
+
250
+ def _save_hcf_checkpoint(self, model, tokenizer, epoch):
251
+ """Save checkpoint with HCF metadata"""
252
+ checkpoint_dir = self.output_dir / f"checkpoint-{epoch}"
253
+ checkpoint_dir.mkdir(exist_ok=True)
254
+
255
+ # Save model and tokenizer
256
+ model.save_pretrained(str(checkpoint_dir))
257
+ tokenizer.save_pretrained(str(checkpoint_dir))
258
+
259
+ # Analyze and save HCF metadata
260
+ analyzer = SafeTensorHCFAnalyzer()
261
+
262
+ # Save tensors to analyze
263
+ model_path = str(checkpoint_dir / "model.safetensors")
264
+ if os.path.exists(model_path):
265
+ results = analyzer.analyze_safetensor_weights(model_path)
266
+
267
+ # Save analysis results
268
+ with open(checkpoint_dir / "hcf_analysis.json", "w") as f:
269
+ json.dump(results, f, indent=2)
270
+
271
+ logger.info(f"Saved checkpoint at {checkpoint_dir}")
272
+
273
+ def main():
274
+ """Main function for training"""
275
+ import argparse
276
+ parser = argparse.ArgumentParser(description="Retrain ScrapeGoat Music models with HCF optimization")
277
+ parser.add_argument("--model", type=str, choices=["7b", "1b"], required=True,
278
+ help="Model size to train")
279
+ parser.add_argument("--dataset_path", type=str, required=True,
280
+ help="Path to processed dataset")
281
+ parser.add_argument("--output_dir", type=str, required=True,
282
+ help="Directory to save trained model")
283
+ parser.add_argument("--device", type=str, default="auto",
284
+ help="Device to use (cuda, cpu, mps, or auto)")
285
+ parser.add_argument("--batch_size", type=int, default=4,
286
+ help="Batch size for training")
287
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
288
+ help="Gradient accumulation steps")
289
+ parser.add_argument("--learning_rate", type=float, default=1e-5,
290
+ help="Learning rate")
291
+ parser.add_argument("--num_epochs", type=int, default=3,
292
+ help="Number of training epochs")
293
+ parser.add_argument("--use_hcf", action="store_true", default=True,
294
+ help="Enable HCF optimization")
295
+ parser.add_argument("--base_dir", type=str, default=os.getcwd(),
296
+ help="Base directory containing model folders")
297
+
298
+ args = parser.parse_args()
299
+
300
+ # Set up model configuration based on size
301
+ if args.model == "7b":
302
+ model_path = os.path.join(args.base_dir, "scrapegoat/ScrapeGoat-Music-Stage1")
303
+ model_config = LocalModelConfig(
304
+ model_path=model_path,
305
+ name="ScrapeGoatMusic 7B"
306
+ )
307
+ else:
308
+ model_path = os.path.join(args.base_dir, "scrapegoat/ScrapeGoat-Music-Stage2")
309
+ model_config = LocalModelConfig(
310
+ model_path=model_path,
311
+ name="ScrapeGoatMusic 1B"
312
+ )
313
+
314
+ # Create fine-tuner
315
+ fine_tuner = LocalFineTuner(
316
+ model_config=model_config,
317
+ dataset_path=args.dataset_path,
318
+ output_dir=args.output_dir,
319
+ device=args.device,
320
+ batch_size=args.batch_size,
321
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
322
+ learning_rate=args.learning_rate,
323
+ num_epochs=args.num_epochs,
324
+ use_hcf=args.use_hcf
325
+ )
326
+
327
+ # Train model
328
+ fine_tuner.train()
329
+
330
+ if __name__ == "__main__":
331
+ main()