task2file-llm / trainer-kit /SFT /run_instruct.py
SirajRLX's picture
Add Training Scripts
e527a65 verified
import argparse
import json
import inspect # Added for Transformers version compatibility
import math
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, List
import torch
import yaml
from datasets import load_dataset, DatasetDict
from huggingface_hub import snapshot_download
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModel,
AutoConfig,
BitsAndBytesConfig,
TrainingArguments,
Trainer,
TrainerCallback,
EarlyStoppingCallback,
default_data_collator,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
PeftModel,
)
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
wandb = None
# --------------------------
# Helpers
# --------------------------
def _dtype_from_str(s: str) -> torch.dtype:
s = (s or "").lower()
if s in ("float16", "fp16"):
return torch.float16
if s in ("bfloat16", "bf16"):
return torch.bfloat16
if s in ("float32", "fp32"):
return torch.float32
raise ValueError(f"Unknown torch_dtype: {s}")
def _now_iso() -> str:
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
def _safe_exp(x: float) -> float:
x = min(float(x), 50.0)
return float(math.exp(x))
def _ensure_dir(p: Path) -> Path:
p.mkdir(parents=True, exist_ok=True)
return p
def _looks_like_model_dir(p: Path) -> bool:
if not p.exists() or not p.is_dir():
return False
if (p / "config.json").exists():
return True
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
return True
return False
def _infer_target_modules(model) -> List[str]:
names = set()
for n, _ in model.named_modules():
names.add(n.split(".")[-1])
for group in [
["q_proj", "k_proj", "v_proj", "o_proj"],
["Wqkv", "out_proj"],
["query_key_value", "dense"],
["c_attn", "c_proj"],
]:
if all(x in names for x in group):
return group
fallback = [
x
for x in [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"c_attn",
"c_proj",
"out_proj",
"dense",
]
if x in names
]
if fallback:
return fallback
raise ValueError(
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
)
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
return cfg.get("model", {}).get("attn_implementation", None)
# --------------------------
# Wandb Integration
# --------------------------
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
"""Initialize Wandb if enabled in configuration."""
wandb_cfg = cfg.get("wandb", {})
if not wandb_cfg.get("enabled", False):
print("Wandb logging disabled")
return None
if not WANDB_AVAILABLE:
print("Wandb not available. Install with: pip install wandb")
return None
# Extract wandb configuration
project = wandb_cfg.get("project", "sft-training")
entity = wandb_cfg.get("entity", None)
name = wandb_cfg.get("name", None)
tags = wandb_cfg.get("tags", [])
notes = wandb_cfg.get("notes", None)
# Initialize wandb
try:
wandb.init(
project=project,
entity=entity,
name=name,
tags=tags,
notes=notes,
dir=str(run_dir),
config={
"model": cfg.get("model", {}),
"data": cfg.get("data", {}),
"peft": cfg.get("peft", {}),
"train": cfg.get("train", {}),
"run_dir": str(run_dir),
}
)
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
return wandb
except Exception as e:
print(f"Failed to initialize Wandb: {e}")
return None
def finish_wandb():
"""Finish Wandb run if active."""
if WANDB_AVAILABLE and wandb.run is not None:
wandb.finish()
print("Wandb run finished")
# --------------------------
# JSONL Logger Callback
# --------------------------
class JsonlLoggerCallback(TrainerCallback):
def __init__(self, run_dir: Path):
self.run_dir = run_dir
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
self.start_time = None
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
if self.start_time is None or global_step <= 0 or max_steps <= 0:
return None
elapsed = time.time() - self.start_time
sec_per_step = elapsed / global_step
remaining = max(0, max_steps - global_step) * sec_per_step
h = int(remaining // 3600)
m = int((remaining % 3600) // 60)
s = int(remaining % 60)
return f"{h:02d}:{m:02d}:{s:02d}"
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs:
return
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
progress_pct = (
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
)
epoch_pct = None
if (
state.epoch is not None
and args.num_train_epochs
and args.num_train_epochs > 0
):
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
payload = {
"ts": _now_iso(),
"event": "train_log",
"step": int(state.global_step),
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
"progress_pct": (
round(progress_pct, 2) if progress_pct is not None else None
),
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
"eta": self._eta(int(state.global_step), max_steps),
"max_grad_norm": getattr(args, "max_grad_norm", None),
**logs,
}
with self.train_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if not metrics:
return
eval_loss = metrics.get("eval_loss", None)
ppl = _safe_exp(eval_loss) if eval_loss is not None else None
payload = {
"ts": _now_iso(),
"event": "eval",
"step": int(state.global_step),
"epoch": float(state.epoch) if state.epoch is not None else None,
**metrics,
"perplexity": ppl,
}
with self.eval_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
# --------------------------
# Data Pipeline (Instruction Formatting)
# --------------------------
def format_instruction(
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
) -> Dict[str, Any]:
"""
Format instruction data for training.
Supports multiple formats: chatml, alpaca, custom templates.
Returns both formatted text and the response start position for loss masking.
"""
data_cfg = cfg["data"]
format_type = data_cfg.get("format_type", "chatml")
# Get field names from config
input_field = data_cfg.get("input_field", "input")
output_field = data_cfg.get("output_field", "output")
instruction_field = data_cfg.get("instruction_field", "instruction")
# Extract text from example
instruction = example.get(instruction_field, "")
input_text = example.get(input_field, "")
output_text = example.get(output_field, "")
if format_type == "chatml":
# ChatML format with special tokens
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
user_content = instruction
if input_text:
user_content = f"{instruction}\n\n{input_text}"
messages.append({"role": "user", "content": user_content})
messages.append({"role": "assistant", "content": output_text})
# Apply chat template
formatted_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
# Add EOS token if not present
if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
formatted_text += tokenizer.eos_token
# Find where the assistant response starts for loss masking
# Try multiple possible markers for robustness
markers = ["<|im_start|>assistant", "<|assistant|>", "Assistant:", "assistant\n"]
response_start_pos = -1
for marker in markers:
idx = formatted_text.find(marker)
if idx != -1:
# Find the newline after the marker
newline_idx = formatted_text.find("\n", idx)
if newline_idx != -1:
response_start_pos = newline_idx + 1
break
# Fallback: find where the actual output starts
if response_start_pos == -1:
output_idx = formatted_text.find(output_text)
if output_idx != -1:
response_start_pos = output_idx
else:
# Last resort: split at last occurrence of newline before end
response_start_pos = formatted_text.rfind("\n", 0, len(formatted_text) - len(output_text)) + 1
elif format_type == "alpaca":
# Alpaca format
if input_text:
prefix = f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
else:
prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n"
formatted_text = prefix + output_text
# Add EOS token
if tokenizer.eos_token:
formatted_text += tokenizer.eos_token
# Response starts after the prefix
response_start_pos = len(prefix)
elif format_type == "custom":
# Custom template from config
template = data_cfg.get("custom_template", "{instruction}\n{input}\n{output}")
# For custom format, use system_prompt as instruction if instruction field is empty
if not instruction:
instruction = data_cfg.get("system_prompt", "")
# For custom templates, we need to find where {output} starts
template_parts = template.split("{output}")
prefix = template_parts[0].format(instruction=instruction, input=input_text)
formatted_text = prefix + output_text
# Add EOS token if not already in template
if tokenizer.eos_token and not formatted_text.endswith(tokenizer.eos_token):
formatted_text += tokenizer.eos_token
# Response starts after the prefix
response_start_pos = len(prefix)
else:
raise ValueError(f"Unsupported format_type: {format_type}")
return {"text": formatted_text, "response_start_pos": response_start_pos}
def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
"""
Build datasets for instruction fine-tuning.
"""
data_cfg = cfg["data"]
train_path = data_cfg["train_jsonl"]
eval_path = data_cfg.get("eval_jsonl", None)
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
max_length = int(data_cfg.get("max_length", 2048))
shuffle = bool(data_cfg.get("shuffle", True))
num_proc = int(data_cfg.get("num_proc", 4))
# Ensure tokenizer has pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load datasets
ds = load_dataset("json", data_files={"train": train_path})
if eval_path:
ds_eval = load_dataset("json", data_files={"eval": eval_path})
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
else:
if 0.0 < split_ratio < 1.0:
split = ds["train"].train_test_split(
test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
)
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
else:
dsd = DatasetDict({"train": ds["train"], "eval": None})
# Format instructions and track response start positions
def format_fn(examples):
formatted_examples = []
response_start_positions = []
for i in range(len(examples[list(examples.keys())[0]])):
example = {k: examples[k][i] for k in examples.keys()}
formatted = format_instruction(example, cfg, tokenizer)
formatted_examples.append(formatted["text"])
response_start_positions.append(formatted["response_start_pos"])
return {
"text": formatted_examples,
"response_start_pos": response_start_positions
}
formatted_train = dsd["train"].map(
format_fn,
batched=True,
num_proc=num_proc,
remove_columns=dsd["train"].column_names,
desc="Formatting train instructions",
)
formatted_eval = None
if dsd["eval"] is not None:
formatted_eval = dsd["eval"].map(
format_fn,
batched=True,
num_proc=num_proc,
remove_columns=dsd["eval"].column_names,
desc="Formatting eval instructions",
)
# Tokenize and apply loss masking
def tokenize_and_mask_fn(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding=False,
max_length=max_length,
return_overflowing_tokens=False,
)
# Apply loss masking - CRITICAL for SFT
labels = []
attention_masks = []
for i in range(len(tokenized["input_ids"])):
input_ids = tokenized["input_ids"][i]
response_start_pos = examples["response_start_pos"][i]
# Get the instruction part (before response)
full_text = examples["text"][i]
instruction_text = full_text[:response_start_pos]
# Create labels masked by default
label_ids = [-100] * len(input_ids)
# Find where response starts using character-based ratio
# This is more reliable than tokenizing prefix separately
# because separate tokenization can add different special tokens
char_ratio = response_start_pos / max(len(full_text), 1)
response_start_idx = int(len(input_ids) * char_ratio)
# Ensure we have valid bounds (at least position 1, at most len-1)
response_start_idx = max(1, min(response_start_idx, len(input_ids) - 1))
# Unmask response tokens (including EOS)
for j in range(response_start_idx, len(input_ids)):
label_ids[j] = input_ids[j]
# Create attention mask (1 for real tokens, 0 for padding)
attention_mask = [1] * len(input_ids)
labels.append(label_ids)
attention_masks.append(attention_mask)
tokenized["labels"] = labels
tokenized["attention_mask"] = attention_masks
return tokenized
tokenized_train = formatted_train.map(
tokenize_and_mask_fn,
batched=True,
num_proc=num_proc,
desc="Tokenizing and masking train",
)
tokenized_eval = None
if formatted_eval is not None:
tokenized_eval = formatted_eval.map(
tokenize_and_mask_fn,
batched=True,
num_proc=num_proc,
desc="Tokenizing and masking eval",
)
if shuffle:
tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
return tokenized_train, tokenized_eval
# --------------------------
# Model Loading + PEFT
# --------------------------
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
model_cfg = cfg["model"]
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
device_map = model_cfg.get("device_map", "auto")
tokenizer = AutoTokenizer.from_pretrained(
str(base_dir),
use_fast=use_fast,
trust_remote_code=trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
use_4bit = bool(model_cfg.get("use_4bit", False))
quant_cfg = None
if use_4bit:
quant_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
bnb_4bit_use_double_quant=bool(
model_cfg.get("bnb_4bit_use_double_quant", True)
),
bnb_4bit_compute_dtype=_dtype_from_str(
model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
),
)
attn_impl = _choose_attn_impl(cfg)
# First check the model type to determine loading strategy
try:
config = AutoConfig.from_pretrained(str(base_dir), trust_remote_code=True)
model_type = config.model_type
architectures = getattr(config, 'architectures', [])
# Handle Mistral3 (multimodal) models
if model_type == "mistral3" or (architectures and "Mistral3" in architectures[0]):
print(f"[info] Detected Mistral3 model architecture, loading with specific class")
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
try:
model = Mistral3ForConditionalGeneration.from_pretrained(
str(base_dir),
config=config,
device_map=device_map,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
attn_implementation=attn_impl,
)
except Exception as e:
if attn_impl is not None:
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
print("[warn] Falling back to default attention implementation.")
model = Mistral3ForConditionalGeneration.from_pretrained(
str(base_dir),
config=config,
device_map=device_map,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
)
else:
raise e
else:
# Standard AutoModelForCausalLM loading for other models
try:
model = AutoModelForCausalLM.from_pretrained(
str(base_dir),
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
attn_implementation=attn_impl,
)
except Exception as e:
if attn_impl is not None:
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
print("[warn] Falling back to default attention implementation.")
model = AutoModelForCausalLM.from_pretrained(
str(base_dir),
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
)
else:
raise e
except Exception as e:
print(f"[error] Failed to load model: {e}")
raise e
# Ensure all parameters are off meta device
print("[info] Ensuring all parameters are materialized...")
meta_params = []
for name, param in model.named_parameters():
if param.device.type == 'meta':
meta_params.append(name)
if meta_params:
print(f"[warn] Found {len(meta_params)} parameters on meta device")
# For multimodal models, freeze vision components if doing text-only training
if hasattr(model, 'vision_tower'):
print("[info] Freezing vision tower for text-only training")
for param in model.vision_tower.parameters():
param.requires_grad = False
return model, tokenizer
def apply_peft(cfg: Dict[str, Any], model):
peft_cfg = cfg["peft"]
model_cfg = cfg["model"]
tr_cfg = cfg["train"]
if not bool(peft_cfg.get("enabled", True)):
return model, None
use_4bit = bool(model_cfg.get("use_4bit", False))
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
# For multimodal models, ensure vision tower doesn't use gradient checkpointing
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
if hasattr(model, 'vision_tower'):
print("[info] Disabling gradient checkpointing for vision tower")
# Only enable gradient checkpointing on language model
if hasattr(model, 'language_model'):
model.language_model.gradient_checkpointing_enable()
elif hasattr(model, 'lm_head'):
model.gradient_checkpointing_enable()
else:
model.gradient_checkpointing_enable()
if hasattr(model, "config"):
model.config.use_cache = False
if use_4bit:
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=gradient_checkpointing,
)
target_modules = peft_cfg.get("target_modules", "auto")
if target_modules == "auto":
target_modules = _infer_target_modules(model)
# For multimodal models, ensure we only target language model modules
if hasattr(model, 'vision_tower') and isinstance(target_modules, list):
print(f"[info] Filtering target modules to exclude vision tower")
# Filter out any vision tower modules
target_modules = [m for m in target_modules if 'vision' not in m.lower()]
print(f"[info] LoRA target modules: {target_modules}")
lora_config = LoraConfig(
r=int(peft_cfg.get("r", 16)),
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
bias=str(peft_cfg.get("bias", "none")),
task_type="CAUSAL_LM",
target_modules=target_modules,
modules_to_save=None, # Don't update any additional modules
)
model = get_peft_model(model, lora_config)
return model, lora_config
# --------------------------
# Merge Logic
# --------------------------
def merge_adapter(
cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
):
print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
model_cfg = cfg["model"]
merge_cfg = cfg.get("merge", {})
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
base = AutoModelForCausalLM.from_pretrained(
str(base_dir),
torch_dtype=merged_dtype,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code,
)
merged = PeftModel.from_pretrained(base, str(adapter_dir))
merged = merged.merge_and_unload()
_ensure_dir(final_dir)
merged.save_pretrained(
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
)
tok = AutoTokenizer.from_pretrained(
str(base_dir), trust_remote_code=trust_remote_code
)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.save_pretrained(str(final_dir))
print("--- Merge complete ---")
# --------------------------
# Main
# --------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True, help="Path to YAML config")
ap.add_argument(
"--merge-only", action="store_true", help="Skip training, just merge adapter"
)
args = ap.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
_ensure_dir(run_dir / "logs")
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
yaml.safe_dump(cfg, f, sort_keys=False)
model_cfg = cfg["model"]
repo_id = str(model_cfg["repo_id"]).strip()
repo_path = Path(repo_id)
# ✅ Local model path -> load directly; no download
if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
base_dir = repo_path
print(f"Using local model at: {base_dir}")
elif repo_path.exists() and repo_path.is_dir():
raise ValueError(
f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
)
else:
# HF repo_id -> download into run_dir/base_local_dir
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
if not _looks_like_model_dir(base_dir):
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
snapshot_download(
repo_id=repo_id,
revision=model_cfg.get("revision", None),
local_dir=str(base_dir),
local_dir_use_symlinks=False,
)
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
merge_cfg = cfg.get("merge", {}) or {}
if merge_cfg.get("output_dir"):
od = Path(str(merge_cfg["output_dir"]))
final_dir = od if od.is_absolute() else (run_dir / od)
else:
final_dir = run_dir / "final_model"
# Merge-only
if args.merge_only:
if not _looks_like_model_dir(best_adapter_dir):
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
return
# Initialize Wandb
wandb_run = setup_wandb(cfg, run_dir)
# Training
set_seed(int(cfg["run"].get("seed", 42)))
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
model, _ = apply_peft(cfg, model)
train_ds, eval_ds = build_datasets(cfg, tokenizer)
tr_cfg = cfg["train"]
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
use_fp16 = dtype == torch.float16
use_bf16 = dtype == torch.bfloat16
max_steps = int(tr_cfg.get("max_steps", 0))
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
# --- Dynamic evaluation strategy parameter handling ---
ta_params = inspect.signature(TrainingArguments.__init__).parameters
eval_key = (
"eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
)
# Setup reporting based on wandb availability
report_to = []
if wandb_run is not None:
report_to.append("wandb")
ta_kwargs = dict(
output_dir=str(ckpt_dir),
max_steps=max_steps if max_steps > 0 else -1,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
per_device_eval_batch_size=int(
tr_cfg.get(
"per_device_eval_batch_size",
tr_cfg.get("per_device_train_batch_size", 1),
)
),
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
optim=str(
tr_cfg.get(
"optim",
(
"paged_adamw_8bit"
if bool(model_cfg.get("use_4bit", False))
else "adamw_torch"
),
)
),
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
logging_steps=int(tr_cfg.get("logging_steps", 10)),
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
save_steps=int(tr_cfg.get("save_steps", 200)),
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
eval_steps=int(tr_cfg.get("eval_steps", 200)),
load_best_model_at_end=(
bool(tr_cfg.get("load_best_model_at_end", True))
if eval_ds is not None
else False
),
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=use_fp16,
bf16=use_bf16,
report_to=report_to,
remove_unused_columns=False,
)
# Set the correct argument name for this transformers version
ta_kwargs[eval_key] = str(
tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
)
training_args = TrainingArguments(**ta_kwargs)
# Setup callbacks
callbacks = [JsonlLoggerCallback(run_dir)]
# Add early stopping callback if enabled
early_stopping_cfg = tr_cfg.get("early_stopping", {})
if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
)
callbacks.append(early_stopping_callback)
print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, "
f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=default_data_collator,
callbacks=callbacks,
)
# Resume
resume_from = tr_cfg.get("resume_from_checkpoint", None)
if resume_from == "auto":
last = get_last_checkpoint(str(ckpt_dir))
resume_from = last if last else None
if resume_from:
print(f"Resuming from {resume_from}")
print("Starting instruction fine-tuning...")
trainer.train(resume_from_checkpoint=resume_from)
trainer.save_model(str(best_adapter_dir))
print(f"Saved best adapter -> {best_adapter_dir}")
if eval_ds is not None:
metrics = trainer.evaluate()
eval_loss = metrics.get("eval_loss", None)
metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2)
print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
if bool(cfg.get("merge", {}).get("enabled", False)):
del trainer, model
torch.cuda.empty_cache()
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
else:
print("Merge disabled. Run with --merge-only later if needed.")
# Finish Wandb run
finish_wandb()
if __name__ == "__main__":
main()