lesson-agent-dev / research /finetune.py
MSG
Feat/last hour (#24)
bbff1ca
Raw
History Blame Contribute Delete
40.8 kB
"""
finetune.py — Fine-tune a small LLM: FULL, LoRA, or QLoRA, one script.
======================================================================
Install:
pip install torch transformers datasets peft accelerate
pip install bitsandbytes # only needed for --mode qlora
Model resolution (first match wins)
------------------------------------
1. --model <hf-id-or-path>
2. --preset <key> from models.yaml (or FINETUNE_PRESET env)
3. MODEL_ID / BASE env (raw Hugging Face id or local path)
4. ACTIVE_MODEL preset from models.yaml (text transformers presets only)
Outputs are saved under ./models/finetuned/<preset>-<mode>/ by default.
Examples
--------
# LoRA on the lesson-agent chat dataset using models.yaml preset
python research/finetune.py --preset minicpm5-1b --mode lora --epochs 1
# Same, but read ACTIVE_MODEL / BASE from .env (auto-loaded from repo root)
python research/finetune.py --mode lora --max_steps 50
# LoRA on an instruction dataset from the Hub
python research/finetune.py \
--model Qwen/Qwen2.5-0.5B-Instruct \
--dataset tatsu-lab/alpaca --format alpaca \
--mode lora --epochs 1
# QLoRA (4-bit) on a local JSONL chat file: {"messages": [{"role":..,"content":..}, ...]}
python research/finetune.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--dataset ./data/chats.jsonl --format chat \
--mode qlora
# Hugging Face Hub datasets (--dataset is the repo id; optional --dataset-config / --split)
python research/finetune.py \
--preset minicpm5-1b --mode qlora \
--dataset tatsu-lab/alpaca --format alpaca --dataset-split train
python research/finetune.py \
--preset minicpm5-1b --mode lora \
--dataset HuggingFaceTB/smoltalk --format chat \
--dataset-config all --dataset-split train[:500]
# Env vars also work: FINETUNE_DATASET, FINETUNE_DATASET_CONFIG, FINETUNE_DATASET_SPLIT
# FULL fine-tune on raw text files (continued pretraining style)
python research/finetune.py \
--model HuggingFaceTB/SmolLM2-360M \
--dataset ./data/corpus.txt --format text \
--mode full --lr 2e-5
# After LoRA training, merge adapter into standalone weights:
python research/finetune.py --merge ./models/finetuned/minicpm5-1b-lora \
--out ./models/finetuned/minicpm5-1b-merged
Dataset formats (--format)
--------------------------
alpaca : columns instruction / input(optional) / output
chat : column messages = [{"role": "...", "content": "..."}]
prompt : columns prompt / completion (or prompt / response)
text : column text — or a plain .txt file (one doc per line / whole file)
Local files: .json, .jsonl, .csv, .txt. Hub ids: any datasets repo.
Hub datasets useful for the lesson / teacher agent (--format must match columns):
tatsu-lab/alpaca alpaca instruction tuning (general)
HuggingFaceTB/smoltalk chat multi-turn chat (use config: all)
Open-Orca/OpenOrca prompt instruction + response pairs
databricks/databricks-dolly-15k alpaca short Q&A, good for small models
After training, metrics are written to <out>/training_results.json
(train/eval loss, perplexity, result_score 0–100).
"""
import argparse
import gc
import json
import math
import os
import subprocess
import sys
from pathlib import Path
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
)
IGNORE_INDEX = -100
_REPO_ROOT = Path(__file__).resolve().parents[1]
_DEFAULT_DATASET = _REPO_ROOT / "research/data/education-lesson-chat.jsonl"
_FINETUNE_ROOT = _REPO_ROOT / "models/finetuned"
_FALLBACK_FINETUNE_PRESET = "minicpm5-1b"
def _load_dotenv(path: Path) -> None:
"""Load KEY=VALUE pairs from .env without overriding existing env vars."""
if not path.is_file():
return
for line in path.read_text().splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, _, value = line.partition("=")
key = key.strip()
value = value.strip().strip('"').strip("'")
if key:
os.environ.setdefault(key, value)
def _ensure_repo_on_path() -> None:
libs = _REPO_ROOT / "libs" / "inference" / "src"
if str(libs) not in sys.path:
sys.path.insert(0, str(libs))
def _is_finetuneable_preset(model) -> bool:
return model.backend == "transformers" and not model.multimodal and bool(
model.model_id
)
def resolve_model_and_preset(
*,
model_arg: str | None,
preset_arg: str | None,
) -> tuple[str, str | None, bool]:
"""Return (model_id_or_path, preset_key, trust_remote_code)."""
if model_arg:
trust = os.environ.get("TRUST_REMOTE_CODE", "").lower() in {
"1",
"true",
"yes",
}
return model_arg, preset_arg, trust
for env_name in ("FINETUNE_MODEL", "MODEL_ID", "BASE"):
raw = os.environ.get(env_name)
if raw:
trust = os.environ.get("TRUST_REMOTE_CODE", "").lower() in {
"1",
"true",
"yes",
}
return raw, preset_arg, trust
_ensure_repo_on_path()
from inference.config import get_app_config, get_model_config
app_config = get_app_config(reload=True)
preset_key = (
preset_arg
or os.environ.get("FINETUNE_PRESET")
or os.environ.get("ACTIVE_MODEL")
)
if preset_key and preset_key in app_config.models:
model = get_model_config(preset_key)
if not _is_finetuneable_preset(model):
print(
f"Preset {preset_key!r} is {model.backend}"
+ (" multimodal" if model.multimodal else "")
+ "; falling back to a text transformers preset for fine-tuning."
)
preset_key = None
if preset_key is None:
for candidate in (_FALLBACK_FINETUNE_PRESET, *app_config.models):
if candidate not in app_config.models:
continue
model = get_model_config(candidate)
if _is_finetuneable_preset(model):
preset_key = candidate
break
if not preset_key:
raise SystemExit(
"No fine-tunable transformers preset found. Pass --model or set BASE/MODEL_ID."
)
model = get_model_config(preset_key)
if not _is_finetuneable_preset(model):
raise SystemExit(
f"Preset {preset_key!r} cannot be fine-tuned "
f"(backend={model.backend}, multimodal={model.multimodal})."
)
return model.model_id, preset_key, model.trust_remote_code
def default_output_dir(preset_key: str | None, mode: str) -> str:
label = preset_key or "custom"
return str((_FINETUNE_ROOT / f"{label}-{mode}").resolve())
# ----------------------------------------------------------------------------
# Args
# ----------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
"--model",
type=str,
default=None,
help="HF id or local path (overrides models.yaml / env)",
)
p.add_argument(
"--preset",
type=str,
default=None,
help="Preset key from models.yaml (default: FINETUNE_PRESET or ACTIVE_MODEL)",
)
p.add_argument(
"--dataset",
type=str,
default=None,
help="HF Hub repo id (e.g. tatsu-lab/alpaca) or local file path",
)
p.add_argument(
"--dataset-config",
type=str,
default=os.environ.get("FINETUNE_DATASET_CONFIG"),
help="HF dataset config/subset name (optional)",
)
p.add_argument(
"--dataset-split",
type=str,
default=os.environ.get("FINETUNE_DATASET_SPLIT", "train"),
help="HF split name or slice, e.g. train or train[:1000]",
)
p.add_argument(
"--dataset-max-samples",
type=int,
default=int(os.environ["FINETUNE_MAX_SAMPLES"])
if os.environ.get("FINETUNE_MAX_SAMPLES")
else None,
help="Cap examples after loading (useful for Hub smoke tests)",
)
p.add_argument(
"--mix-json",
type=str,
default=os.environ.get("FINETUNE_MIX_JSON"),
help=(
"JSON list of dataset source specs to mix/replay; overrides "
"--dataset/--format. Each spec: "
'{"dataset":..,"format":..,"columns":{..},"dataset_config":..,'
'"dataset_split":..,"max_samples":..,"max_len":..,"weight":..}'
),
)
p.add_argument(
"--format",
type=str,
default=os.environ.get("FINETUNE_FORMAT", "chat"),
choices=["alpaca", "chat", "prompt", "text"],
)
# Column-name overrides: let a dataset's own columns map onto a --format
# without preprocessing (e.g. MetaMathQA query/response -> prompt format,
# orca-math question/answer -> prompt format).
p.add_argument("--prompt-key", default=None,
help="column to use as the prompt (prompt format)")
p.add_argument("--response-key", default=None,
help="column to use as the response (prompt format)")
p.add_argument("--instruction-key", default=None,
help="column to use as instruction (alpaca format)")
p.add_argument("--input-key", default=None,
help="column to use as optional input (alpaca format)")
p.add_argument("--output-key", default=None,
help="column to use as output (alpaca format)")
p.add_argument("--mode", type=str, default="lora",
choices=["full", "lora", "qlora"])
p.add_argument(
"--out",
type=str,
default=None,
help="Output directory (default: ./models/finetuned/<preset>-<mode>)",
)
# training hparams
p.add_argument("--epochs", type=float, default=1.0)
p.add_argument("--max_steps", type=int, default=-1)
p.add_argument("--batch_size", type=int, default=4)
p.add_argument("--grad_accum", type=int, default=4)
p.add_argument("--lr", type=float, default=None,
help="default: 2e-4 for (q)lora, 2e-5 for full")
p.add_argument("--max_len", type=int, default=1024)
p.add_argument("--warmup_ratio", type=float, default=0.03)
p.add_argument("--mask_prompt", action="store_true", default=True,
help="compute loss only on the response tokens")
p.add_argument("--no_mask_prompt", dest="mask_prompt", action="store_false")
# training schedule / regularization (previously hardcoded)
p.add_argument("--lr_scheduler", type=str, default="cosine",
help="LR scheduler type: cosine, linear, constant, ...")
p.add_argument("--weight_decay", type=float, default=0.01)
p.add_argument("--max_grad_norm", type=float, default=1.0)
p.add_argument("--logging_steps", type=int, default=10)
p.add_argument("--eval_steps", type=int, default=None,
help="eval every N steps (default: max_steps//5, else 200)")
p.add_argument("--save_steps", type=int, default=500)
p.add_argument("--save_total_limit", type=int, default=2)
p.add_argument("--early_stopping_patience", type=int, default=0,
help=">0 enables early stopping + load_best_model_at_end on eval_loss")
p.add_argument("--neftune_noise_alpha", type=float, default=None,
help="NEFTune noise alpha (e.g. 5) — quick instruction-tuning gain")
p.add_argument("--report_to", type=str, default="none",
help="trainer reporting: none, wandb, tensorboard, ...")
# lora hparams
p.add_argument("--lora_r", type=int, default=16)
p.add_argument("--lora_alpha", type=int, default=32)
p.add_argument("--lora_dropout", type=float, default=0.05)
p.add_argument("--lora_targets", type=str,
default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
help="comma list; 'all-linear' also works")
# misc
p.add_argument("--val_split", type=float, default=0.02)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--bf16", action="store_true", default=None)
p.add_argument("--gradient_checkpointing", action="store_true", default=True)
p.add_argument(
"--device",
type=str,
default=os.environ.get("FINETUNE_DEVICE", "auto"),
choices=["auto", "cpu", "cuda"],
help="Training device (default: auto; set FINETUNE_DEVICE=cpu to avoid GPU OOM)",
)
p.add_argument("--resume", type=str, default=None)
# merge mode
p.add_argument("--merge", type=str, default=None,
help="path to a LoRA adapter dir to merge into its base model")
p.add_argument(
"--lm-eval-after",
action="store_true",
help="run slm-lm-eval on the saved checkpoint after training",
)
p.add_argument(
"--lm-eval-config",
type=str,
default=str(_REPO_ROOT / "research/evals/configs/lm_eval_smoke.yaml"),
help="YAML config for post-training lm-eval (default: lm_eval_smoke.yaml)",
)
p.add_argument(
"--lm-eval-baseline",
type=str,
default=None,
help="optional baseline preset key; runs lm-eval on base model and compares",
)
return p.parse_args()
# ----------------------------------------------------------------------------
# Dataset loading + normalization to (prompt, response) or raw text
# ----------------------------------------------------------------------------
def load_raw_dataset(
path: str,
*,
config: str | None = None,
split: str = "train",
max_samples: int | None = None,
):
"""Load from a local file or Hugging Face Hub (datasets.load_dataset)."""
if os.path.exists(path):
ext = os.path.splitext(path)[1].lower()
if ext in (".json", ".jsonl"):
ds = load_dataset("json", data_files=path, split="train")
elif ext == ".csv":
ds = load_dataset("csv", data_files=path, split="train")
elif ext == ".txt":
ds = load_dataset("text", data_files=path, split="train")
else:
raise ValueError(f"Unsupported local file type: {ext}")
else:
kwargs: dict = {"path": path, "split": split}
if config:
kwargs["name"] = config
print(f"Loading Hub dataset: {path}" + (f" (config={config})" if config else "")
+ f" split={split}")
ds = load_dataset(**kwargs)
if max_samples is not None and max_samples > 0:
ds = ds.select(range(min(max_samples, len(ds))))
return ds
def _last_metric(history: list[dict], key: str) -> float | None:
for entry in reversed(history):
if key in entry:
return float(entry[key])
return None
def _result_score(eval_loss: float | None, train_loss: float | None) -> float | None:
"""Higher is better (0–100). Derived from eval loss, else train loss."""
loss = eval_loss if eval_loss is not None else train_loss
if loss is None:
return None
# exp(-loss) maps typical LM losses (~0.5–3) into a readable 0–100 band.
return round(min(100.0, max(0.0, 100.0 * math.exp(-loss))), 2)
def save_training_results(
out_dir: str,
*,
args,
preset_key: str | None,
train_count: int,
eval_count: int,
train_result,
log_history: list[dict],
eval_metrics: dict | None,
) -> Path:
history = train_result.metrics if hasattr(train_result, "metrics") else {}
final_train_loss = _last_metric(log_history, "loss")
if final_train_loss is None and "train_loss" in history:
final_train_loss = float(history["train_loss"])
eval_loss = None
perplexity = None
if eval_metrics:
eval_loss = float(eval_metrics.get("eval_loss", 0))
if eval_loss < 20:
perplexity = round(math.exp(eval_loss), 4)
result_score = _result_score(eval_loss, final_train_loss)
payload = {
"model": args.model,
"preset": preset_key,
"dataset": args.dataset,
"dataset_config": args.dataset_config,
"dataset_split": args.dataset_split,
"mix": json.loads(args.mix_json) if args.mix_json else None,
"format": args.format,
"mode": args.mode,
"output_dir": out_dir,
"samples": {"train": train_count, "eval": eval_count},
"metrics": {
"final_train_loss": round(final_train_loss, 6)
if final_train_loss is not None
else None,
"eval_loss": round(eval_loss, 6) if eval_loss is not None else None,
"perplexity": perplexity,
"loss_score": round(eval_loss, 6)
if eval_loss is not None
else (
round(final_train_loss, 6) if final_train_loss is not None else None
),
"result_score": result_score,
},
"training": {
"epochs": args.epochs,
"max_steps": args.max_steps,
"global_step": getattr(train_result, "global_step", None),
"train_runtime_sec": round(history.get("train_runtime", 0), 2)
if history
else None,
"train_samples_per_second": history.get("train_samples_per_second"),
},
}
path = Path(out_dir) / "training_results.json"
path.write_text(json.dumps(payload, indent=2) + "\n")
return path
def to_prompt_response(example, fmt, tokenizer, keys=None, prompt_prefix=None):
"""Normalize any supported format into a single training string,
returning (full_text, prompt_text). prompt_text is None for raw text.
`keys` optionally remaps a dataset's column names onto the format's
expected fields (e.g. {"prompt": "query"} for MetaMathQA).
`prompt_prefix` prepends fixed instruction text to prompt-format user turns."""
keys = keys or {}
if fmt == "text":
return example[keys.get("text", "text")], None
if fmt == "alpaca":
instr = example.get(keys.get("instruction", "instruction"), "")
inp = example.get(keys.get("input", "input"), "") or ""
out = example.get(keys.get("output", "output"), "")
user = instr if not inp else f"{instr}\n\n{inp}"
messages = [{"role": "user", "content": user},
{"role": "assistant", "content": out}]
elif fmt == "prompt":
prompt = example.get(keys.get("prompt", "prompt"), "")
if prompt_prefix:
prompt = f"{prompt_prefix}{prompt}"
rkey = keys.get("response")
resp = example.get(rkey, "") if rkey else example.get(
"completion", example.get("response", ""))
messages = [{"role": "user", "content": prompt},
{"role": "assistant", "content": resp}]
elif fmt == "chat":
messages = example["messages"]
else:
raise ValueError(fmt)
# Use the model's chat template when it has one; else simple fallback.
if tokenizer.chat_template:
full = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False)
prompt_only = tokenizer.apply_chat_template(
messages[:-1], tokenize=False, add_generation_prompt=True)
else:
prompt_only = "".join(
f"### {m['role'].capitalize()}:\n{m['content']}\n\n"
for m in messages[:-1]) + "### Assistant:\n"
full = prompt_only + messages[-1]["content"] + (tokenizer.eos_token or "")
return full, prompt_only
def build_tokenize_fn(tokenizer, fmt, max_len, mask_prompt, keys=None, prompt_prefix=None):
def fn(example):
full, prompt = to_prompt_response(
example, fmt, tokenizer, keys, prompt_prefix=prompt_prefix)
ids = tokenizer(full, truncation=True, max_length=max_len,
add_special_tokens=(fmt == "text"))["input_ids"]
labels = list(ids)
if mask_prompt and prompt is not None:
p_len = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
p_len = min(p_len, len(labels))
labels[:p_len] = [IGNORE_INDEX] * p_len # no loss on prompt
return {"input_ids": ids, "labels": labels}
return fn
def _source_specs(args) -> list[dict]:
"""Return the list of dataset source specs to train on.
With --mix-json, parse the JSON list verbatim. Otherwise synthesize a
single source from the top-level --dataset/--format/--*-key args."""
if args.mix_json:
specs = json.loads(args.mix_json)
if not isinstance(specs, list) or not specs:
raise SystemExit("--mix-json must be a non-empty JSON list of source specs")
return specs
return [{
"dataset": args.dataset,
"format": args.format,
"dataset_config": args.dataset_config,
"dataset_split": args.dataset_split,
"max_samples": args.dataset_max_samples,
"columns": {k: v for k, v in {
"prompt": args.prompt_key, "response": args.response_key,
"instruction": args.instruction_key, "input": args.input_key,
"output": args.output_key,
}.items() if v},
}]
def _apply_weight(ds, weight):
"""Up-sample (weight > 1, with repeats) or sub-sample (weight < 1) a source."""
if not weight or weight == 1.0 or len(ds) == 0:
return ds
target = max(0, int(round(len(ds) * float(weight))))
if target == 0:
return ds.select([])
n = len(ds)
return ds.select([i % n for i in range(target)]) # repeats when target > n
def build_training_dataset(args, tokenizer):
"""Load, tokenize, weight and concatenate every source into one dataset.
Each source carries its own format / columns / split / max_len so a skill
dataset can be mixed with a general-data replay slice in one run."""
from datasets import concatenate_datasets
specs = _source_specs(args)
multi = len(specs) > 1
if multi:
print(f"Mixing {len(specs)} dataset source(s):")
parts = []
for i, spec in enumerate(specs):
dataset = spec.get("dataset")
if not dataset:
raise SystemExit(f"mix source #{i} is missing 'dataset'")
fmt = spec.get("format", args.format)
raw = load_raw_dataset(
dataset,
config=spec.get("dataset_config"),
split=spec.get("dataset_split", "train"),
max_samples=spec.get("max_samples"),
)
raw = raw.shuffle(seed=args.seed)
keys = spec.get("columns") or {}
max_len = spec.get("max_len", args.max_len)
prefix = spec.get("prompt_prefix")
tokenize = build_tokenize_fn(
tokenizer, fmt, max_len, args.mask_prompt, keys, prompt_prefix=prefix)
tok = raw.map(tokenize, remove_columns=raw.column_names,
desc=f"tokenizing {dataset}")
tok = tok.filter(lambda e: len(e["input_ids"]) > 1)
tok = _apply_weight(tok, spec.get("weight"))
if multi:
wnote = f" (weight {spec['weight']})" if spec.get("weight") else ""
print(f" - {dataset} [{fmt}] -> {len(tok)} examples{wnote}")
parts.append(tok)
ds = parts[0] if len(parts) == 1 else concatenate_datasets(parts)
return ds.shuffle(seed=args.seed)
class CausalCollator:
"""Pads input_ids with pad_token and labels with IGNORE_INDEX."""
def __init__(self, tokenizer):
self.tok = tokenizer
def __call__(self, batch):
max_len = max(len(b["input_ids"]) for b in batch)
input_ids, labels, attn = [], [], []
pad = self.tok.pad_token_id
for b in batch:
n = max_len - len(b["input_ids"])
input_ids.append(b["input_ids"] + [pad] * n)
labels.append(b["labels"] + [IGNORE_INDEX] * n)
attn.append([1] * len(b["input_ids"]) + [0] * n)
return {
"input_ids": torch.tensor(input_ids),
"labels": torch.tensor(labels),
"attention_mask": torch.tensor(attn),
}
# ----------------------------------------------------------------------------
# Model loading for each mode
# ----------------------------------------------------------------------------
def _training_uses_cuda(args) -> bool:
if args.device == "cpu":
return False
if args.device == "cuda":
return True
return torch.cuda.is_available()
def _gpu_memory_summary() -> str:
if not torch.cuda.is_available():
return "CUDA not available"
free, total = torch.cuda.mem_get_info()
alloc = torch.cuda.memory_allocated()
reserved = torch.cuda.memory_reserved()
return (
f"{free // 2**20} MiB free / {total // 2**20} MiB total "
f"(allocated {alloc // 2**20} MiB, reserved {reserved // 2**20} MiB)"
)
def _gpu_total_gib() -> float | None:
if not torch.cuda.is_available():
return None
_, total = torch.cuda.mem_get_info()
return total / (1024**3)
def _apply_low_vram_defaults(args) -> None:
"""Cap batch/seq length and prefer QLoRA on GPUs that cannot fit full LoRA."""
if not _training_uses_cuda(args):
return
total_gib = _gpu_total_gib()
if total_gib is None or total_gib >= 6.0:
return
orig_batch, orig_max_len, orig_mode = args.batch_size, args.max_len, args.mode
args.batch_size = min(args.batch_size, 1)
args.max_len = min(args.max_len, 512)
args.gradient_checkpointing = True
if total_gib < 4.5 and args.mode == "lora":
try:
import bitsandbytes # noqa: F401
args.mode = "qlora"
except ImportError:
print(
f"Warning: {total_gib:.1f} GiB GPU — full LoRA may OOM. "
"Install finetune extras and use --mode qlora:\n"
" uv sync --group finetune"
)
if (
args.batch_size != orig_batch
or args.max_len != orig_max_len
or args.mode != orig_mode
):
print(
f"Low VRAM ({total_gib:.1f} GiB): adjusted training defaults — "
f"batch_size {orig_batch}->{args.batch_size}, "
f"max_len {orig_max_len}->{args.max_len}"
+ (f", mode {orig_mode}->{args.mode}" if args.mode != orig_mode else "")
)
def _validate_cuda_device(args) -> None:
if not _training_uses_cuda(args):
return
if torch.cuda.is_available():
return
raise SystemExit(
"CUDA training was requested (--device cuda or auto with a visible GPU) "
"but PyTorch cannot use the GPU.\n"
f" torch.cuda.is_available() = False\n"
f" torch.cuda.device_count() = {torch.cuda.device_count()}\n"
"Run `nvidia-smi` and check for driver errors (ERR! fields). "
"If the GPU is busy or broken, free it or reboot, then retry.\n"
"Fallback: pass --device cpu (slower, higher RAM use)."
)
def clear_gpu_memory(*, reset_peak: bool = True) -> None:
"""Release cached CUDA allocations before loading a model."""
gc.collect()
if not torch.cuda.is_available():
return
torch.cuda.empty_cache()
try:
torch.cuda.ipc_collect()
except Exception:
pass
if reset_peak:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
def _cuda_device_map() -> str | dict[str, int]:
"""Keep weights on one GPU; avoid CPU offload on small cards."""
if torch.cuda.device_count() <= 1:
return {"": 0}
return "auto"
def load_model_and_tokenizer(args):
common = {"trust_remote_code": args.trust_remote_code}
tokenizer = AutoTokenizer.from_pretrained(args.model, **common)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
use_cuda = _training_uses_cuda(args)
bf16_ok = (
args.bf16
if args.bf16 is not None
else use_cuda and torch.cuda.is_bf16_supported()
)
dtype = torch.bfloat16 if bf16_ok else torch.float32
if args.mode == "qlora":
if not use_cuda:
raise SystemExit("QLoRA requires CUDA. Use --mode lora with --device cpu.")
try:
import bitsandbytes # noqa: F401
except ImportError as exc:
raise SystemExit(
"QLoRA requires bitsandbytes. Install with:\n"
" uv sync --group finetune"
) from exc
from transformers import BitsAndBytesConfig
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16 if bf16_ok else torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
args.model,
quantization_config=bnb,
device_map=_cuda_device_map(),
**common,
)
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model,
dtype=dtype,
device_map=_cuda_device_map() if use_cuda else None,
**common,
)
if not use_cuda:
model.to("cpu")
if args.mode in ("lora", "qlora"):
from peft import LoraConfig, get_peft_model
targets = ("all-linear" if args.lora_targets == "all-linear"
else [t.strip() for t in args.lora_targets.split(",")])
cfg = LoraConfig(
r=args.lora_r, lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=targets, task_type="CAUSAL_LM")
model = get_peft_model(model, cfg)
model.print_trainable_parameters()
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
model.config.use_cache = False
return model, tokenizer, bf16_ok
# ----------------------------------------------------------------------------
# Merge a trained LoRA adapter back into base weights
# ----------------------------------------------------------------------------
def merge_adapter(adapter_dir, out_dir):
from peft import PeftModel, PeftConfig
cfg = PeftConfig.from_pretrained(adapter_dir)
base = AutoModelForCausalLM.from_pretrained(
cfg.base_model_name_or_path, torch_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(cfg.base_model_name_or_path)
model = PeftModel.from_pretrained(base, adapter_dir)
model = model.merge_and_unload()
model.save_pretrained(out_dir)
tok.save_pretrained(out_dir)
print(f"Merged model saved to {out_dir}")
def run_post_lm_eval(
*,
checkpoint_path: str,
config_path: str,
experiment_name: str,
baseline_preset: str | None = None,
adapter_path: str | None = None,
) -> dict | None:
"""Run slm-lm-eval via subprocess; return paths written under post_eval."""
baseline_results: Path | None = None
if baseline_preset:
baseline_name = f"{baseline_preset}__lm-eval-baseline"
baseline_cmd = [
"uv",
"run",
"--package",
"slm-evals",
"slm-lm-eval",
"--config",
config_path,
"--preset",
baseline_preset,
"--experiment-name",
baseline_name,
]
print(f"\n--- lm-eval baseline ({baseline_preset}) ---")
subprocess.run(baseline_cmd, cwd=_REPO_ROOT, check=False)
baseline_results = (
_REPO_ROOT / "results" / "lm_eval" / baseline_name / "results.json"
)
cmd = [
"uv",
"run",
"--package",
"slm-evals",
"slm-lm-eval",
"--config",
config_path,
"--model",
checkpoint_path,
"--experiment-name",
experiment_name,
]
if adapter_path:
cmd.extend(["--adapter", adapter_path])
if baseline_results and baseline_results.is_file():
cmd.extend(["--compare-to", str(baseline_results)])
print(f"\n--- lm-eval candidate ({experiment_name}) ---")
proc = subprocess.run(cmd, cwd=_REPO_ROOT, check=False)
out_root = _REPO_ROOT / "results" / "lm_eval" / experiment_name
post_eval = {
"experiment_name": experiment_name,
"config": config_path,
"checkpoint_path": checkpoint_path,
"adapter_path": adapter_path,
"baseline_preset": baseline_preset,
"results_json": str(out_root / "results.json"),
"summary_md": str(out_root / "summary.md"),
"comparison_md": str(out_root / "comparison.md")
if (out_root / "comparison.md").is_file()
else None,
"exit_code": proc.returncode,
}
return post_eval if proc.returncode == 0 else post_eval
# ----------------------------------------------------------------------------
# Main
# ----------------------------------------------------------------------------
def main():
_load_dotenv(_REPO_ROOT / ".env")
args = parse_args()
if args.merge:
out_dir = args.out or default_output_dir(None, "merged")
merge_adapter(args.merge, out_dir)
return
model_id, preset_key, trust_remote_code = resolve_model_and_preset(
model_arg=args.model,
preset_arg=args.preset,
)
args.model = model_id
args.trust_remote_code = trust_remote_code
if not args.dataset:
args.dataset = (
os.environ.get("FINETUNE_DATASET")
or str(_DEFAULT_DATASET)
)
if not args.out:
args.out = os.environ.get("FINETUNE_OUT") or default_output_dir(
preset_key, args.mode
)
Path(args.out).mkdir(parents=True, exist_ok=True)
print(f"Base model: {args.model}")
if preset_key:
print(f"Preset: {preset_key}")
if args.mix_json:
print(f"Dataset mix: {len(json.loads(args.mix_json))} source(s)")
else:
print(f"Dataset: {args.dataset}")
print(f"Output: {args.out}")
print(f"Device: {args.device}")
_validate_cuda_device(args)
_apply_low_vram_defaults(args)
if _training_uses_cuda(args):
print(f"GPU before cleanup: {_gpu_memory_summary()}")
clear_gpu_memory()
print(f"GPU after cleanup: {_gpu_memory_summary()}")
lr = args.lr or (2e-5 if args.mode == "full" else 2e-4)
model, tokenizer, bf16_ok = load_model_and_tokenizer(args)
if _training_uses_cuda(args):
print(f"GPU after model load: {_gpu_memory_summary()}")
ds = build_training_dataset(args, tokenizer)
if args.val_split > 0:
split = ds.train_test_split(test_size=args.val_split, seed=args.seed)
train_ds, eval_ds = split["train"], split["test"]
else:
train_ds, eval_ds = ds, None
# Default eval cadence to the run length so short (max_steps) runs still
# evaluate mid-training instead of only at the end.
eval_steps = args.eval_steps
if eval_steps is None:
eval_steps = max(1, args.max_steps // 5) if args.max_steps > 0 else 200
use_best = args.early_stopping_patience > 0 and eval_ds is not None
# load_best_model_at_end needs save_steps aligned to eval_steps.
save_steps = eval_steps if use_best else args.save_steps
targs = TrainingArguments(
output_dir=args.out,
num_train_epochs=args.epochs,
max_steps=args.max_steps,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=lr,
lr_scheduler_type=args.lr_scheduler,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
max_grad_norm=args.max_grad_norm,
logging_steps=args.logging_steps,
eval_strategy="steps" if eval_ds is not None else "no",
eval_steps=eval_steps,
save_strategy="steps",
save_steps=save_steps,
save_total_limit=args.save_total_limit,
load_best_model_at_end=use_best,
metric_for_best_model="eval_loss" if use_best else None,
greater_is_better=False if use_best else None,
bf16=bf16_ok,
fp16=(not bf16_ok and _training_uses_cuda(args)),
gradient_checkpointing=args.gradient_checkpointing,
neftune_noise_alpha=args.neftune_noise_alpha,
report_to=args.report_to,
seed=args.seed,
)
callbacks = []
if use_best:
from transformers import EarlyStoppingCallback
callbacks.append(
EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience)
)
trainer = Trainer(
model=model,
args=targs,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=CausalCollator(tokenizer),
callbacks=callbacks,
)
train_result = trainer.train(resume_from_checkpoint=args.resume)
# ---- save -----------------------------------------------------------
model.config.use_cache = True
trainer.save_model(args.out) # full weights OR adapter only
tokenizer.save_pretrained(args.out)
eval_metrics = None
if eval_ds is not None:
eval_metrics = trainer.evaluate()
ppl = (
math.exp(eval_metrics["eval_loss"])
if eval_metrics["eval_loss"] < 20
else float("inf")
)
print(
f"\neval_loss={eval_metrics['eval_loss']:.4f} "
f"perplexity={ppl:.2f}"
)
results_path = save_training_results(
args.out,
args=args,
preset_key=preset_key,
train_count=len(train_ds),
eval_count=len(eval_ds) if eval_ds is not None else 0,
train_result=train_result,
log_history=trainer.state.log_history,
eval_metrics=eval_metrics,
)
m = json.loads(results_path.read_text())["metrics"]
print("\n--- scores ---")
print(f"loss_score = {m['loss_score']} (lower is better)")
print(f"result_score = {m['result_score']} (0–100, higher is better)")
print(f"Saved to {results_path}")
if args.mode in ("lora", "qlora"):
merged = f"{args.out}-merged"
print(f"\nAdapter saved to {args.out}")
print(
"Use in Gradio: set ACTIVE_MODEL to the matching *-lora preset "
"in models.yaml, or merge with:\n"
f" python research/finetune.py --merge {args.out} --out {merged}"
)
else:
print(f"\nFull model saved to {args.out}")
# quick smoke generation
try:
model.eval()
prompt = "Hello! Briefly introduce yourself."
if tokenizer.chat_template:
text = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False, add_generation_prompt=True)
else:
text = prompt
device = next(model.parameters()).device
ids = tokenizer(text, return_tensors="pt").to(device)
out = model.generate(**ids, max_new_tokens=60, do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.pad_token_id)
print("\n--- sample ---\n" +
tokenizer.decode(out[0][ids["input_ids"].shape[1]:],
skip_special_tokens=True))
except Exception as e: # smoke test is best-effort
print(f"(sample generation skipped: {e})")
if args.lm_eval_after:
exp_name = f"{Path(args.out).name}__lm-eval-posttrain"
if args.mode in ("lora", "qlora"):
post_eval = run_post_lm_eval(
checkpoint_path=args.model,
config_path=args.lm_eval_config,
experiment_name=exp_name,
baseline_preset=args.lm_eval_baseline or preset_key,
adapter_path=args.out,
)
else:
post_eval = run_post_lm_eval(
checkpoint_path=args.out,
config_path=args.lm_eval_config,
experiment_name=exp_name,
baseline_preset=args.lm_eval_baseline or preset_key,
)
if post_eval:
payload = json.loads(results_path.read_text())
payload["post_eval"] = post_eval
results_path.write_text(json.dumps(payload, indent=2))
print(f"Appended post_eval to {results_path}")
if _training_uses_cuda(args):
clear_gpu_memory()
print(f"GPU after training: {_gpu_memory_summary()}")
if __name__ == "__main__":
main()