gpt2_medium_prefix_682k / scripts /train_experiment.py
augustocsc's picture
GPT-2 Medium trained on prefix dataset (682K)
3742716 verified
#!/usr/bin/env python3
"""
Training script for expression generation experiments.
Supports two formats:
- EXP-A (JSON): Uses custom <|endofex|> token
- EXP-B (EOS): Uses native GPT-2 <|endoftext|> token
Usage:
# EXP-A (JSON format)
python scripts/train_experiment.py \
--experiment_name exp_a_json \
--train_file ./data/experiments/exp_a_json/train.csv \
--output_dir ./output/exp_a_json \
--end_marker "<|endofex|>"
# EXP-B (EOS format)
python scripts/train_experiment.py \
--experiment_name exp_b_eos \
--train_file ./data/experiments/exp_b_eos/train.csv \
--output_dir ./output/exp_b_eos \
--end_marker "<|endoftext|>" \
--use_native_eos
"""
import argparse
import logging
import os
import random
import sys
from pathlib import Path
import numpy as np
import torch
import wandb
from datasets import load_dataset
from dotenv import load_dotenv
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
Trainer,
TrainingArguments,
set_seed,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def tokenize_function(examples, tokenizer):
"""Tokenize the text field."""
return tokenizer(examples["text"])
def group_texts(examples, block_size):
"""Group texts into blocks of block_size."""
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated[list(examples.keys())[0]])
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
else:
logger.warning(f"Total length ({total_length}) < block_size ({block_size})")
result = {
k: [t[i:i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated.items()
}
result["labels"] = result["input_ids"].copy()
return result
def validate_data_format(dataset, tokenizer, end_marker, num_samples=10, is_json_format=False):
"""Validate that training data is in the expected format."""
import json as json_module
if is_json_format:
logger.info("Validating JSON format data...")
marker_to_check = '"expr":' # JSON format has expr field
else:
logger.info(f"Validating data contains '{end_marker}'...")
marker_to_check = end_marker
sample_indices = random.sample(
range(len(dataset)),
min(num_samples, len(dataset))
)
valid_count = 0
for idx in sample_indices:
text = dataset[idx]["text"]
if is_json_format:
# For JSON format, validate it's valid JSON with expr field
try:
obj = json_module.loads(text)
if "expr" in obj and "vars" in obj:
valid_count += 1
except:
pass
else:
# For EOS format, check marker presence
if marker_to_check in text:
valid_count += 1
rate = valid_count / len(sample_indices) * 100
logger.info(f"Validation: {valid_count}/{len(sample_indices)} ({rate:.1f}%) valid")
if valid_count == 0:
logger.error("No valid samples found! Data not properly prepared.")
sys.exit(1)
return rate
def main():
parser = argparse.ArgumentParser(
description="Train expression generation model"
)
# Required arguments
parser.add_argument("--experiment_name", type=str, required=True,
help="Experiment name (e.g., 'exp_a_json', 'exp_b_eos')")
parser.add_argument("--train_file", type=str, required=True,
help="Path to training CSV file")
parser.add_argument("--output_dir", type=str, required=True,
help="Directory to save model")
# Format options
parser.add_argument("--end_marker", type=str, default="<|endofex|>",
help="End marker token (e.g., '<|endofex|>' or '<|endoftext|>')")
parser.add_argument("--use_native_eos", action="store_true",
help="Use native GPT-2 EOS token instead of custom token")
parser.add_argument("--json_format", action="store_true",
help="Data is in JSON format (for EXP-A)")
# Optional data arguments
parser.add_argument("--validation_file", type=str, default=None,
help="Path to validation CSV file")
parser.add_argument("--test_file", type=str, default=None,
help="Path to test CSV file")
# Model arguments
parser.add_argument("--model_name_or_path", type=str, default="gpt2",
help="Base model name")
parser.add_argument("--block_size", type=int, default=128,
help="Block size for tokenization")
# Training arguments
parser.add_argument("--num_train_epochs", type=int, default=3,
help="Number of training epochs")
parser.add_argument("--per_device_train_batch_size", type=int, default=8,
help="Batch size per device")
parser.add_argument("--per_device_eval_batch_size", type=int, default=8,
help="Eval batch size per device")
parser.add_argument("--gradient_accumulation_steps", type=int, default=4,
help="Gradient accumulation steps")
parser.add_argument("--learning_rate", type=float, default=5e-5,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.01,
help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=500,
help="Warmup steps")
parser.add_argument("--fp16", action="store_true",
help="Use FP16 mixed precision")
# LoRA arguments
parser.add_argument("--lora_r", type=int, default=8,
help="LoRA rank")
parser.add_argument("--lora_alpha", type=int, default=32,
help="LoRA alpha")
parser.add_argument("--lora_dropout", type=float, default=0.05,
help="LoRA dropout")
# Wandb arguments
parser.add_argument("--wandb_project", type=str, default="seriguela_experiments",
help="Wandb project name")
parser.add_argument("--wandb_run_name", type=str, default=None,
help="Wandb run name")
# Other
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--logging_steps", type=int, default=100,
help="Logging steps")
parser.add_argument("--save_steps", type=int, default=500,
help="Save checkpoint steps")
parser.add_argument("--eval_steps", type=int, default=500,
help="Evaluation steps")
parser.add_argument("--push_to_hub", action="store_true",
help="Push model to HuggingFace Hub")
parser.add_argument("--hub_model_id", type=str, default=None,
help="Hub model ID for pushing")
args = parser.parse_args()
# Load environment variables
load_dotenv()
# Set seed
set_seed(args.seed)
# Configure wandb
wandb_api_key = os.getenv("WANDB_API_KEY")
if wandb_api_key:
os.environ["WANDB_API_KEY"] = wandb_api_key
wandb.login(key=wandb_api_key)
wandb_run_name = args.wandb_run_name or args.experiment_name
wandb.init(
project=args.wandb_project,
name=wandb_run_name,
config=vars(args)
)
logger.info("=" * 60)
logger.info(f"EXPERIMENT: {args.experiment_name}")
logger.info("=" * 60)
logger.info(f"End marker: {args.end_marker}")
logger.info(f"Use native EOS: {args.use_native_eos}")
logger.info(f"Train file: {args.train_file}")
logger.info(f"Output dir: {args.output_dir}")
logger.info("=" * 60)
# Load dataset
logger.info("Loading dataset...")
data_files = {"train": args.train_file}
if args.validation_file:
data_files["validation"] = args.validation_file
if args.test_file:
data_files["test"] = args.test_file
raw_datasets = load_dataset("csv", data_files=data_files)
logger.info(f"Loaded dataset: {raw_datasets}")
# Validate data format
validate_data_format(
raw_datasets["train"],
tokenizer=None,
end_marker=args.end_marker,
is_json_format=args.json_format
)
# Load tokenizer
logger.info(f"Loading tokenizer: {args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
# Set padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Add special tokens based on experiment type
if args.use_native_eos:
# EXP-B: Use native EOS token, no special tokens needed
logger.info("Using native GPT-2 EOS token (<|endoftext|>)")
end_token_id = tokenizer.eos_token_id
logger.info(f"EOS token ID: {end_token_id}")
else:
# EXP-A: Add custom <|endofex|> token
logger.info("Adding custom special tokens")
tokenizer.add_special_tokens({
"additional_special_tokens": ["<|startofex|>", "<|endofex|>"]
})
end_token_id = tokenizer.convert_tokens_to_ids("<|endofex|>")
logger.info(f"Custom end token ID: {end_token_id}")
# Tokenize dataset
logger.info("Tokenizing dataset...")
tokenized_datasets = raw_datasets.map(
lambda examples: tokenize_function(examples, tokenizer),
batched=True,
remove_columns=raw_datasets["train"].column_names
)
# Group into blocks
logger.info(f"Grouping texts into blocks of {args.block_size}...")
lm_datasets = tokenized_datasets.map(
lambda examples: group_texts(examples, args.block_size),
batched=True
)
logger.info(f"Processed dataset: {lm_datasets}")
# Validate processed data has end markers
logger.info("Validating processed data...")
sample_indices = random.sample(
range(len(lm_datasets["train"])),
min(10, len(lm_datasets["train"]))
)
valid_count = 0
for idx in sample_indices:
sample = lm_datasets["train"][idx]
decoded = tokenizer.decode(sample["input_ids"])
if args.end_marker in decoded:
valid_count += 1
logger.info(f"Processed data validation: {valid_count}/{len(sample_indices)} contain end marker")
if valid_count == 0:
logger.error("No processed samples contain end marker! Check data format.")
sys.exit(1)
# Load model
logger.info(f"Loading model: {args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
# Resize embeddings if using custom tokens
if not args.use_native_eos:
model.resize_token_embeddings(len(tokenizer))
logger.info(f"Resized embeddings to {len(tokenizer)}")
# Configure EOS token for generation
model.config.eos_token_id = end_token_id
logger.info(f"Model EOS token ID: {model.config.eos_token_id}")
# Apply LoRA
logger.info("Applying LoRA configuration...")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["c_attn"],
lora_dropout=args.lora_dropout,
bias="none"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.train()
# Training arguments
logger.info("Configuring training...")
has_validation = "validation" in lm_datasets and len(lm_datasets["validation"]) > 0
training_args = TrainingArguments(
output_dir=args.output_dir,
overwrite_output_dir=True,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_steps=args.warmup_steps,
logging_dir=os.path.join(args.output_dir, 'logs'),
logging_steps=args.logging_steps,
eval_strategy="epoch" if has_validation else "no",
save_strategy="epoch",
save_total_limit=2,
load_best_model_at_end=has_validation,
metric_for_best_model="eval_loss" if has_validation else None,
greater_is_better=False if has_validation else None,
fp16=args.fp16,
report_to="wandb",
run_name=wandb_run_name,
seed=args.seed,
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
# Trainer
logger.info("Initializing Trainer...")
callbacks = []
if has_validation:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=2))
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_datasets["train"],
eval_dataset=lm_datasets.get("validation"),
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks if callbacks else None,
)
# Train
logger.info("=" * 60)
logger.info("STARTING TRAINING")
logger.info("=" * 60)
try:
train_result = trainer.train()
# Log metrics
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
# Save model
logger.info(f"Saving model to {args.output_dir}")
trainer.save_model()
tokenizer.save_pretrained(args.output_dir)
# Save experiment info
import json
exp_info = {
"experiment_name": args.experiment_name,
"end_marker": args.end_marker,
"use_native_eos": args.use_native_eos,
"train_file": args.train_file,
"end_token_id": end_token_id,
"final_loss": metrics.get("train_loss", None),
}
with open(os.path.join(args.output_dir, "experiment_info.json"), "w") as f:
json.dump(exp_info, f, indent=2)
logger.info("=" * 60)
logger.info("TRAINING COMPLETE")
logger.info("=" * 60)
logger.info(f"Final train loss: {metrics.get('train_loss', 'N/A')}")
logger.info(f"Model saved to: {args.output_dir}")
except Exception as e:
logger.error(f"Training failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
finally:
wandb.finish()
# Push to Hub if requested
if args.push_to_hub and args.hub_model_id:
logger.info(f"Pushing to Hub: {args.hub_model_id}")
try:
trainer.push_to_hub(commit_message=f"Training complete: {args.experiment_name}")
logger.info("Push successful!")
except Exception as e:
logger.error(f"Push failed: {e}")
if __name__ == "__main__":
main()