"""Train a base model on the unified Mel corpus with LoRA. Designed for cloud GPU deployment. Loads base model in fp16/bf16, applies LoRA adapters, trains on the prepared JSONL data. Usage: python train.py --model EleutherAI/pythia-1.4b --data train.jsonl --output mel-pythia-1.4b For 4-bit quantization (fits on smaller GPUs): python train.py --model EleutherAI/pythia-2.8b --data train.jsonl --output mel-pythia-2.8b --use-4bit """ import argparse import json import os import torch from datasets import Dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig, ) from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType def load_jsonl(path): """Load JSONL into a HF Dataset.""" examples = [] with open(path) as f: for line in f: examples.append(json.loads(line)) return Dataset.from_list(examples) def main(): parser = argparse.ArgumentParser() parser.add_argument('--model', default='EleutherAI/pythia-1.4b', help='Base model. Use uncontaminated base models, not -Instruct/-Chat variants.') parser.add_argument('--data', default='train.jsonl') parser.add_argument('--output', default='mel-pythia-1.4b') parser.add_argument('--epochs', type=int, default=3) parser.add_argument('--batch-size', type=int, default=1) parser.add_argument('--gradient-accumulation', type=int, default=8) parser.add_argument('--learning-rate', type=float, default=2e-4) parser.add_argument('--lora-rank', type=int, default=16) parser.add_argument('--lora-alpha', type=int, default=32) parser.add_argument('--use-4bit', action='store_true', help='4-bit quantization for memory efficiency') parser.add_argument('--use-8bit', action='store_true') parser.add_argument('--max-length', type=int, default=2048) parser.add_argument('--hf-repo', default=None, help='HuggingFace repo to push trained adapter to') args = parser.parse_args() print(f"=== Training {args.model} on {args.data} ===") print(f"Output: {args.output}") print(f"Epochs: {args.epochs}, batch: {args.batch_size}, accum: {args.gradient_accumulation}") print(f"LoRA rank: {args.lora_rank}, alpha: {args.lora_alpha}") # Quantization config bnb_config = None if args.use_4bit: bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) elif args.use_8bit: bnb_config = BitsAndBytesConfig(load_in_8bit=True) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(args.model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model print(f"Loading model...") model = AutoModelForCausalLM.from_pretrained( args.model, quantization_config=bnb_config, torch_dtype=torch.bfloat16 if not bnb_config else None, device_map='auto', ) if bnb_config: model = prepare_model_for_kbit_training(model) # Apply LoRA # Target modules vary by model architecture target_modules = { 'pythia': ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'], 'llama': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], 'qwen': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'], 'phi': ['q_proj', 'k_proj', 'v_proj', 'dense', 'fc1', 'fc2'], } model_family = 'pythia' for key in target_modules: if key in args.model.lower(): model_family = key break lora_config = LoraConfig( r=args.lora_rank, lora_alpha=args.lora_alpha, target_modules=target_modules[model_family], lora_dropout=0.05, bias='none', task_type=TaskType.CAUSAL_LM, ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() # Load and tokenize data print(f"Loading data: {args.data}") dataset = load_jsonl(args.data) print(f"Examples: {len(dataset)}") def tokenize_fn(examples): return tokenizer( examples['text'], truncation=True, max_length=args.max_length, padding=False, ) dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Training args training_args = TrainingArguments( output_dir=args.output, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation, learning_rate=args.learning_rate, warmup_steps=100, logging_steps=10, save_steps=500, save_total_limit=3, bf16=True, gradient_checkpointing=True, optim='paged_adamw_8bit' if bnb_config else 'adamw_torch', report_to='none', push_to_hub=args.hf_repo is not None, hub_model_id=args.hf_repo, ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=data_collator, ) print("Starting training...") trainer.train() print("Saving final model...") trainer.save_model(args.output) if args.hf_repo: trainer.push_to_hub() print(f"Done. Saved to {args.output}") if __name__ == '__main__': main()