Dat1710's picture
Upload folder using huggingface_hub
00db46c verified
#!/usr/bin/env python3
"""
SFT training script with Hydra for LoRA (resume supported)
"""
import os
import sys
import logging
from pathlib import Path
# ---------------- FIX IMPORT WHEN USING HYDRA ----------------
FILE_DIR = os.path.dirname(os.path.abspath(__file__)) # src/training/sft
PROJECT_ROOT = os.path.abspath(os.path.join(FILE_DIR, "../../../"))
sys.path.insert(0, PROJECT_ROOT)
# --------------------------------------------------------------
import hydra
from omegaconf import DictConfig, OmegaConf
from datasets import Dataset
from peft import (
LoraConfig,
get_peft_model,
PeftModel,
)
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)
from trl import SFTTrainer, SFTConfig
from huggingface_hub import login
# dataset utils
from src.dataset.sft import (
load_csv_dataset_sft,
map_problem_description_to_conversation_sft,
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("sft_training")
# ------------------------------------------------------------
# Dataset loader
# ------------------------------------------------------------
def load_train_dataset(cfg: DictConfig) -> Dataset:
raw_dataset = load_csv_dataset_sft(
cfg.file_path, map_problem_description_to_conversation_sft
)
raw_dataset = raw_dataset.shuffle(seed=cfg.seed)
train_dataset = raw_dataset.select(range(min(cfg.max_rows, len(raw_dataset))))
logger.info("Train rows: %d", len(train_dataset))
return train_dataset
# ------------------------------------------------------------
# Create model + LoRA
# ------------------------------------------------------------
def create_lora_model(cfg, resume_path=None):
"""
RULE:
- If resume_path provided: load base model then load LoRA adapter
- Else: load base model then attach new LoRA
"""
base_model_id = cfg.model_id
logger.info(f"Loading base model: {base_model_id}")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
device_map=cfg.device_map,
)
if resume_path:
logger.info(f"Resume from LoRA adapter: {resume_path}")
model = PeftModel.from_pretrained(base_model, resume_path)
return model
# Create new LoRA
lora_cfg = LoraConfig(
r=cfg.lora.r,
lora_alpha=cfg.lora.lora_alpha,
target_modules=OmegaConf.to_container(cfg.lora.target_modules),
lora_dropout=cfg.lora.lora_dropout,
bias=cfg.lora.bias,
task_type=cfg.lora.task_type,
)
model = get_peft_model(base_model, lora_cfg)
logger.info("New LoRA model created")
return model
# ------------------------------------------------------------
# TrainingConfig
# ------------------------------------------------------------
def build_sft_config(cfg, output_dir):
return SFTConfig(
output_dir=output_dir,
learning_rate=cfg.learning_rate,
weight_decay=cfg.weight_decay,
warmup_ratio=cfg.warmup_ratio,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
per_device_train_batch_size=cfg.per_device_train_batch_size,
num_train_epochs=cfg.num_train_epochs,
max_length=cfg.max_length,
bf16=cfg.bf16,
fp16=cfg.fp16,
logging_steps=cfg.logging_steps,
save_strategy=cfg.save_strategy,
save_steps=cfg.save_steps,
report_to=cfg.report_to,
lr_scheduler_type=cfg.lr_scheduler_type,
optim=cfg.optim,
remove_unused_columns=cfg.remove_unused_columns,
)
# ------------------------------------------------------------
# Trainer
# ------------------------------------------------------------
def create_trainer(model, tokenizer, train_dataset, training_args):
return SFTTrainer(
model=model,
train_dataset=train_dataset,
args=training_args,
tokenizer=tokenizer,
)
# ------------------------------------------------------------
# Train & Save (LoRA ONLY)
# ------------------------------------------------------------
def train_and_save(trainer, output_dir, tokenizer, hf_repo_id=None):
logger.info("Start training...")
trainer.train()
logger.info("Training finished")
# SAVE ONLY LORA ADAPTER
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
logger.info(f"Saved LoRA adapter to: {output_dir}")
if hf_repo_id:
logger.info(f"Pushing adapter to HF Hub: {hf_repo_id}")
trainer.model.push_to_hub(hf_repo_id)
tokenizer.push_to_hub(hf_repo_id)
# ------------------------------------------------------------
# MAIN
# ------------------------------------------------------------
@hydra.main(version_base=None, config_path="../../config/sft", config_name="config")
def main(cfg: DictConfig):
print("Loaded config:")
print(OmegaConf.to_yaml(cfg))
# Login HF (optional)
if cfg.get("hf_token", None):
login(cfg.hf_token)
logger.info("Logged into HF")
# Check dataset
if not Path(cfg.dataset.file_path).exists():
logger.error(f"Dataset not found: {cfg.dataset.file_path}")
return
os.makedirs(cfg.output_dir, exist_ok=True)
# Load dataset
train_dataset = load_train_dataset(cfg.dataset)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(cfg.model.model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Create model (resume if available)
resume = cfg.get("resume_from", None)
model = create_lora_model(cfg.model, resume)
# Training configuration
training_args = build_sft_config(cfg.training, cfg.output_dir)
# Trainer
trainer = create_trainer(model, tokenizer, train_dataset, training_args)
# Train & Save
train_and_save(
trainer=trainer,
output_dir=cfg.output_dir,
tokenizer=tokenizer,
hf_repo_id=cfg.get("hf_repo_id", None),
)
if __name__ == "__main__":
main()