Dat1710's picture
Upload folder using huggingface_hub
00db46c verified
#!/usr/bin/env python3
"""
SFT training script for arithmetic countdown problems.
This script trains a language model using SFT (Supervised Fine-Tuning)
to solve arithmetic problems with proper reasoning and formatting.
"""
import argparse
import logging
import os
from pathlib import Path
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from trl import SFTConfig, SFTTrainer
from src.dataset.sft import (
load_csv_dataset_sft,
map_problem_description_to_conversation_sft,
)
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("sft_training")
def load_train_dataset(
dataset_csv: str, max_rows: int = 2000, seed: int = 42
) -> Dataset:
"""
Load, shuffle, and subsample the training dataset.
Args:
dataset_csv: Absolute path to the dataset CSV file
max_rows: Maximum number of rows to select for training
seed: Seed for dataset shuffling
Returns:
Dataset: A datasets.Dataset ready for SFT training
"""
raw_dataset: Dataset = load_csv_dataset_sft(
dataset_csv, "train", map_problem_description_to_conversation_sft
)
raw_dataset = raw_dataset.shuffle(seed=seed)
train_dataset = raw_dataset.select(range(min(max_rows, len(raw_dataset))))
logger.info("Train rows: %d", len(train_dataset))
return train_dataset
def create_lora_model(model_id: str, device_map: str = "auto") -> PreTrainedModel:
"""
Create a base causal LM and wrap it with LoRA adapters.
Args:
model_id: Hugging Face model identifier to load as the base model
device_map: Device mapping strategy for model loading
Returns:
PreTrainedModel: A transformers.PreTrainedModel with LoRA adapters applied
"""
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
)
lora_cfg = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_cfg)
logger.info("Model with LoRA ready")
return model
def create_sft_config(
output_dir: str,
learning_rate: float = 2e-4,
num_train_epochs: int = 1,
per_device_train_batch_size: int = 4,
gradient_accumulation_steps: int = 4,
max_length: int = 2048,
save_steps: int = 50,
logging_steps: int = 1,
) -> SFTConfig:
"""
Create SFT training configuration.
Args:
output_dir: Directory where checkpoints and logs will be written
learning_rate: Learning rate for training
num_train_epochs: Number of training epochs
per_device_train_batch_size: Batch size per device
gradient_accumulation_steps: Steps to accumulate gradients
max_length: Maximum sequence length
save_steps: Steps between model saves
logging_steps: Steps between log outputs
Returns:
SFTConfig: A configured trl.SFTConfig instance
"""
return SFTConfig(
output_dir=output_dir,
learning_rate=learning_rate,
weight_decay=0.001,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
optim="paged_adamw_32bit",
remove_unused_columns=False,
gradient_accumulation_steps=gradient_accumulation_steps,
num_train_epochs=num_train_epochs,
bf16=True,
per_device_train_batch_size=per_device_train_batch_size,
# SFT-specific parameters
max_length=max_length,
packing=False,
# Logging and saving
report_to=["tensorboard"],
logging_steps=logging_steps,
save_strategy="steps",
save_steps=save_steps,
eval_strategy="no",
)
def create_trainer(
model: PreTrainedModel,
tokenizer: AutoTokenizer,
train_dataset: Dataset,
args: SFTConfig,
) -> SFTTrainer:
"""
Construct an SFTTrainer for supervised fine-tuning.
Args:
model: The LoRA-wrapped pretrained model to train
tokenizer: The tokenizer for the model
train_dataset: The dataset to use for training
args: The SFT configuration
Returns:
SFTTrainer: An initialized trl.SFTTrainer instance
"""
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=args,
train_dataset=train_dataset,
)
return trainer
def train_and_save(trainer: SFTTrainer, output_dir: str) -> None:
"""
Run training and save the final model to disk.
Args:
trainer: The configured SFT trainer instance
output_dir: Output directory to save the trained model
Returns:
None
"""
train_result = trainer.train()
logger.info("Training complete: %s", str(train_result))
trainer.save_model(output_dir)
logger.info("Saved to %s", output_dir)
def main() -> None:
"""
Run the full SFT training workflow with command-line arguments.
Returns:
None
"""
parser = argparse.ArgumentParser(
description="Train a language model using SFT for arithmetic countdown problems"
)
# Dataset arguments
parser.add_argument(
"--dataset_csv",
type=str,
required=True,
help="Path to the training dataset CSV file",
)
parser.add_argument(
"--max_rows", type=int, default=2000, help="Maximum number of training samples"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for dataset shuffling"
)
# Model arguments
parser.add_argument(
"--model_id",
type=str,
default="Qwen/Qwen2.5-3B-Instruct",
help="Hugging Face model identifier",
)
parser.add_argument(
"--device_map", type=str, default="auto", help="Device mapping strategy"
)
# Training arguments
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Directory to save model checkpoints and logs",
)
parser.add_argument(
"--learning_rate", type=float, default=2e-4, help="Learning rate"
)
parser.add_argument(
"--num_train_epochs", type=int, default=1, help="Number of training epochs"
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=4,
help="Batch size per device",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=4,
help="Gradient accumulation steps",
)
parser.add_argument(
"--max_length",
type=int,
default=2048,
help="Maximum sequence length",
)
parser.add_argument(
"--save_steps", type=int, default=50, help="Steps between model saves"
)
parser.add_argument(
"--logging_steps", type=int, default=1, help="Steps between log outputs"
)
args = parser.parse_args()
# Validate arguments
if not Path(args.dataset_csv).exists():
logger.error("Dataset CSV file does not exist: %s", args.dataset_csv)
return
if args.max_rows <= 0:
logger.error("max_rows must be positive")
return
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
logger.info("Output dir: %s", args.output_dir)
# Load dataset
train_dataset = load_train_dataset(args.dataset_csv, args.max_rows, args.seed)
# Create model and tokenizer
model = create_lora_model(args.model_id, args.device_map)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
# Set pad token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Create training configuration
training_args = create_sft_config(
output_dir=args.output_dir,
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
max_length=args.max_length,
save_steps=args.save_steps,
logging_steps=args.logging_steps,
)
# Create trainer
trainer = create_trainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
args=training_args,
)
# Train and save
train_and_save(trainer=trainer, output_dir=args.output_dir)
if __name__ == "__main__":
main()