Upload 8 files
Browse files- .gitattributes +3 -0
- Document1.pdf +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- prepare_data.py +116 -0
- requirements.txt +15 -0
- train_hcf.py +639 -0
- train_local.py +331 -0
.gitattributes
CHANGED
|
@@ -1,2 +1,5 @@
|
|
| 1 |
model-00001-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
tokenizer.model filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
model-00001-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 2 |
tokenizer.model filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
Document1.pdf filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
model-00002-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
model-00003-of-00003.safetensors filter=lfs diff=lfs merge=lfs -text
|
Document1.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e678058adddf0a03284f79f65242699fee2cf5191b239a9a668ada8be9862e90
|
| 3 |
+
size 6035292
|
model-00002-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d9a7c7adf0142010ea7fb2d6d60b2698b86f36847d00d0afa4170c3a9fb66a9c
|
| 3 |
+
size 4934842808
|
model-00003-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4c9f1d21524ad189e63230a62a62997c52205f9ce3099948c7fc3d27385d0dc
|
| 3 |
+
size 2598483736
|
prepare_data.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import librosa
|
| 5 |
+
import taglib
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import logging
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class MusicDataPreprocessor:
|
| 14 |
+
def __init__(self, input_dir: str, output_dir: str):
|
| 15 |
+
self.input_dir = Path(input_dir)
|
| 16 |
+
self.output_dir = Path(output_dir)
|
| 17 |
+
self.metadata = []
|
| 18 |
+
|
| 19 |
+
# Create necessary directories
|
| 20 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
(self.output_dir / "audio").mkdir(exist_ok=True)
|
| 22 |
+
(self.output_dir / "metadata").mkdir(exist_ok=True)
|
| 23 |
+
|
| 24 |
+
def extract_metadata(self, audio_path: Path) -> dict:
|
| 25 |
+
"""Extract metadata from audio file (MP3 or WAV)"""
|
| 26 |
+
try:
|
| 27 |
+
# Read audio file metadata
|
| 28 |
+
audio_format = audio_path.suffix.lower()[1:] # Get extension without dot
|
| 29 |
+
audio_file = taglib.File(str(audio_path))
|
| 30 |
+
|
| 31 |
+
# Get basic audio properties
|
| 32 |
+
y, sr = librosa.load(audio_path, sr=16000) # Resample to 16kHz
|
| 33 |
+
duration = librosa.get_duration(y=y, sr=sr)
|
| 34 |
+
|
| 35 |
+
metadata = {
|
| 36 |
+
"filename": audio_path.name,
|
| 37 |
+
"format": audio_format,
|
| 38 |
+
"duration": duration,
|
| 39 |
+
"genre": audio_file.tags.get("GENRE", ["unknown"])[0],
|
| 40 |
+
"title": audio_file.tags.get("TITLE", ["unknown"])[0],
|
| 41 |
+
"artist": audio_file.tags.get("ARTIST", ["unknown"])[0],
|
| 42 |
+
"sample_rate": sr,
|
| 43 |
+
"channels": audio_file.channels
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
return metadata
|
| 47 |
+
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Error processing {audio_path}: {str(e)}")
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
def process_files(self):
|
| 53 |
+
"""Process all audio files (MP3 and WAV) in the input directory"""
|
| 54 |
+
# Find all MP3 and WAV files
|
| 55 |
+
audio_files = list(self.input_dir.glob("**/*.[mw][pa][3v]")) # Match mp3, wav files
|
| 56 |
+
|
| 57 |
+
formats_found = {"mp3": 0, "wav": 0, "other": 0}
|
| 58 |
+
formats_processed = {"mp3": 0, "wav": 0}
|
| 59 |
+
|
| 60 |
+
logger.info(f"Found {len(audio_files)} audio files to process")
|
| 61 |
+
|
| 62 |
+
for audio_path in tqdm(audio_files, desc="Processing audio files"):
|
| 63 |
+
# Track format statistics
|
| 64 |
+
file_ext = audio_path.suffix.lower()[1:]
|
| 65 |
+
if file_ext == "mp3":
|
| 66 |
+
formats_found["mp3"] += 1
|
| 67 |
+
elif file_ext == "wav":
|
| 68 |
+
formats_found["wav"] += 1
|
| 69 |
+
else:
|
| 70 |
+
formats_found["other"] += 1
|
| 71 |
+
logger.warning(f"Unexpected file format: {file_ext} for file {audio_path}")
|
| 72 |
+
|
| 73 |
+
metadata = self.extract_metadata(audio_path)
|
| 74 |
+
|
| 75 |
+
if metadata:
|
| 76 |
+
# Save processed audio - convert all to WAV
|
| 77 |
+
output_audio_path = self.output_dir / "audio" / f"{audio_path.stem}.wav"
|
| 78 |
+
try:
|
| 79 |
+
y, sr = librosa.load(audio_path, sr=16000, mono=True)
|
| 80 |
+
sf.write(output_audio_path, y, sr, format='WAV')
|
| 81 |
+
|
| 82 |
+
# Track successful processing
|
| 83 |
+
formats_processed[file_ext] += 1
|
| 84 |
+
|
| 85 |
+
# Add path information to metadata
|
| 86 |
+
metadata["processed_path"] = str(output_audio_path.relative_to(self.output_dir))
|
| 87 |
+
self.metadata.append(metadata)
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"Error saving {audio_path}: {str(e)}")
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# Save metadata
|
| 94 |
+
with open(self.output_dir / "metadata" / "dataset_info.json", "w") as f:
|
| 95 |
+
json.dump({
|
| 96 |
+
"files": self.metadata,
|
| 97 |
+
"stats": {
|
| 98 |
+
"total_processed": len(self.metadata),
|
| 99 |
+
"formats_found": formats_found,
|
| 100 |
+
"formats_processed": formats_processed
|
| 101 |
+
}
|
| 102 |
+
}, f, indent=2)
|
| 103 |
+
|
| 104 |
+
logger.info(f"Processed {len(self.metadata)} files successfully")
|
| 105 |
+
logger.info(f"Files found: MP3: {formats_found['mp3']}, WAV: {formats_found['wav']}")
|
| 106 |
+
logger.info(f"Files processed: MP3: {formats_processed['mp3']}, WAV: {formats_processed['wav']}")
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
import argparse
|
| 110 |
+
parser = argparse.ArgumentParser()
|
| 111 |
+
parser.add_argument("--input_dir", type=str, required=True, help="Directory containing music files")
|
| 112 |
+
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save processed files")
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
preprocessor = MusicDataPreprocessor(args.input_dir, args.output_dir)
|
| 116 |
+
preprocessor.process_files()
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers>=4.42.0
|
| 3 |
+
datasets>=2.14.0
|
| 4 |
+
accelerate>=0.27.0
|
| 5 |
+
librosa>=0.10.0
|
| 6 |
+
pytaglib>=2.0.0
|
| 7 |
+
tqdm>=4.65.0
|
| 8 |
+
numpy>=1.24.0
|
| 9 |
+
einops>=0.6.0
|
| 10 |
+
flash-attn>=2.3.0 # Optional, for CUDA acceleration
|
| 11 |
+
safetensors>=0.4.0
|
| 12 |
+
soundfile>=0.12.0
|
| 13 |
+
pydub>=0.25.1 # For better MP3 support
|
| 14 |
+
huggingface_hub>=0.20.3
|
| 15 |
+
tokenizers>=0.15.0
|
train_hcf.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import logging
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Optional, List, Dict, Tuple, Any
|
| 8 |
+
import transformers
|
| 9 |
+
from transformers import (
|
| 10 |
+
AutoModelForCausalLM,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
TrainingArguments,
|
| 13 |
+
Trainer,
|
| 14 |
+
DataCollatorForLanguageModeling
|
| 15 |
+
)
|
| 16 |
+
from datasets import Dataset, load_dataset
|
| 17 |
+
import numpy as np
|
| 18 |
+
from accelerate import Accelerator
|
| 19 |
+
from safetensors import safe_open
|
| 20 |
+
from safetensors.torch import save_file, load_file
|
| 21 |
+
|
| 22 |
+
logging.basicConfig(level=logging.INFO)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class TensorInfo:
|
| 27 |
+
"""Stores metadata about tensor indices and shape"""
|
| 28 |
+
shape: Tuple[int, ...]
|
| 29 |
+
dtype: str
|
| 30 |
+
indices: Optional[torch.Tensor] = None
|
| 31 |
+
hcf_patterns: Optional[Dict] = None
|
| 32 |
+
|
| 33 |
+
class SafeTensorHCFAnalyzer:
|
| 34 |
+
"""
|
| 35 |
+
Analyzes HCF patterns in model weights using SafeTensors format.
|
| 36 |
+
Handles efficient loading and analysis of large model weights.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, tolerance: float = 1e-5):
|
| 40 |
+
self.tolerance = tolerance
|
| 41 |
+
self.tensor_info = {}
|
| 42 |
+
self.metadata = {}
|
| 43 |
+
|
| 44 |
+
def load_safetensor_file(self,
|
| 45 |
+
filepath: str,
|
| 46 |
+
device: str = 'cpu',
|
| 47 |
+
load_indices: bool = True) -> Dict[str, TensorInfo]:
|
| 48 |
+
"""
|
| 49 |
+
Load and parse a SafeTensor file with proper memory management.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
filepath: Path to .safetensors file
|
| 53 |
+
device: Device to load tensors to
|
| 54 |
+
load_indices: Whether to load weight indices
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Dictionary mapping tensor names to their metadata
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
# First load metadata only to check structure
|
| 61 |
+
with safe_open(filepath, framework="pt") as f:
|
| 62 |
+
self.metadata = json.loads(f.metadata()) if f.metadata() else {}
|
| 63 |
+
|
| 64 |
+
# Load tensors efficiently
|
| 65 |
+
tensors = load_file(filepath, device=device)
|
| 66 |
+
|
| 67 |
+
for tensor_name, tensor in tensors.items():
|
| 68 |
+
self.tensor_info[tensor_name] = TensorInfo(
|
| 69 |
+
shape=tuple(tensor.shape),
|
| 70 |
+
dtype=str(tensor.dtype)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Load indices if available in metadata
|
| 74 |
+
if load_indices and tensor_name in self.metadata:
|
| 75 |
+
if 'indices' in self.metadata[tensor_name]:
|
| 76 |
+
indices_data = self.metadata[tensor_name]['indices']
|
| 77 |
+
if isinstance(indices_data, list):
|
| 78 |
+
self.tensor_info[tensor_name].indices = torch.tensor(
|
| 79 |
+
indices_data, device=device
|
| 80 |
+
)
|
| 81 |
+
elif isinstance(indices_data, str) and os.path.exists(indices_data):
|
| 82 |
+
# Load indices from separate file if provided as path
|
| 83 |
+
self.tensor_info[tensor_name].indices = torch.load(indices_data)
|
| 84 |
+
|
| 85 |
+
return self.tensor_info
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
raise RuntimeError(f"Error loading SafeTensor file: {str(e)}")
|
| 89 |
+
|
| 90 |
+
def analyze_safetensor_weights(self,
|
| 91 |
+
filepath: str,
|
| 92 |
+
batch_size: int = 1000) -> Dict:
|
| 93 |
+
"""
|
| 94 |
+
Analyze weights from SafeTensor file in memory-efficient batches.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
filepath: Path to .safetensors file
|
| 98 |
+
batch_size: Number of weights to process at once
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Analysis results including HCF patterns and optimization opportunities
|
| 102 |
+
"""
|
| 103 |
+
results = {
|
| 104 |
+
'tensor_hcfs': {},
|
| 105 |
+
'shared_patterns': [],
|
| 106 |
+
'optimization_suggestions': [],
|
| 107 |
+
'memory_impact': {}
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
# Process tensors in batches
|
| 111 |
+
with safe_open(filepath, framework="pt") as f:
|
| 112 |
+
for tensor_name in f.keys():
|
| 113 |
+
# Get tensor info
|
| 114 |
+
tensor_data = f.get_tensor(tensor_name)
|
| 115 |
+
tensor_size = np.prod(tensor_data.shape)
|
| 116 |
+
|
| 117 |
+
if tensor_name in self.tensor_info and self.tensor_info[tensor_name].indices is not None:
|
| 118 |
+
indices = self.tensor_info[tensor_name].indices
|
| 119 |
+
unique_indices = torch.unique(indices)
|
| 120 |
+
|
| 121 |
+
# Process each index group
|
| 122 |
+
tensor_hcfs = {}
|
| 123 |
+
for idx in unique_indices:
|
| 124 |
+
mask = (indices == idx)
|
| 125 |
+
indexed_weights = tensor_data[mask]
|
| 126 |
+
|
| 127 |
+
# Process in batches if needed
|
| 128 |
+
if len(indexed_weights) > batch_size:
|
| 129 |
+
hcf = self._process_large_weight_group(indexed_weights, batch_size)
|
| 130 |
+
else:
|
| 131 |
+
hcf = self._calculate_hcf(indexed_weights)
|
| 132 |
+
|
| 133 |
+
tensor_hcfs[idx.item()] = hcf
|
| 134 |
+
|
| 135 |
+
results['tensor_hcfs'][tensor_name] = tensor_hcfs
|
| 136 |
+
|
| 137 |
+
# Find optimization opportunities
|
| 138 |
+
patterns = self._analyze_weight_patterns(tensor_data, indices)
|
| 139 |
+
self.tensor_info[tensor_name].hcf_patterns = patterns
|
| 140 |
+
|
| 141 |
+
# Calculate potential memory savings
|
| 142 |
+
savings = self._estimate_memory_savings(patterns, tensor_data.dtype)
|
| 143 |
+
results['memory_impact'][tensor_name] = {
|
| 144 |
+
'original_size': tensor_size * tensor_data.element_size(),
|
| 145 |
+
'potential_savings': savings
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
# Find shared patterns across tensors
|
| 149 |
+
results['shared_patterns'] = self._find_shared_patterns()
|
| 150 |
+
results['optimization_suggestions'] = self._generate_optimization_suggestions(results)
|
| 151 |
+
|
| 152 |
+
return results
|
| 153 |
+
|
| 154 |
+
def _calculate_hcf(self, weights: torch.Tensor) -> float:
|
| 155 |
+
"""Calculate HCF for a tensor of weights, with tolerance for floating point"""
|
| 156 |
+
# Implementation placeholder - actual implementation would depend on specific needs
|
| 157 |
+
if len(weights) == 0:
|
| 158 |
+
return 0.0
|
| 159 |
+
return 1.0 # Simplified for example
|
| 160 |
+
|
| 161 |
+
def _gcd_float(self, a: float, b: float) -> float:
|
| 162 |
+
"""Calculate greatest common divisor for floating point numbers"""
|
| 163 |
+
# Implementation placeholder
|
| 164 |
+
return min(a, b) # Simplified for example
|
| 165 |
+
|
| 166 |
+
def _process_large_weight_group(self,
|
| 167 |
+
weights: torch.Tensor,
|
| 168 |
+
batch_size: int) -> float:
|
| 169 |
+
"""Process large weight groups in batches to manage memory."""
|
| 170 |
+
current_hcf = None
|
| 171 |
+
|
| 172 |
+
for i in range(0, len(weights), batch_size):
|
| 173 |
+
batch = weights[i:i + batch_size]
|
| 174 |
+
batch_hcf = self._calculate_hcf(batch)
|
| 175 |
+
|
| 176 |
+
if current_hcf is None:
|
| 177 |
+
current_hcf = batch_hcf
|
| 178 |
+
elif batch_hcf > self.tolerance:
|
| 179 |
+
current_hcf = self._gcd_float(current_hcf, batch_hcf)
|
| 180 |
+
|
| 181 |
+
return current_hcf if current_hcf is not None else 0.0
|
| 182 |
+
|
| 183 |
+
def _analyze_weight_patterns(self,
|
| 184 |
+
weights: torch.Tensor,
|
| 185 |
+
indices: torch.Tensor) -> Dict:
|
| 186 |
+
"""Analyze weight patterns within indexed groups."""
|
| 187 |
+
patterns = {}
|
| 188 |
+
unique_indices = torch.unique(indices)
|
| 189 |
+
|
| 190 |
+
for idx in unique_indices:
|
| 191 |
+
mask = (indices == idx)
|
| 192 |
+
pattern_weights = weights[mask]
|
| 193 |
+
|
| 194 |
+
patterns[idx.item()] = {
|
| 195 |
+
'mean': float(pattern_weights.mean()),
|
| 196 |
+
'std': float(pattern_weights.std()),
|
| 197 |
+
'size': len(pattern_weights),
|
| 198 |
+
'hcf': self._calculate_hcf(pattern_weights)
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
return patterns
|
| 202 |
+
|
| 203 |
+
def _estimate_memory_savings(self, patterns: Dict, dtype: torch.dtype) -> int:
|
| 204 |
+
"""Estimate potential memory savings from patterns"""
|
| 205 |
+
# Implementation placeholder
|
| 206 |
+
return sum(p['size'] for p in patterns.values()) // 2 # Simplified estimate
|
| 207 |
+
|
| 208 |
+
def _find_shared_patterns(self) -> List[Dict]:
|
| 209 |
+
"""Find patterns that could be shared across tensors."""
|
| 210 |
+
shared_patterns = []
|
| 211 |
+
pattern_groups = {}
|
| 212 |
+
|
| 213 |
+
for tensor_name, info in self.tensor_info.items():
|
| 214 |
+
if info.hcf_patterns:
|
| 215 |
+
for idx, pattern in info.hcf_patterns.items():
|
| 216 |
+
# Create pattern signature
|
| 217 |
+
signature = f"{pattern['mean']:.4f}_{pattern['std']:.4f}"
|
| 218 |
+
|
| 219 |
+
if signature not in pattern_groups:
|
| 220 |
+
pattern_groups[signature] = []
|
| 221 |
+
pattern_groups[signature].append({
|
| 222 |
+
'tensor': tensor_name,
|
| 223 |
+
'index': idx,
|
| 224 |
+
'pattern': pattern
|
| 225 |
+
})
|
| 226 |
+
|
| 227 |
+
# Find groups with similar patterns
|
| 228 |
+
for signature, group in pattern_groups.items():
|
| 229 |
+
if len(group) > 1:
|
| 230 |
+
shared_patterns.append({
|
| 231 |
+
'signature': signature,
|
| 232 |
+
'occurrences': group,
|
| 233 |
+
'potential_savings': sum(p['pattern']['size'] for p in group[1:])
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
return shared_patterns
|
| 237 |
+
|
| 238 |
+
def _generate_optimization_suggestions(self, results: Dict) -> List[Dict]:
|
| 239 |
+
"""Generate optimization suggestions based on analysis"""
|
| 240 |
+
# Implementation placeholder
|
| 241 |
+
suggestions = []
|
| 242 |
+
for tensor_name, impact in results['memory_impact'].items():
|
| 243 |
+
if impact['potential_savings'] > 1000000: # If savings > 1MB
|
| 244 |
+
suggestions.append({
|
| 245 |
+
'tensor': tensor_name,
|
| 246 |
+
'suggestion': 'Consider weight quantization',
|
| 247 |
+
'impact': f"Save {impact['potential_savings'] / 1024 / 1024:.2f}MB"
|
| 248 |
+
})
|
| 249 |
+
return suggestions
|
| 250 |
+
|
| 251 |
+
@dataclass
|
| 252 |
+
class TrainingStatistics:
|
| 253 |
+
"""Statistics collected during HCF-aware training"""
|
| 254 |
+
memory_savings: int = 0
|
| 255 |
+
quantization_error: float = 0.0
|
| 256 |
+
convergence_rate: float = 0.0
|
| 257 |
+
epoch: int = 0
|
| 258 |
+
batch_count: int = 0
|
| 259 |
+
|
| 260 |
+
def update(self, batch_stats: Dict[str, Any]):
|
| 261 |
+
"""Update statistics with batch results"""
|
| 262 |
+
self.memory_savings += batch_stats.get('memory_savings', 0)
|
| 263 |
+
self.quantization_error = batch_stats.get('quantization_error', self.quantization_error)
|
| 264 |
+
self.convergence_rate = batch_stats.get('convergence_rate', self.convergence_rate)
|
| 265 |
+
self.batch_count += 1
|
| 266 |
+
|
| 267 |
+
class HCFTrainingOptimizer(torch.optim.Adam):
|
| 268 |
+
"""
|
| 269 |
+
Optimizer with HCF-awareness for more efficient training
|
| 270 |
+
"""
|
| 271 |
+
def __init__(self,
|
| 272 |
+
params,
|
| 273 |
+
lr=0.001,
|
| 274 |
+
betas=(0.9, 0.999),
|
| 275 |
+
eps=1e-8,
|
| 276 |
+
weight_decay=0,
|
| 277 |
+
weight_quantization=True,
|
| 278 |
+
maintain_patterns=True):
|
| 279 |
+
super().__init__(params, lr, betas, eps, weight_decay)
|
| 280 |
+
self.weight_quantization = weight_quantization
|
| 281 |
+
self.maintain_patterns = maintain_patterns
|
| 282 |
+
self.analyzer = SafeTensorHCFAnalyzer()
|
| 283 |
+
self.stats = {'memory_savings': 0, 'quantization_error': 0.0}
|
| 284 |
+
|
| 285 |
+
def step(self, closure=None):
|
| 286 |
+
"""Perform optimization step with HCF awareness"""
|
| 287 |
+
# Run standard optimization step
|
| 288 |
+
loss = super().step(closure)
|
| 289 |
+
|
| 290 |
+
# Apply HCF optimizations if enabled
|
| 291 |
+
if self.weight_quantization:
|
| 292 |
+
self._apply_weight_quantization()
|
| 293 |
+
|
| 294 |
+
if self.maintain_patterns:
|
| 295 |
+
self._maintain_weight_patterns()
|
| 296 |
+
|
| 297 |
+
return loss
|
| 298 |
+
|
| 299 |
+
def _apply_weight_quantization(self):
|
| 300 |
+
"""Apply dynamic weight quantization using HCF patterns"""
|
| 301 |
+
savings = 0
|
| 302 |
+
total_error = 0.0
|
| 303 |
+
|
| 304 |
+
for group in self.param_groups:
|
| 305 |
+
for p in group['params']:
|
| 306 |
+
if p.grad is None or not p.requires_grad:
|
| 307 |
+
continue
|
| 308 |
+
|
| 309 |
+
# Apply weight quantization logic based on HCF analysis
|
| 310 |
+
# This is a simplified placeholder - real implementation would be more complex
|
| 311 |
+
if p.dim() > 1: # Only apply to matrices/tensors
|
| 312 |
+
# Find suitable quantization factor
|
| 313 |
+
factor = torch.max(torch.abs(p.data)) / 127 # 8-bit quantization example
|
| 314 |
+
|
| 315 |
+
# Quantize weights
|
| 316 |
+
quantized = torch.round(p.data / factor) * factor
|
| 317 |
+
|
| 318 |
+
# Calculate error and savings
|
| 319 |
+
error = torch.mean((p.data - quantized)**2).item()
|
| 320 |
+
savings += p.numel() * (p.element_size() - 1) # Assuming 8-bit savings
|
| 321 |
+
|
| 322 |
+
# Apply quantized weights
|
| 323 |
+
p.data.copy_(quantized)
|
| 324 |
+
|
| 325 |
+
total_error += error
|
| 326 |
+
|
| 327 |
+
# Update statistics
|
| 328 |
+
self.stats['memory_savings'] = savings
|
| 329 |
+
self.stats['quantization_error'] = total_error
|
| 330 |
+
|
| 331 |
+
def _maintain_weight_patterns(self):
|
| 332 |
+
"""Maintain efficient weight patterns identified by HCF analysis"""
|
| 333 |
+
# Placeholder for pattern maintenance logic
|
| 334 |
+
# Real implementation would analyze weight matrices and enforce patterns
|
| 335 |
+
pass
|
| 336 |
+
|
| 337 |
+
def get_stats(self):
|
| 338 |
+
"""Get current optimization statistics"""
|
| 339 |
+
return self.stats
|
| 340 |
+
|
| 341 |
+
class HCFAwareTrainer:
|
| 342 |
+
"""
|
| 343 |
+
Trainer that incorporates HCF analysis for better training efficiency
|
| 344 |
+
"""
|
| 345 |
+
def __init__(self, model, optimizer):
|
| 346 |
+
self.model = model
|
| 347 |
+
self.optimizer = optimizer
|
| 348 |
+
self.analyzer = SafeTensorHCFAnalyzer()
|
| 349 |
+
|
| 350 |
+
def train_epoch(self, train_loader, criterion, epoch):
|
| 351 |
+
"""Train one epoch with HCF awareness"""
|
| 352 |
+
self.model.train()
|
| 353 |
+
stats = TrainingStatistics(epoch=epoch)
|
| 354 |
+
|
| 355 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 356 |
+
# Get data
|
| 357 |
+
inputs, targets = self._prepare_batch(batch)
|
| 358 |
+
|
| 359 |
+
# Forward pass
|
| 360 |
+
self.optimizer.zero_grad()
|
| 361 |
+
outputs = self.model(inputs)
|
| 362 |
+
loss = criterion(outputs, targets)
|
| 363 |
+
|
| 364 |
+
# Backward pass
|
| 365 |
+
loss.backward()
|
| 366 |
+
|
| 367 |
+
# Optimize with HCF awareness
|
| 368 |
+
self.optimizer.step()
|
| 369 |
+
|
| 370 |
+
# Get batch statistics
|
| 371 |
+
batch_stats = self.optimizer.get_stats()
|
| 372 |
+
stats.update(batch_stats)
|
| 373 |
+
|
| 374 |
+
# Log progress
|
| 375 |
+
if batch_idx % 50 == 0:
|
| 376 |
+
logger.info(f"Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | "
|
| 377 |
+
f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB | "
|
| 378 |
+
f"Quantization Error: {stats.quantization_error:.6f}")
|
| 379 |
+
|
| 380 |
+
# End of epoch analysis
|
| 381 |
+
self._analyze_model_weights()
|
| 382 |
+
|
| 383 |
+
return stats
|
| 384 |
+
|
| 385 |
+
def _prepare_batch(self, batch):
|
| 386 |
+
"""Prepare batch data for training"""
|
| 387 |
+
# Implementation depends on dataset structure
|
| 388 |
+
if isinstance(batch, dict):
|
| 389 |
+
inputs = batch.get('input_ids')
|
| 390 |
+
targets = batch.get('labels', inputs)
|
| 391 |
+
else:
|
| 392 |
+
# Assume batch is a tuple of (inputs, targets)
|
| 393 |
+
inputs, targets = batch
|
| 394 |
+
|
| 395 |
+
return inputs, targets
|
| 396 |
+
|
| 397 |
+
def _analyze_model_weights(self):
|
| 398 |
+
"""Analyze model weights for patterns and optimizations"""
|
| 399 |
+
# Save model to temporary safetensor file for analysis
|
| 400 |
+
model_path = "temp_model.safetensors"
|
| 401 |
+
tensors = {name: param for name, param in self.model.named_parameters()}
|
| 402 |
+
save_file(tensors, model_path)
|
| 403 |
+
|
| 404 |
+
# Analyze weights
|
| 405 |
+
results = self.analyzer.analyze_safetensor_weights(model_path)
|
| 406 |
+
|
| 407 |
+
# Log findings
|
| 408 |
+
logger.info(f"Weight Analysis: Found {len(results['shared_patterns'])} shared patterns")
|
| 409 |
+
logger.info(f"Potential memory savings: "
|
| 410 |
+
f"{sum(i['potential_savings'] for i in results['memory_impact'].values())/1024/1024:.2f}MB")
|
| 411 |
+
|
| 412 |
+
# Clean up
|
| 413 |
+
if os.path.exists(model_path):
|
| 414 |
+
os.remove(model_path)
|
| 415 |
+
|
| 416 |
+
@dataclass
|
| 417 |
+
class ModelConfig:
|
| 418 |
+
name: str
|
| 419 |
+
model_id: str
|
| 420 |
+
tokenizer_id: str
|
| 421 |
+
|
| 422 |
+
CONFIGS = {
|
| 423 |
+
"7b": ModelConfig(
|
| 424 |
+
name="7b",
|
| 425 |
+
model_id="scrapegoat/ScrapeGoat-Music-Stage1",
|
| 426 |
+
tokenizer_id="scrapegoat/ScrapeGoat-Music-Stage1"
|
| 427 |
+
),
|
| 428 |
+
"1b": ModelConfig(
|
| 429 |
+
name="1b",
|
| 430 |
+
model_id="scrapegoat/ScrapeGoat-Music-Stage2",
|
| 431 |
+
tokenizer_id="scrapegoat/ScrapeGoat-Music-Stage2"
|
| 432 |
+
)
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
class MusicFineTuner:
|
| 436 |
+
def __init__(
|
| 437 |
+
self,
|
| 438 |
+
model_size: str,
|
| 439 |
+
dataset_path: str,
|
| 440 |
+
output_dir: str,
|
| 441 |
+
device: str = "auto",
|
| 442 |
+
batch_size: int = 4,
|
| 443 |
+
gradient_accumulation_steps: int = 4,
|
| 444 |
+
learning_rate: float = 1e-5,
|
| 445 |
+
num_epochs: int = 3,
|
| 446 |
+
use_hcf: bool = True
|
| 447 |
+
):
|
| 448 |
+
self.config = CONFIGS[model_size]
|
| 449 |
+
self.dataset_path = Path(dataset_path)
|
| 450 |
+
self.output_dir = Path(output_dir)
|
| 451 |
+
self.device = self._setup_device(device)
|
| 452 |
+
self.use_hcf = use_hcf
|
| 453 |
+
self.training_args = TrainingArguments(
|
| 454 |
+
output_dir=str(self.output_dir),
|
| 455 |
+
per_device_train_batch_size=batch_size,
|
| 456 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 457 |
+
learning_rate=learning_rate,
|
| 458 |
+
num_train_epochs=num_epochs,
|
| 459 |
+
logging_steps=100,
|
| 460 |
+
save_steps=1000,
|
| 461 |
+
evaluation_strategy="steps",
|
| 462 |
+
eval_steps=500,
|
| 463 |
+
save_total_limit=3,
|
| 464 |
+
load_best_model_at_end=True,
|
| 465 |
+
gradient_checkpointing=True,
|
| 466 |
+
fp16=torch.cuda.is_available(),
|
| 467 |
+
optim="adamw_torch"
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
def _setup_device(self, device: str) -> str:
|
| 471 |
+
if device == "auto":
|
| 472 |
+
if torch.cuda.is_available():
|
| 473 |
+
return "cuda"
|
| 474 |
+
elif torch.backends.mps.is_available():
|
| 475 |
+
return "mps"
|
| 476 |
+
else:
|
| 477 |
+
return "cpu"
|
| 478 |
+
return device
|
| 479 |
+
|
| 480 |
+
def _load_model_and_tokenizer(self):
|
| 481 |
+
logger.info(f"Loading model {self.config.model_id}")
|
| 482 |
+
|
| 483 |
+
# Determine dtype based on device
|
| 484 |
+
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
|
| 485 |
+
|
| 486 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 487 |
+
self.config.model_id,
|
| 488 |
+
torch_dtype=dtype,
|
| 489 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 490 |
+
attn_implementation="flash_attention_2" if self.device == "cuda" else "eager"
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_id)
|
| 494 |
+
return model, tokenizer
|
| 495 |
+
|
| 496 |
+
def _prepare_dataset(self, tokenizer):
|
| 497 |
+
logger.info("Preparing dataset")
|
| 498 |
+
|
| 499 |
+
with open(self.dataset_path / "metadata" / "dataset_info.json") as f:
|
| 500 |
+
metadata = json.load(f)
|
| 501 |
+
|
| 502 |
+
def generate_text(item):
|
| 503 |
+
return f"Genre: {item['genre']}\nDuration: {item['duration']:.2f}s\nTitle: {item['title']}\nArtist: {item['artist']}\n"
|
| 504 |
+
|
| 505 |
+
texts = [generate_text(item) for item in metadata["files"]]
|
| 506 |
+
dataset = Dataset.from_dict({"text": texts})
|
| 507 |
+
|
| 508 |
+
def tokenize(examples):
|
| 509 |
+
return tokenizer(
|
| 510 |
+
examples["text"],
|
| 511 |
+
truncation=True,
|
| 512 |
+
padding="max_length",
|
| 513 |
+
max_length=512,
|
| 514 |
+
return_tensors="pt"
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
tokenized_dataset = dataset.map(
|
| 518 |
+
tokenize,
|
| 519 |
+
batched=True,
|
| 520 |
+
remove_columns=dataset.column_names
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
return tokenized_dataset
|
| 524 |
+
|
| 525 |
+
def train(self):
|
| 526 |
+
# Create output directory
|
| 527 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 528 |
+
|
| 529 |
+
# Load model and tokenizer
|
| 530 |
+
model, tokenizer = self._load_model_and_tokenizer()
|
| 531 |
+
|
| 532 |
+
# Prepare dataset
|
| 533 |
+
dataset = self._prepare_dataset(tokenizer)
|
| 534 |
+
|
| 535 |
+
# Split dataset
|
| 536 |
+
dataset = dataset.train_test_split(test_size=0.1)
|
| 537 |
+
|
| 538 |
+
if self.use_hcf:
|
| 539 |
+
logger.info("Using HCF-aware training")
|
| 540 |
+
# Create custom HCF optimizer
|
| 541 |
+
optimizer = HCFTrainingOptimizer(
|
| 542 |
+
model.parameters(),
|
| 543 |
+
lr=self.training_args.learning_rate,
|
| 544 |
+
weight_quantization=True,
|
| 545 |
+
maintain_patterns=True
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# Create HCF trainer
|
| 549 |
+
hcf_trainer = HCFAwareTrainer(model, optimizer)
|
| 550 |
+
|
| 551 |
+
# Create custom training loop
|
| 552 |
+
train_loader = torch.utils.data.DataLoader(
|
| 553 |
+
dataset["train"],
|
| 554 |
+
batch_size=self.training_args.per_device_train_batch_size,
|
| 555 |
+
shuffle=True
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
# Training loop with HCF awareness
|
| 559 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 560 |
+
for epoch in range(int(self.training_args.num_train_epochs)):
|
| 561 |
+
stats = hcf_trainer.train_epoch(train_loader, criterion, epoch)
|
| 562 |
+
|
| 563 |
+
# Log training metrics
|
| 564 |
+
logger.info(f"Epoch {epoch} completed")
|
| 565 |
+
logger.info(f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB")
|
| 566 |
+
logger.info(f"Quantization Error: {stats.quantization_error:.6f}")
|
| 567 |
+
logger.info(f"Convergence Rate: {stats.convergence_rate:.4f}")
|
| 568 |
+
|
| 569 |
+
# Save checkpoint
|
| 570 |
+
self._save_hcf_checkpoint(model, tokenizer, epoch)
|
| 571 |
+
else:
|
| 572 |
+
# Use standard HuggingFace Trainer
|
| 573 |
+
logger.info("Using standard training")
|
| 574 |
+
trainer = Trainer(
|
| 575 |
+
model=model,
|
| 576 |
+
args=self.training_args,
|
| 577 |
+
train_dataset=dataset["train"],
|
| 578 |
+
eval_dataset=dataset["test"],
|
| 579 |
+
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# Train
|
| 583 |
+
logger.info("Starting training")
|
| 584 |
+
trainer.train()
|
| 585 |
+
|
| 586 |
+
# Save final model
|
| 587 |
+
logger.info("Saving model")
|
| 588 |
+
model.save_pretrained(str(self.output_dir / "final_model"))
|
| 589 |
+
tokenizer.save_pretrained(str(self.output_dir / "final_model"))
|
| 590 |
+
|
| 591 |
+
def _save_hcf_checkpoint(self, model, tokenizer, epoch):
|
| 592 |
+
"""Save checkpoint with HCF metadata"""
|
| 593 |
+
checkpoint_dir = self.output_dir / f"checkpoint-{epoch}"
|
| 594 |
+
checkpoint_dir.mkdir(exist_ok=True)
|
| 595 |
+
|
| 596 |
+
# Save model and tokenizer
|
| 597 |
+
model.save_pretrained(str(checkpoint_dir))
|
| 598 |
+
tokenizer.save_pretrained(str(checkpoint_dir))
|
| 599 |
+
|
| 600 |
+
# Analyze and save HCF metadata
|
| 601 |
+
analyzer = SafeTensorHCFAnalyzer()
|
| 602 |
+
|
| 603 |
+
# Save tensors to analyze
|
| 604 |
+
model_path = str(checkpoint_dir / "model.safetensors")
|
| 605 |
+
if os.path.exists(model_path):
|
| 606 |
+
results = analyzer.analyze_safetensor_weights(model_path)
|
| 607 |
+
|
| 608 |
+
# Save analysis results
|
| 609 |
+
with open(checkpoint_dir / "hcf_analysis.json", "w") as f:
|
| 610 |
+
json.dump(results, f, indent=2)
|
| 611 |
+
|
| 612 |
+
logger.info(f"Saved checkpoint at {checkpoint_dir}")
|
| 613 |
+
|
| 614 |
+
if __name__ == "__main__":
|
| 615 |
+
import argparse
|
| 616 |
+
parser = argparse.ArgumentParser()
|
| 617 |
+
parser.add_argument("--model_size", type=str, choices=["1b", "7b"], required=True)
|
| 618 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
| 619 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
| 620 |
+
parser.add_argument("--device", type=str, default="auto")
|
| 621 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 622 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
|
| 623 |
+
parser.add_argument("--learning_rate", type=float, default=1e-5)
|
| 624 |
+
parser.add_argument("--num_epochs", type=int, default=3)
|
| 625 |
+
parser.add_argument("--use_hcf", action="store_true", help="Enable HCF-aware training")
|
| 626 |
+
args = parser.parse_args()
|
| 627 |
+
|
| 628 |
+
fine_tuner = MusicFineTuner(
|
| 629 |
+
model_size=args.model_size,
|
| 630 |
+
dataset_path=args.dataset_path,
|
| 631 |
+
output_dir=args.output_dir,
|
| 632 |
+
device=args.device,
|
| 633 |
+
batch_size=args.batch_size,
|
| 634 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 635 |
+
learning_rate=args.learning_rate,
|
| 636 |
+
num_epochs=args.num_epochs,
|
| 637 |
+
use_hcf=args.use_hcf
|
| 638 |
+
)
|
| 639 |
+
fine_tuner.train()
|
train_local.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Training script for ScrapeGoat Music models using local model files with HCF optimization.
|
| 4 |
+
Optimized for local training with the models in the provided directory structure.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import json
|
| 10 |
+
import torch
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from typing import Optional, List, Dict, Tuple, Any
|
| 15 |
+
import transformers
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoModelForCausalLM,
|
| 18 |
+
AutoTokenizer,
|
| 19 |
+
TrainingArguments,
|
| 20 |
+
Trainer,
|
| 21 |
+
DataCollatorForLanguageModeling
|
| 22 |
+
)
|
| 23 |
+
from datasets import Dataset
|
| 24 |
+
import numpy as np
|
| 25 |
+
from accelerate import Accelerator
|
| 26 |
+
from safetensors import safe_open
|
| 27 |
+
from safetensors.torch import save_file, load_file
|
| 28 |
+
|
| 29 |
+
# Configure logging
|
| 30 |
+
logging.basicConfig(level=logging.INFO,
|
| 31 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# Add xcodec_mini_infer to path to access its modules
|
| 35 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 36 |
+
XCODEC_PATH = os.path.join(SCRIPT_DIR, "xcodec_mini_infer")
|
| 37 |
+
sys.path.append(XCODEC_PATH)
|
| 38 |
+
|
| 39 |
+
# Import HCF training components from train_hcf.py
|
| 40 |
+
from train_hcf import (
|
| 41 |
+
TensorInfo,
|
| 42 |
+
SafeTensorHCFAnalyzer,
|
| 43 |
+
TrainingStatistics,
|
| 44 |
+
HCFTrainingOptimizer,
|
| 45 |
+
HCFAwareTrainer
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class LocalModelConfig:
|
| 50 |
+
"""Configuration for local model directories"""
|
| 51 |
+
model_path: str
|
| 52 |
+
name: str
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def model_dir(self) -> str:
|
| 56 |
+
return os.path.abspath(self.model_path)
|
| 57 |
+
|
| 58 |
+
class LocalFineTuner:
|
| 59 |
+
"""Fine-tuner that works with local model files"""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
model_config: LocalModelConfig,
|
| 64 |
+
dataset_path: str,
|
| 65 |
+
output_dir: str,
|
| 66 |
+
device: str = "auto",
|
| 67 |
+
batch_size: int = 4,
|
| 68 |
+
gradient_accumulation_steps: int = 4,
|
| 69 |
+
learning_rate: float = 1e-5,
|
| 70 |
+
num_epochs: int = 3,
|
| 71 |
+
use_hcf: bool = True
|
| 72 |
+
):
|
| 73 |
+
self.model_config = model_config
|
| 74 |
+
self.dataset_path = Path(dataset_path)
|
| 75 |
+
self.output_dir = Path(output_dir)
|
| 76 |
+
self.device = self._setup_device(device)
|
| 77 |
+
self.use_hcf = use_hcf
|
| 78 |
+
|
| 79 |
+
# Ensure output directory exists
|
| 80 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
# Set up training arguments
|
| 83 |
+
self.training_args = TrainingArguments(
|
| 84 |
+
output_dir=str(self.output_dir),
|
| 85 |
+
per_device_train_batch_size=batch_size,
|
| 86 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 87 |
+
learning_rate=learning_rate,
|
| 88 |
+
num_train_epochs=num_epochs,
|
| 89 |
+
logging_steps=100,
|
| 90 |
+
save_steps=1000,
|
| 91 |
+
evaluation_strategy="steps",
|
| 92 |
+
eval_steps=500,
|
| 93 |
+
save_total_limit=3,
|
| 94 |
+
load_best_model_at_end=True,
|
| 95 |
+
gradient_checkpointing=True,
|
| 96 |
+
fp16=torch.cuda.is_available(),
|
| 97 |
+
optim="adamw_torch"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _setup_device(self, device: str) -> str:
|
| 101 |
+
"""Set up the training device"""
|
| 102 |
+
if device == "auto":
|
| 103 |
+
if torch.cuda.is_available():
|
| 104 |
+
return "cuda"
|
| 105 |
+
elif torch.backends.mps.is_available():
|
| 106 |
+
return "mps"
|
| 107 |
+
else:
|
| 108 |
+
return "cpu"
|
| 109 |
+
return device
|
| 110 |
+
|
| 111 |
+
def _load_model_and_tokenizer(self):
|
| 112 |
+
"""Load model and tokenizer from local path"""
|
| 113 |
+
logger.info(f"Loading model from {self.model_config.model_dir}")
|
| 114 |
+
|
| 115 |
+
# Determine dtype based on device
|
| 116 |
+
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
|
| 117 |
+
|
| 118 |
+
# Load model from local path
|
| 119 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 120 |
+
self.model_config.model_dir,
|
| 121 |
+
torch_dtype=dtype,
|
| 122 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 123 |
+
attn_implementation="flash_attention_2" if self.device == "cuda" else "eager",
|
| 124 |
+
local_files_only=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Load tokenizer from local path
|
| 128 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 129 |
+
self.model_config.model_dir,
|
| 130 |
+
local_files_only=True
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return model, tokenizer
|
| 134 |
+
|
| 135 |
+
def _prepare_dataset(self, tokenizer):
|
| 136 |
+
"""Prepare dataset for training"""
|
| 137 |
+
logger.info("Preparing dataset")
|
| 138 |
+
|
| 139 |
+
# Load metadata
|
| 140 |
+
with open(self.dataset_path / "metadata" / "dataset_info.json") as f:
|
| 141 |
+
metadata = json.load(f)
|
| 142 |
+
|
| 143 |
+
# Define text generation from metadata
|
| 144 |
+
def generate_text(item):
|
| 145 |
+
return f"Genre: {item['genre']}\nDuration: {item['duration']:.2f}s\nTitle: {item['title']}\nArtist: {item['artist']}\n"
|
| 146 |
+
|
| 147 |
+
# Generate text samples
|
| 148 |
+
texts = [generate_text(item) for item in metadata["files"]]
|
| 149 |
+
dataset = Dataset.from_dict({"text": texts})
|
| 150 |
+
|
| 151 |
+
# Tokenize function
|
| 152 |
+
def tokenize(examples):
|
| 153 |
+
return tokenizer(
|
| 154 |
+
examples["text"],
|
| 155 |
+
truncation=True,
|
| 156 |
+
padding="max_length",
|
| 157 |
+
max_length=512,
|
| 158 |
+
return_tensors="pt"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Apply tokenization
|
| 162 |
+
tokenized_dataset = dataset.map(
|
| 163 |
+
tokenize,
|
| 164 |
+
batched=True,
|
| 165 |
+
remove_columns=dataset.column_names
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return tokenized_dataset
|
| 169 |
+
|
| 170 |
+
def train(self):
|
| 171 |
+
"""Train the model with HCF optimization"""
|
| 172 |
+
# Create output directory
|
| 173 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 174 |
+
|
| 175 |
+
# Log training configuration
|
| 176 |
+
logger.info(f"Training {self.model_config.name} model with HCF optimization")
|
| 177 |
+
logger.info(f"Model path: {self.model_config.model_dir}")
|
| 178 |
+
logger.info(f"Dataset path: {self.dataset_path}")
|
| 179 |
+
logger.info(f"Output directory: {self.output_dir}")
|
| 180 |
+
logger.info(f"Device: {self.device}")
|
| 181 |
+
logger.info(f"HCF optimization: {'enabled' if self.use_hcf else 'disabled'}")
|
| 182 |
+
|
| 183 |
+
# Load model and tokenizer
|
| 184 |
+
model, tokenizer = self._load_model_and_tokenizer()
|
| 185 |
+
|
| 186 |
+
# Prepare dataset
|
| 187 |
+
dataset = self._prepare_dataset(tokenizer)
|
| 188 |
+
|
| 189 |
+
# Split dataset
|
| 190 |
+
dataset = dataset.train_test_split(test_size=0.1)
|
| 191 |
+
|
| 192 |
+
if self.use_hcf:
|
| 193 |
+
logger.info("Using HCF-aware training")
|
| 194 |
+
# Create custom HCF optimizer
|
| 195 |
+
optimizer = HCFTrainingOptimizer(
|
| 196 |
+
model.parameters(),
|
| 197 |
+
lr=self.training_args.learning_rate,
|
| 198 |
+
weight_quantization=True,
|
| 199 |
+
maintain_patterns=True
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Create HCF trainer
|
| 203 |
+
hcf_trainer = HCFAwareTrainer(model, optimizer)
|
| 204 |
+
|
| 205 |
+
# Create custom training loop
|
| 206 |
+
train_loader = torch.utils.data.DataLoader(
|
| 207 |
+
dataset["train"],
|
| 208 |
+
batch_size=self.training_args.per_device_train_batch_size,
|
| 209 |
+
shuffle=True
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Training loop with HCF awareness
|
| 213 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 214 |
+
for epoch in range(int(self.training_args.num_train_epochs)):
|
| 215 |
+
stats = hcf_trainer.train_epoch(train_loader, criterion, epoch)
|
| 216 |
+
|
| 217 |
+
# Log training metrics
|
| 218 |
+
logger.info(f"Epoch {epoch} completed")
|
| 219 |
+
logger.info(f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB")
|
| 220 |
+
logger.info(f"Quantization Error: {stats.quantization_error:.6f}")
|
| 221 |
+
logger.info(f"Convergence Rate: {stats.convergence_rate:.4f}")
|
| 222 |
+
|
| 223 |
+
# Save checkpoint
|
| 224 |
+
self._save_hcf_checkpoint(model, tokenizer, epoch)
|
| 225 |
+
else:
|
| 226 |
+
# Use standard HuggingFace Trainer
|
| 227 |
+
logger.info("Using standard training")
|
| 228 |
+
trainer = Trainer(
|
| 229 |
+
model=model,
|
| 230 |
+
args=self.training_args,
|
| 231 |
+
train_dataset=dataset["train"],
|
| 232 |
+
eval_dataset=dataset["test"],
|
| 233 |
+
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# Train
|
| 237 |
+
logger.info("Starting training")
|
| 238 |
+
trainer.train()
|
| 239 |
+
|
| 240 |
+
# Save final model
|
| 241 |
+
logger.info("Saving model")
|
| 242 |
+
final_output_dir = self.output_dir / "final_model"
|
| 243 |
+
final_output_dir.mkdir(exist_ok=True)
|
| 244 |
+
|
| 245 |
+
model.save_pretrained(str(final_output_dir))
|
| 246 |
+
tokenizer.save_pretrained(str(final_output_dir))
|
| 247 |
+
|
| 248 |
+
logger.info(f"Training complete. Model saved to {final_output_dir}")
|
| 249 |
+
|
| 250 |
+
def _save_hcf_checkpoint(self, model, tokenizer, epoch):
|
| 251 |
+
"""Save checkpoint with HCF metadata"""
|
| 252 |
+
checkpoint_dir = self.output_dir / f"checkpoint-{epoch}"
|
| 253 |
+
checkpoint_dir.mkdir(exist_ok=True)
|
| 254 |
+
|
| 255 |
+
# Save model and tokenizer
|
| 256 |
+
model.save_pretrained(str(checkpoint_dir))
|
| 257 |
+
tokenizer.save_pretrained(str(checkpoint_dir))
|
| 258 |
+
|
| 259 |
+
# Analyze and save HCF metadata
|
| 260 |
+
analyzer = SafeTensorHCFAnalyzer()
|
| 261 |
+
|
| 262 |
+
# Save tensors to analyze
|
| 263 |
+
model_path = str(checkpoint_dir / "model.safetensors")
|
| 264 |
+
if os.path.exists(model_path):
|
| 265 |
+
results = analyzer.analyze_safetensor_weights(model_path)
|
| 266 |
+
|
| 267 |
+
# Save analysis results
|
| 268 |
+
with open(checkpoint_dir / "hcf_analysis.json", "w") as f:
|
| 269 |
+
json.dump(results, f, indent=2)
|
| 270 |
+
|
| 271 |
+
logger.info(f"Saved checkpoint at {checkpoint_dir}")
|
| 272 |
+
|
| 273 |
+
def main():
|
| 274 |
+
"""Main function for training"""
|
| 275 |
+
import argparse
|
| 276 |
+
parser = argparse.ArgumentParser(description="Retrain ScrapeGoat Music models with HCF optimization")
|
| 277 |
+
parser.add_argument("--model", type=str, choices=["7b", "1b"], required=True,
|
| 278 |
+
help="Model size to train")
|
| 279 |
+
parser.add_argument("--dataset_path", type=str, required=True,
|
| 280 |
+
help="Path to processed dataset")
|
| 281 |
+
parser.add_argument("--output_dir", type=str, required=True,
|
| 282 |
+
help="Directory to save trained model")
|
| 283 |
+
parser.add_argument("--device", type=str, default="auto",
|
| 284 |
+
help="Device to use (cuda, cpu, mps, or auto)")
|
| 285 |
+
parser.add_argument("--batch_size", type=int, default=4,
|
| 286 |
+
help="Batch size for training")
|
| 287 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
|
| 288 |
+
help="Gradient accumulation steps")
|
| 289 |
+
parser.add_argument("--learning_rate", type=float, default=1e-5,
|
| 290 |
+
help="Learning rate")
|
| 291 |
+
parser.add_argument("--num_epochs", type=int, default=3,
|
| 292 |
+
help="Number of training epochs")
|
| 293 |
+
parser.add_argument("--use_hcf", action="store_true", default=True,
|
| 294 |
+
help="Enable HCF optimization")
|
| 295 |
+
parser.add_argument("--base_dir", type=str, default=os.getcwd(),
|
| 296 |
+
help="Base directory containing model folders")
|
| 297 |
+
|
| 298 |
+
args = parser.parse_args()
|
| 299 |
+
|
| 300 |
+
# Set up model configuration based on size
|
| 301 |
+
if args.model == "7b":
|
| 302 |
+
model_path = os.path.join(args.base_dir, "scrapegoat/ScrapeGoat-Music-Stage1")
|
| 303 |
+
model_config = LocalModelConfig(
|
| 304 |
+
model_path=model_path,
|
| 305 |
+
name="ScrapeGoatMusic 7B"
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
model_path = os.path.join(args.base_dir, "scrapegoat/ScrapeGoat-Music-Stage2")
|
| 309 |
+
model_config = LocalModelConfig(
|
| 310 |
+
model_path=model_path,
|
| 311 |
+
name="ScrapeGoatMusic 1B"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Create fine-tuner
|
| 315 |
+
fine_tuner = LocalFineTuner(
|
| 316 |
+
model_config=model_config,
|
| 317 |
+
dataset_path=args.dataset_path,
|
| 318 |
+
output_dir=args.output_dir,
|
| 319 |
+
device=args.device,
|
| 320 |
+
batch_size=args.batch_size,
|
| 321 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 322 |
+
learning_rate=args.learning_rate,
|
| 323 |
+
num_epochs=args.num_epochs,
|
| 324 |
+
use_hcf=args.use_hcf
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Train model
|
| 328 |
+
fine_tuner.train()
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
main()
|