File size: 9,037 Bytes
9071ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
#!/usr/bin/env python3
"""

Main training entry point for TouchGrass models.

Fine-tunes Qwen3.5 with LoRA and music modules.

"""

import argparse
import sys
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

from configs.touchgrass_3b_config import TOUCHGRASS_3B_CONFIG
from configs.touchgrass_7b_config import TOUCHGRASS_7B_CONFIG
from configs.training_config import (
    TRAINING_CONFIG_3B_CUDA,
    TRAINING_CONFIG_7B_CUDA,
    TRAINING_CONFIG_MPS,
)
from data.dataset_loader import TouchGrassDataset
from training.trainer import TouchGrassTrainer
from tokenizer.music_token_extension import MusicTokenizerExtension


def parse_args():
    parser = argparse.ArgumentParser(description="Train TouchGrass music assistant model")
    parser.add_argument(
        "--model_size",
        type=str,
        choices=["3b", "7b"],
        default="3b",
        help="Model size to train",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda", "mps", "cpu"],
        help="Device to train on",
    )
    parser.add_argument(
        "--use_mps",
        action="store_true",
        help="Use MPS backend (Apple Silicon)",
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default="./data/processed",
        help="Directory with processed data shards",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./checkpoints",
        help="Output directory for checkpoints",
    )
    parser.add_argument(
        "--max_steps",
        type=int,
        default=None,
        help="Override max training steps",
    )
    parser.add_argument(
        "--micro_batch_size",
        type=int,
        default=None,
        help="Override micro batch size",
    )
    parser.add_argument(
        "--lora_r",
        type=int,
        default=16,
        help="LoRA rank",
    )
    parser.add_argument(
        "--lora_alpha",
        type=int,
        default=32,
        help="LoRA alpha",
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help="Resume training from checkpoint",
    )
    parser.add_argument(
        "--generate_data",
        action="store_true",
        help="Generate synthetic training data before training",
    )
    parser.add_argument(
        "--num_train_samples",
        type=int,
        default=10000,
        help="Number of training samples to generate",
    )
    return parser.parse_args()


def load_tokenizer(config: dict, args):
    """Load and extend tokenizer with music tokens."""
    base_model = config["base_model"]
    print(f"Loading base tokenizer: {base_model}")

    # Extend tokenizer with music tokens
    tokenizer_ext = MusicTokenizerExtension(
        base_tokenizer_name=base_model,
        special_tokens=config.get("special_tokens"),
    )

    tokenizer = tokenizer_ext.get_tokenizer()
    print(f"Extended tokenizer vocab size: {tokenizer.vocab_size}")

    return tokenizer_ext, tokenizer


def load_model(config: dict, args, tokenizer):
    """Load base model and apply LoRA."""
    base_model = config["base_model"]
    print(f"Loading base model: {base_model}")

    # Determine torch dtype
    if args.device == "cuda" and torch.cuda.is_available():
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    elif args.device == "mps":
        dtype = torch.float32  # MPS doesn't support bf16 well
    else:
        dtype = torch.float32

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        torch_dtype=dtype,
        trust_remote_code=True,
    )

    # Resize embeddings to match extended tokenizer
    model.resize_token_embeddings(tokenizer.vocab_size)

    # Apply LoRA
    print("Applying LoRA...")
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=0.1,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        bias="none",
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    return model


def generate_synthetic_data(config: dict, args, tokenizer):
    """Generate synthetic training data."""
    from data.music_qa_generator import MusicQAGenerator
    from data.chat_formatter import ChatFormatter

    print("Generating synthetic training data...")

    # Create generator
    generator = MusicQAGenerator(seed=42)

    # Generate dataset
    output_dir = Path(args.data_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Generate full dataset
    dataset = generator.generate_dataset(
        num_samples=args.num_train_samples,
        output_path=output_dir / "synthetic_music_qa.jsonl",
    )

    # Format with chat formatter
    formatter = ChatFormatter(tokenizer=tokenizer)
    formatted_samples = []

    for item in dataset:
        formatted = formatter.format_qa_pair(
            question=item["messages"][1]["content"],
            answer=item["messages"][2]["content"],
            context=None,  # Context already in question
        )
        formatted_samples.append(formatted)

    # Create train/val splits
    splits = formatter.create_pretraining_dataset(
        formatted_samples,
        output_dir=output_dir,
        train_split=0.9,
    )

    print(f"Data generation complete. Train: {splits['train']}, Val: {splits['val']}")

    return splits


def load_datasets(args, tokenizer):
    """Load training and validation datasets."""
    data_dir = Path(args.data_dir)

    train_path = data_dir / "train.jsonl"
    val_path = data_dir / "val.jsonl"

    if not train_path.exists() or not val_path.exists():
        print(f"Data not found in {data_dir}. Generate with --generate_data")
        sys.exit(1)

    print(f"Loading datasets from {data_dir}")

    train_dataset = TouchGrassDataset(
        data_path=str(train_path),
        tokenizer=tokenizer,
        max_seq_length=4096,
        mode="train",
    )

    val_dataset = TouchGrassDataset(
        data_path=str(val_path),
        tokenizer=tokenizer,
        max_seq_length=4096,
        mode="eval",
    )

    return train_dataset, val_dataset


def main():
    args = parse_args()

    # Load config
    if args.model_size == "3b":
        model_config = TOUCHGRASS_3B_CONFIG.copy()
        train_config = TRAINING_CONFIG_3B_CUDA.copy()
    else:
        model_config = TOUCHGRASS_7B_CONFIG.copy()
        train_config = TRAINING_CONFIG_7B_CUDA.copy()

    # Override with MPS config if needed
    if args.use_mps or args.device == "mps":
        train_config = TRAINING_CONFIG_MPS.copy()
        train_config["use_mps"] = True

    # Apply overrides
    if args.max_steps:
        train_config["max_steps"] = args.max_steps
    if args.micro_batch_size:
        train_config["micro_batch_size"] = args.micro_batch_size

    # Set device
    device = torch.device(args.device)
    train_config["device"] = args.device

    print(f"Training TouchGrass-{args.model_size.upper()}")
    print(f"Device: {device}")
    print(f"Max steps: {train_config['max_steps']}")
    print(f"Micro batch size: {train_config['micro_batch_size']}")
    print(f"LoRA: r={args.lora_r}, alpha={args.lora_alpha}")

    # Load tokenizer
    tokenizer_ext, tokenizer = load_tokenizer(model_config, args)

    # Generate data if requested
    if args.generate_data:
        generate_synthetic_data(model_config, args, tokenizer)

    # Load datasets
    train_dataset, val_dataset = load_datasets(args, tokenizer)
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")

    # Load model with LoRA
    model = load_model(model_config, args, tokenizer)

    # Create trainer
    trainer = TouchGrassTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        config=train_config,
        eval_dataset=val_dataset,
    )

    # Resume from checkpoint if specified
    if args.resume_from_checkpoint:
        trainer.load_checkpoint(args.resume_from_checkpoint)

    # Train
    trainer.train()

    # Save final model
    output_dir = Path(args.output_dir) / f"touchgrass-{args.model_size}b-final"
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"\nSaving final model to {output_dir}")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    # Save tokenizer extension metadata
    tokenizer_ext.save_pretrained(output_dir)

    print("Training complete! Model saved.")


if __name__ == "__main__":
    main()