|
|
|
|
|
"""
|
|
|
SFT training script for arithmetic countdown problems.
|
|
|
|
|
|
This script trains a language model using SFT (Supervised Fine-Tuning)
|
|
|
to solve arithmetic problems with proper reasoning and formatting.
|
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
import logging
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
from datasets import Dataset
|
|
|
from peft import LoraConfig, get_peft_model
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
|
|
from trl import SFTConfig, SFTTrainer
|
|
|
|
|
|
from src.dataset.sft import (
|
|
|
load_csv_dataset_sft,
|
|
|
map_problem_description_to_conversation_sft,
|
|
|
)
|
|
|
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
)
|
|
|
logger = logging.getLogger("sft_training")
|
|
|
|
|
|
|
|
|
def load_train_dataset(
|
|
|
dataset_csv: str, max_rows: int = 2000, seed: int = 42
|
|
|
) -> Dataset:
|
|
|
"""
|
|
|
Load, shuffle, and subsample the training dataset.
|
|
|
|
|
|
Args:
|
|
|
dataset_csv: Absolute path to the dataset CSV file
|
|
|
max_rows: Maximum number of rows to select for training
|
|
|
seed: Seed for dataset shuffling
|
|
|
|
|
|
Returns:
|
|
|
Dataset: A datasets.Dataset ready for SFT training
|
|
|
"""
|
|
|
raw_dataset: Dataset = load_csv_dataset_sft(
|
|
|
dataset_csv, "train", map_problem_description_to_conversation_sft
|
|
|
)
|
|
|
raw_dataset = raw_dataset.shuffle(seed=seed)
|
|
|
train_dataset = raw_dataset.select(range(min(max_rows, len(raw_dataset))))
|
|
|
logger.info("Train rows: %d", len(train_dataset))
|
|
|
return train_dataset
|
|
|
|
|
|
|
|
|
def create_lora_model(model_id: str, device_map: str = "auto") -> PreTrainedModel:
|
|
|
"""
|
|
|
Create a base causal LM and wrap it with LoRA adapters.
|
|
|
|
|
|
Args:
|
|
|
model_id: Hugging Face model identifier to load as the base model
|
|
|
device_map: Device mapping strategy for model loading
|
|
|
|
|
|
Returns:
|
|
|
PreTrainedModel: A transformers.PreTrainedModel with LoRA adapters applied
|
|
|
"""
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_id,
|
|
|
device_map=device_map,
|
|
|
)
|
|
|
|
|
|
lora_cfg = LoraConfig(
|
|
|
r=16,
|
|
|
lora_alpha=32,
|
|
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
|
lora_dropout=0.05,
|
|
|
bias="none",
|
|
|
task_type="CAUSAL_LM",
|
|
|
)
|
|
|
model = get_peft_model(model, lora_cfg)
|
|
|
logger.info("Model with LoRA ready")
|
|
|
return model
|
|
|
|
|
|
|
|
|
def create_sft_config(
|
|
|
output_dir: str,
|
|
|
learning_rate: float = 2e-4,
|
|
|
num_train_epochs: int = 1,
|
|
|
per_device_train_batch_size: int = 4,
|
|
|
gradient_accumulation_steps: int = 4,
|
|
|
max_length: int = 2048,
|
|
|
save_steps: int = 50,
|
|
|
logging_steps: int = 1,
|
|
|
) -> SFTConfig:
|
|
|
"""
|
|
|
Create SFT training configuration.
|
|
|
|
|
|
Args:
|
|
|
output_dir: Directory where checkpoints and logs will be written
|
|
|
learning_rate: Learning rate for training
|
|
|
num_train_epochs: Number of training epochs
|
|
|
per_device_train_batch_size: Batch size per device
|
|
|
gradient_accumulation_steps: Steps to accumulate gradients
|
|
|
max_length: Maximum sequence length
|
|
|
save_steps: Steps between model saves
|
|
|
logging_steps: Steps between log outputs
|
|
|
|
|
|
Returns:
|
|
|
SFTConfig: A configured trl.SFTConfig instance
|
|
|
"""
|
|
|
return SFTConfig(
|
|
|
output_dir=output_dir,
|
|
|
learning_rate=learning_rate,
|
|
|
weight_decay=0.001,
|
|
|
warmup_ratio=0.03,
|
|
|
lr_scheduler_type="cosine",
|
|
|
optim="paged_adamw_32bit",
|
|
|
remove_unused_columns=False,
|
|
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
|
num_train_epochs=num_train_epochs,
|
|
|
bf16=True,
|
|
|
per_device_train_batch_size=per_device_train_batch_size,
|
|
|
|
|
|
max_length=max_length,
|
|
|
packing=False,
|
|
|
|
|
|
report_to=["tensorboard"],
|
|
|
logging_steps=logging_steps,
|
|
|
save_strategy="steps",
|
|
|
save_steps=save_steps,
|
|
|
eval_strategy="no",
|
|
|
)
|
|
|
|
|
|
|
|
|
def create_trainer(
|
|
|
model: PreTrainedModel,
|
|
|
tokenizer: AutoTokenizer,
|
|
|
train_dataset: Dataset,
|
|
|
args: SFTConfig,
|
|
|
) -> SFTTrainer:
|
|
|
"""
|
|
|
Construct an SFTTrainer for supervised fine-tuning.
|
|
|
|
|
|
Args:
|
|
|
model: The LoRA-wrapped pretrained model to train
|
|
|
tokenizer: The tokenizer for the model
|
|
|
train_dataset: The dataset to use for training
|
|
|
args: The SFT configuration
|
|
|
|
|
|
Returns:
|
|
|
SFTTrainer: An initialized trl.SFTTrainer instance
|
|
|
"""
|
|
|
trainer = SFTTrainer(
|
|
|
model=model,
|
|
|
tokenizer=tokenizer,
|
|
|
args=args,
|
|
|
train_dataset=train_dataset,
|
|
|
)
|
|
|
return trainer
|
|
|
|
|
|
|
|
|
def train_and_save(trainer: SFTTrainer, output_dir: str) -> None:
|
|
|
"""
|
|
|
Run training and save the final model to disk.
|
|
|
|
|
|
Args:
|
|
|
trainer: The configured SFT trainer instance
|
|
|
output_dir: Output directory to save the trained model
|
|
|
|
|
|
Returns:
|
|
|
None
|
|
|
"""
|
|
|
train_result = trainer.train()
|
|
|
logger.info("Training complete: %s", str(train_result))
|
|
|
trainer.save_model(output_dir)
|
|
|
logger.info("Saved to %s", output_dir)
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
"""
|
|
|
Run the full SFT training workflow with command-line arguments.
|
|
|
|
|
|
Returns:
|
|
|
None
|
|
|
"""
|
|
|
parser = argparse.ArgumentParser(
|
|
|
description="Train a language model using SFT for arithmetic countdown problems"
|
|
|
)
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--dataset_csv",
|
|
|
type=str,
|
|
|
required=True,
|
|
|
help="Path to the training dataset CSV file",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--max_rows", type=int, default=2000, help="Maximum number of training samples"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--seed", type=int, default=42, help="Random seed for dataset shuffling"
|
|
|
)
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--model_id",
|
|
|
type=str,
|
|
|
default="Qwen/Qwen2.5-3B-Instruct",
|
|
|
help="Hugging Face model identifier",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--device_map", type=str, default="auto", help="Device mapping strategy"
|
|
|
)
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
"--output_dir",
|
|
|
type=str,
|
|
|
required=True,
|
|
|
help="Directory to save model checkpoints and logs",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--learning_rate", type=float, default=2e-4, help="Learning rate"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--num_train_epochs", type=int, default=1, help="Number of training epochs"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--per_device_train_batch_size",
|
|
|
type=int,
|
|
|
default=4,
|
|
|
help="Batch size per device",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--gradient_accumulation_steps",
|
|
|
type=int,
|
|
|
default=4,
|
|
|
help="Gradient accumulation steps",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--max_length",
|
|
|
type=int,
|
|
|
default=2048,
|
|
|
help="Maximum sequence length",
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--save_steps", type=int, default=50, help="Steps between model saves"
|
|
|
)
|
|
|
parser.add_argument(
|
|
|
"--logging_steps", type=int, default=1, help="Steps between log outputs"
|
|
|
)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
if not Path(args.dataset_csv).exists():
|
|
|
logger.error("Dataset CSV file does not exist: %s", args.dataset_csv)
|
|
|
return
|
|
|
|
|
|
if args.max_rows <= 0:
|
|
|
logger.error("max_rows must be positive")
|
|
|
return
|
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
logger.info("Output dir: %s", args.output_dir)
|
|
|
|
|
|
|
|
|
train_dataset = load_train_dataset(args.dataset_csv, args.max_rows, args.seed)
|
|
|
|
|
|
|
|
|
model = create_lora_model(args.model_id, args.device_map)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
|
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
training_args = create_sft_config(
|
|
|
output_dir=args.output_dir,
|
|
|
learning_rate=args.learning_rate,
|
|
|
num_train_epochs=args.num_train_epochs,
|
|
|
per_device_train_batch_size=args.per_device_train_batch_size,
|
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
|
max_length=args.max_length,
|
|
|
save_steps=args.save_steps,
|
|
|
logging_steps=args.logging_steps,
|
|
|
)
|
|
|
|
|
|
|
|
|
trainer = create_trainer(
|
|
|
model=model,
|
|
|
tokenizer=tokenizer,
|
|
|
train_dataset=train_dataset,
|
|
|
args=training_args,
|
|
|
)
|
|
|
|
|
|
|
|
|
train_and_save(trainer=trainer, output_dir=args.output_dir)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|