| """Train a base model on the unified Mel corpus with LoRA. |
| |
| Designed for cloud GPU deployment. Loads base model in fp16/bf16, applies |
| LoRA adapters, trains on the prepared JSONL data. |
| |
| Usage: |
| python train.py --model EleutherAI/pythia-1.4b --data train.jsonl --output mel-pythia-1.4b |
| |
| For 4-bit quantization (fits on smaller GPUs): |
| python train.py --model EleutherAI/pythia-2.8b --data train.jsonl --output mel-pythia-2.8b --use-4bit |
| """ |
| import argparse |
| import json |
| import os |
| import torch |
| from datasets import Dataset |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForLanguageModeling, |
| BitsAndBytesConfig, |
| ) |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType |
|
|
|
|
| def load_jsonl(path): |
| """Load JSONL into a HF Dataset.""" |
| examples = [] |
| with open(path) as f: |
| for line in f: |
| examples.append(json.loads(line)) |
| return Dataset.from_list(examples) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model', default='EleutherAI/pythia-1.4b', |
| help='Base model. Use uncontaminated base models, not -Instruct/-Chat variants.') |
| parser.add_argument('--data', default='train.jsonl') |
| parser.add_argument('--output', default='mel-pythia-1.4b') |
| parser.add_argument('--epochs', type=int, default=3) |
| parser.add_argument('--batch-size', type=int, default=1) |
| parser.add_argument('--gradient-accumulation', type=int, default=8) |
| parser.add_argument('--learning-rate', type=float, default=2e-4) |
| parser.add_argument('--lora-rank', type=int, default=16) |
| parser.add_argument('--lora-alpha', type=int, default=32) |
| parser.add_argument('--use-4bit', action='store_true', help='4-bit quantization for memory efficiency') |
| parser.add_argument('--use-8bit', action='store_true') |
| parser.add_argument('--max-length', type=int, default=2048) |
| parser.add_argument('--hf-repo', default=None, help='HuggingFace repo to push trained adapter to') |
| args = parser.parse_args() |
| |
| print(f"=== Training {args.model} on {args.data} ===") |
| print(f"Output: {args.output}") |
| print(f"Epochs: {args.epochs}, batch: {args.batch_size}, accum: {args.gradient_accumulation}") |
| print(f"LoRA rank: {args.lora_rank}, alpha: {args.lora_alpha}") |
| |
| |
| bnb_config = None |
| if args.use_4bit: |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type='nf4', |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
| elif args.use_8bit: |
| bnb_config = BitsAndBytesConfig(load_in_8bit=True) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(args.model) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| print(f"Loading model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model, |
| quantization_config=bnb_config, |
| torch_dtype=torch.bfloat16 if not bnb_config else None, |
| device_map='auto', |
| ) |
| |
| if bnb_config: |
| model = prepare_model_for_kbit_training(model) |
| |
| |
| |
| target_modules = { |
| 'pythia': ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'], |
| 'llama': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], |
| 'qwen': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], |
| 'phi': ['q_proj', 'k_proj', 'v_proj', 'dense', 'fc1', 'fc2'], |
| } |
| model_family = 'pythia' |
| for key in target_modules: |
| if key in args.model.lower(): |
| model_family = key |
| break |
| |
| lora_config = LoraConfig( |
| r=args.lora_rank, |
| lora_alpha=args.lora_alpha, |
| target_modules=target_modules[model_family], |
| lora_dropout=0.05, |
| bias='none', |
| task_type=TaskType.CAUSAL_LM, |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
| |
| |
| print(f"Loading data: {args.data}") |
| dataset = load_jsonl(args.data) |
| print(f"Examples: {len(dataset)}") |
| |
| def tokenize_fn(examples): |
| return tokenizer( |
| examples['text'], |
| truncation=True, |
| max_length=args.max_length, |
| padding=False, |
| ) |
| |
| dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names) |
| |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
| |
| |
| training_args = TrainingArguments( |
| output_dir=args.output, |
| num_train_epochs=args.epochs, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation, |
| learning_rate=args.learning_rate, |
| warmup_steps=100, |
| logging_steps=10, |
| save_steps=500, |
| save_total_limit=3, |
| bf16=True, |
| gradient_checkpointing=True, |
| optim='paged_adamw_8bit' if bnb_config else 'adamw_torch', |
| report_to='none', |
| push_to_hub=args.hf_repo is not None, |
| hub_model_id=args.hf_repo, |
| ) |
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=dataset, |
| data_collator=data_collator, |
| ) |
| |
| print("Starting training...") |
| trainer.train() |
| |
| print("Saving final model...") |
| trainer.save_model(args.output) |
| if args.hf_repo: |
| trainer.push_to_hub() |
| |
| print(f"Done. Saved to {args.output}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|