linalg-zero / linalg_zero /sft_train.py
atomwalk12's picture
initial commit
0dd6c2f
import os
from typing import Any
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
import unsloth # noqa: I001, F401
import logging
import sys
import transformers
from datasets import DatasetDict, load_dataset
from datasets.load import DownloadMode
from datasets.utils.logging import set_verbosity
from transformers.trainer_utils import get_last_checkpoint, set_seed
from trl.scripts.utils import TrlParser
from trl.trainer.sft_config import SFTConfig
from trl.trainer.sft_trainer import SFTTrainer
from linalg_zero.config.data import ScriptArguments, SFTModelConfig, SFTRunConfig
from linalg_zero.sft.callbacks import get_callbacks
from linalg_zero.sft.utils import (
ensure_tokenizer_has_defaults,
get_unsloth_model,
init_wandb_training,
load_merged_model_for_sft,
)
from linalg_zero.shared.utils import get_logger, setup_logging
def main( # noqa: C901
script_args: ScriptArguments, training_args: SFTRunConfig, trl_training_args: SFTConfig, model_args: SFTModelConfig
) -> None:
"""Main training function."""
# Reproducibility
set_seed(trl_training_args.seed)
#################
# Setup logging #
#################
# Log both to file and console
setup_logging(level=logging.INFO, include_timestamp=True)
logger = get_logger(__name__)
# Adjust script logging level based on the node logging level (main process or replica)
log_level = trl_training_args.get_process_log_level()
logger.setLevel(log_level)
set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.info(f"Model parameters: {model_args}")
logger.info(f"Script parameters: {script_args}")
logger.info(f"Training parameters: {training_args}")
logger.info(f"TRL training parameters: {trl_training_args}")
# Check for last checkpoint
last_checkpoint = None
if trl_training_args.output_dir and os.path.isdir(trl_training_args.output_dir):
last_checkpoint = get_last_checkpoint(trl_training_args.output_dir)
if last_checkpoint is not None and trl_training_args.resume_from_checkpoint is None:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}")
# Initialize wandb if requested
if trl_training_args.report_to and "wandb" in trl_training_args.report_to:
init_wandb_training(training_args)
######################################
# Load dataset, tokenizer, and model #
######################################
logger.info(f"Loading dataset from {script_args.dataset_name}...")
dataset = load_dataset(
script_args.dataset_name, script_args.dataset_config, download_mode=DownloadMode.FORCE_REDOWNLOAD
)
if not isinstance(dataset, DatasetDict):
raise TypeError(f"Expected dataset to be a DatasetDict, but got {type(dataset)}")
# Model, tokenizer, dataset
logger.info("Loading model and tokenizer...")
if getattr(model_args, "use_peft", True):
# Standard LoRA SFT on base model
model, tokenizer = get_unsloth_model(model_args, training_args, trl_training_args, resume_path=last_checkpoint)
else:
# Light touch-up on a merged model: train only I/O layers if requested
max_seq_len = training_args.max_seq_length or trl_training_args.max_seq_length
model, tokenizer = load_merged_model_for_sft(
model_path=model_args.model_name_or_path,
max_seq_length=max_seq_len,
dtype=None,
train_io_only=True,
add_special_tokens=training_args.add_special_tokens,
)
# Ensure pad token and padding side are set consistently for SFT
ensure_tokenizer_has_defaults(tokenizer, model)
def ensure_text(x: dict[str, Any]) -> dict[str, Any]:
x["text"] = tokenizer.apply_chat_template(x["messages"], tools=x["tools"], tokenize=False)
return x
def formatting_prompts_func(examples):
convos = examples["messages"] # List of 1000 conversations
tools = examples.get("tools", None) # List of 1000 tool specs
texts = []
for i, convo in enumerate(convos):
example_tools = tools[i] if tools and isinstance(tools, list) else tools
text = tokenizer.apply_chat_template(
convo,
tools=example_tools, # Pass tools[i] for the i-th conversation
tokenize=False,
add_generation_prompt=False,
)
texts.append(text)
return {"text": texts}
dataset = dataset.map(formatting_prompts_func, batched=True)
##############################
# Initialize the SFT Trainer #
##############################
trl_training_args.max_eval_samples = training_args.max_eval_samples
trl_training_args.eval_max_new_tokens = training_args.eval_max_new_tokens
logger.info("Initializing SFT Trainer...")
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(dataset[script_args.dataset_test_split] if trl_training_args.eval_strategy != "no" else None),
args=trl_training_args,
callbacks=get_callbacks(training_args, model_args, script_args, dataset),
)
#################
# Training loop #
#################
logger.info("*** Starting Training ***")
checkpoint = None
if trl_training_args.resume_from_checkpoint is not None:
checkpoint = trl_training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
try:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info("Training completed successfully!")
except KeyboardInterrupt:
logger.info("Training interrupted by user.")
except Exception:
logger.exception("Training failed with an unexpected error")
raise
####################################
# Save model and create model card #
####################################
logger.info("*** Saving Model ***")
try:
# Align the model's generation config with the tokenizer's eos token
# to avoid unbounded generation in the transformers `pipeline()` function
if trainer.model is not None and trainer.model.generation_config is not None:
trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id
assert trainer.model.generation_config.pad_token_id == tokenizer.pad_token_id, "Pad token ID mismatch"
# Restore k,v cache for fast inference before saving
if trainer.model is not None:
trainer.model.config.use_cache = True
trainer.save_model(trl_training_args.output_dir)
logger.info(f"Model saved to {trl_training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["linalg-zero", "sft", "tool-use", "linear-algebra"],
"model_name": model_args.model_name_or_path,
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
except Exception:
logger.exception("Failed to save model")
raise
############
# Evaluate #
############
if trl_training_args.do_eval:
logger.info("*** Final Evaluation on Full Dataset ***")
try:
# Temporarily override max_eval_samples to evaluate on full dataset
original_max_eval_samples = getattr(trl_training_args, "max_eval_samples", None)
trl_training_args.max_eval_samples = training_args.final_eval_max_samples
metrics = trainer.evaluate()
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
logger.info("Evaluation completed successfully!")
# Restore original value
trl_training_args.max_eval_samples = original_max_eval_samples
except Exception:
logger.exception("Evaluation failed")
###############
# Push to hub #
###############
if trl_training_args.push_to_hub:
logger.info("*** Pushing to Hub ***")
try:
trainer.push_to_hub(**kwargs)
logger.info("Successfully pushed model to HuggingFace Hub!")
except Exception:
logger.exception("Failed to push to hub")
if __name__ == "__main__":
"""Script entry point for SFT training."""
if "--config" not in sys.argv:
sys.argv.append("--config")
sys.argv.append("linalg_zero/config/sft/qwen2.5-3B/production_merged.yaml")
# sys.argv.append("linalg_zero/config/sft/qwen2.5-3B/production_instruct.yaml")
# sys.argv.append("linalg_zero/config/sft/qwen2.5-3B/production.yaml")
parser = TrlParser([ScriptArguments, SFTRunConfig, SFTConfig, SFTModelConfig])
script_args, training_args, trl_training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, trl_training_args, model_args)