Melofhell00's picture
Complete training pipeline for unified corpus on uncontaminated base models
fde73f3 verified
"""Train a base model on the unified Mel corpus with LoRA.
Designed for cloud GPU deployment. Loads base model in fp16/bf16, applies
LoRA adapters, trains on the prepared JSONL data.
Usage:
python train.py --model EleutherAI/pythia-1.4b --data train.jsonl --output mel-pythia-1.4b
For 4-bit quantization (fits on smaller GPUs):
python train.py --model EleutherAI/pythia-2.8b --data train.jsonl --output mel-pythia-2.8b --use-4bit
"""
import argparse
import json
import os
import torch
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
def load_jsonl(path):
"""Load JSONL into a HF Dataset."""
examples = []
with open(path) as f:
for line in f:
examples.append(json.loads(line))
return Dataset.from_list(examples)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='EleutherAI/pythia-1.4b',
help='Base model. Use uncontaminated base models, not -Instruct/-Chat variants.')
parser.add_argument('--data', default='train.jsonl')
parser.add_argument('--output', default='mel-pythia-1.4b')
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--gradient-accumulation', type=int, default=8)
parser.add_argument('--learning-rate', type=float, default=2e-4)
parser.add_argument('--lora-rank', type=int, default=16)
parser.add_argument('--lora-alpha', type=int, default=32)
parser.add_argument('--use-4bit', action='store_true', help='4-bit quantization for memory efficiency')
parser.add_argument('--use-8bit', action='store_true')
parser.add_argument('--max-length', type=int, default=2048)
parser.add_argument('--hf-repo', default=None, help='HuggingFace repo to push trained adapter to')
args = parser.parse_args()
print(f"=== Training {args.model} on {args.data} ===")
print(f"Output: {args.output}")
print(f"Epochs: {args.epochs}, batch: {args.batch_size}, accum: {args.gradient_accumulation}")
print(f"LoRA rank: {args.lora_rank}, alpha: {args.lora_alpha}")
# Quantization config
bnb_config = None
if args.use_4bit:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
elif args.use_8bit:
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
print(f"Loading model...")
model = AutoModelForCausalLM.from_pretrained(
args.model,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16 if not bnb_config else None,
device_map='auto',
)
if bnb_config:
model = prepare_model_for_kbit_training(model)
# Apply LoRA
# Target modules vary by model architecture
target_modules = {
'pythia': ['query_key_value', 'dense', 'dense_h_to_4h', 'dense_4h_to_h'],
'llama': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
'qwen': ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
'phi': ['q_proj', 'k_proj', 'v_proj', 'dense', 'fc1', 'fc2'],
}
model_family = 'pythia'
for key in target_modules:
if key in args.model.lower():
model_family = key
break
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=target_modules[model_family],
lora_dropout=0.05,
bias='none',
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Load and tokenize data
print(f"Loading data: {args.data}")
dataset = load_jsonl(args.data)
print(f"Examples: {len(dataset)}")
def tokenize_fn(examples):
return tokenizer(
examples['text'],
truncation=True,
max_length=args.max_length,
padding=False,
)
dataset = dataset.map(tokenize_fn, batched=True, remove_columns=dataset.column_names)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Training args
training_args = TrainingArguments(
output_dir=args.output,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation,
learning_rate=args.learning_rate,
warmup_steps=100,
logging_steps=10,
save_steps=500,
save_total_limit=3,
bf16=True,
gradient_checkpointing=True,
optim='paged_adamw_8bit' if bnb_config else 'adamw_torch',
report_to='none',
push_to_hub=args.hf_repo is not None,
hub_model_id=args.hf_repo,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
)
print("Starting training...")
trainer.train()
print("Saving final model...")
trainer.save_model(args.output)
if args.hf_repo:
trainer.push_to_hub()
print(f"Done. Saved to {args.output}")
if __name__ == '__main__':
main()