#!/usr/bin/env python3 """ SFT training script with Hydra for LoRA (resume supported) """ import os import sys import logging from pathlib import Path # ---------------- FIX IMPORT WHEN USING HYDRA ---------------- FILE_DIR = os.path.dirname(os.path.abspath(__file__)) # src/training/sft PROJECT_ROOT = os.path.abspath(os.path.join(FILE_DIR, "../../../")) sys.path.insert(0, PROJECT_ROOT) # -------------------------------------------------------------- import hydra from omegaconf import DictConfig, OmegaConf from datasets import Dataset from peft import ( LoraConfig, get_peft_model, PeftModel, ) from transformers import ( AutoTokenizer, AutoModelForCausalLM, ) from trl import SFTTrainer, SFTConfig from huggingface_hub import login # dataset utils from src.dataset.sft import ( load_csv_dataset_sft, map_problem_description_to_conversation_sft, ) logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger("sft_training") # ------------------------------------------------------------ # Dataset loader # ------------------------------------------------------------ def load_train_dataset(cfg: DictConfig) -> Dataset: raw_dataset = load_csv_dataset_sft( cfg.file_path, map_problem_description_to_conversation_sft ) raw_dataset = raw_dataset.shuffle(seed=cfg.seed) train_dataset = raw_dataset.select(range(min(cfg.max_rows, len(raw_dataset)))) logger.info("Train rows: %d", len(train_dataset)) return train_dataset # ------------------------------------------------------------ # Create model + LoRA # ------------------------------------------------------------ def create_lora_model(cfg, resume_path=None): """ RULE: - If resume_path provided: load base model then load LoRA adapter - Else: load base model then attach new LoRA """ base_model_id = cfg.model_id logger.info(f"Loading base model: {base_model_id}") base_model = AutoModelForCausalLM.from_pretrained( base_model_id, device_map=cfg.device_map, ) if resume_path: logger.info(f"Resume from LoRA adapter: {resume_path}") model = PeftModel.from_pretrained(base_model, resume_path) return model # Create new LoRA lora_cfg = LoraConfig( r=cfg.lora.r, lora_alpha=cfg.lora.lora_alpha, target_modules=OmegaConf.to_container(cfg.lora.target_modules), lora_dropout=cfg.lora.lora_dropout, bias=cfg.lora.bias, task_type=cfg.lora.task_type, ) model = get_peft_model(base_model, lora_cfg) logger.info("New LoRA model created") return model # ------------------------------------------------------------ # TrainingConfig # ------------------------------------------------------------ def build_sft_config(cfg, output_dir): return SFTConfig( output_dir=output_dir, learning_rate=cfg.learning_rate, weight_decay=cfg.weight_decay, warmup_ratio=cfg.warmup_ratio, gradient_accumulation_steps=cfg.gradient_accumulation_steps, per_device_train_batch_size=cfg.per_device_train_batch_size, num_train_epochs=cfg.num_train_epochs, max_length=cfg.max_length, bf16=cfg.bf16, fp16=cfg.fp16, logging_steps=cfg.logging_steps, save_strategy=cfg.save_strategy, save_steps=cfg.save_steps, report_to=cfg.report_to, lr_scheduler_type=cfg.lr_scheduler_type, optim=cfg.optim, remove_unused_columns=cfg.remove_unused_columns, ) # ------------------------------------------------------------ # Trainer # ------------------------------------------------------------ def create_trainer(model, tokenizer, train_dataset, training_args): return SFTTrainer( model=model, train_dataset=train_dataset, args=training_args, tokenizer=tokenizer, ) # ------------------------------------------------------------ # Train & Save (LoRA ONLY) # ------------------------------------------------------------ def train_and_save(trainer, output_dir, tokenizer, hf_repo_id=None): logger.info("Start training...") trainer.train() logger.info("Training finished") # SAVE ONLY LORA ADAPTER trainer.model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) logger.info(f"Saved LoRA adapter to: {output_dir}") if hf_repo_id: logger.info(f"Pushing adapter to HF Hub: {hf_repo_id}") trainer.model.push_to_hub(hf_repo_id) tokenizer.push_to_hub(hf_repo_id) # ------------------------------------------------------------ # MAIN # ------------------------------------------------------------ @hydra.main(version_base=None, config_path="../../config/sft", config_name="config") def main(cfg: DictConfig): print("Loaded config:") print(OmegaConf.to_yaml(cfg)) # Login HF (optional) if cfg.get("hf_token", None): login(cfg.hf_token) logger.info("Logged into HF") # Check dataset if not Path(cfg.dataset.file_path).exists(): logger.error(f"Dataset not found: {cfg.dataset.file_path}") return os.makedirs(cfg.output_dir, exist_ok=True) # Load dataset train_dataset = load_train_dataset(cfg.dataset) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(cfg.model.model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Create model (resume if available) resume = cfg.get("resume_from", None) model = create_lora_model(cfg.model, resume) # Training configuration training_args = build_sft_config(cfg.training, cfg.output_dir) # Trainer trainer = create_trainer(model, tokenizer, train_dataset, training_args) # Train & Save train_and_save( trainer=trainer, output_dir=cfg.output_dir, tokenizer=tokenizer, hf_repo_id=cfg.get("hf_repo_id", None), ) if __name__ == "__main__": main()