""" finetune.py — Fine-tune a small LLM: FULL, LoRA, or QLoRA, one script. ====================================================================== Install: pip install torch transformers datasets peft accelerate pip install bitsandbytes # only needed for --mode qlora Model resolution (first match wins) ------------------------------------ 1. --model 2. --preset from models.yaml (or FINETUNE_PRESET env) 3. MODEL_ID / BASE env (raw Hugging Face id or local path) 4. ACTIVE_MODEL preset from models.yaml (text transformers presets only) Outputs are saved under ./models/finetuned/-/ by default. Examples -------- # LoRA on the lesson-agent chat dataset using models.yaml preset python research/finetune.py --preset minicpm5-1b --mode lora --epochs 1 # Same, but read ACTIVE_MODEL / BASE from .env (auto-loaded from repo root) python research/finetune.py --mode lora --max_steps 50 # LoRA on an instruction dataset from the Hub python research/finetune.py \ --model Qwen/Qwen2.5-0.5B-Instruct \ --dataset tatsu-lab/alpaca --format alpaca \ --mode lora --epochs 1 # QLoRA (4-bit) on a local JSONL chat file: {"messages": [{"role":..,"content":..}, ...]} python research/finetune.py \ --model meta-llama/Llama-3.2-1B-Instruct \ --dataset ./data/chats.jsonl --format chat \ --mode qlora # Hugging Face Hub datasets (--dataset is the repo id; optional --dataset-config / --split) python research/finetune.py \ --preset minicpm5-1b --mode qlora \ --dataset tatsu-lab/alpaca --format alpaca --dataset-split train python research/finetune.py \ --preset minicpm5-1b --mode lora \ --dataset HuggingFaceTB/smoltalk --format chat \ --dataset-config all --dataset-split train[:500] # Env vars also work: FINETUNE_DATASET, FINETUNE_DATASET_CONFIG, FINETUNE_DATASET_SPLIT # FULL fine-tune on raw text files (continued pretraining style) python research/finetune.py \ --model HuggingFaceTB/SmolLM2-360M \ --dataset ./data/corpus.txt --format text \ --mode full --lr 2e-5 # After LoRA training, merge adapter into standalone weights: python research/finetune.py --merge ./models/finetuned/minicpm5-1b-lora \ --out ./models/finetuned/minicpm5-1b-merged Dataset formats (--format) -------------------------- alpaca : columns instruction / input(optional) / output chat : column messages = [{"role": "...", "content": "..."}] prompt : columns prompt / completion (or prompt / response) text : column text — or a plain .txt file (one doc per line / whole file) Local files: .json, .jsonl, .csv, .txt. Hub ids: any datasets repo. Hub datasets useful for the lesson / teacher agent (--format must match columns): tatsu-lab/alpaca alpaca instruction tuning (general) HuggingFaceTB/smoltalk chat multi-turn chat (use config: all) Open-Orca/OpenOrca prompt instruction + response pairs databricks/databricks-dolly-15k alpaca short Q&A, good for small models After training, metrics are written to /training_results.json (train/eval loss, perplexity, result_score 0–100). """ import argparse import gc import json import math import os import subprocess import sys from pathlib import Path os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, ) IGNORE_INDEX = -100 _REPO_ROOT = Path(__file__).resolve().parents[1] _DEFAULT_DATASET = _REPO_ROOT / "research/data/education-lesson-chat.jsonl" _FINETUNE_ROOT = _REPO_ROOT / "models/finetuned" _FALLBACK_FINETUNE_PRESET = "minicpm5-1b" def _load_dotenv(path: Path) -> None: """Load KEY=VALUE pairs from .env without overriding existing env vars.""" if not path.is_file(): return for line in path.read_text().splitlines(): line = line.strip() if not line or line.startswith("#") or "=" not in line: continue key, _, value = line.partition("=") key = key.strip() value = value.strip().strip('"').strip("'") if key: os.environ.setdefault(key, value) def _ensure_repo_on_path() -> None: libs = _REPO_ROOT / "libs" / "inference" / "src" if str(libs) not in sys.path: sys.path.insert(0, str(libs)) def _is_finetuneable_preset(model) -> bool: return model.backend == "transformers" and not model.multimodal and bool( model.model_id ) def resolve_model_and_preset( *, model_arg: str | None, preset_arg: str | None, ) -> tuple[str, str | None, bool]: """Return (model_id_or_path, preset_key, trust_remote_code).""" if model_arg: trust = os.environ.get("TRUST_REMOTE_CODE", "").lower() in { "1", "true", "yes", } return model_arg, preset_arg, trust for env_name in ("FINETUNE_MODEL", "MODEL_ID", "BASE"): raw = os.environ.get(env_name) if raw: trust = os.environ.get("TRUST_REMOTE_CODE", "").lower() in { "1", "true", "yes", } return raw, preset_arg, trust _ensure_repo_on_path() from inference.config import get_app_config, get_model_config app_config = get_app_config(reload=True) preset_key = ( preset_arg or os.environ.get("FINETUNE_PRESET") or os.environ.get("ACTIVE_MODEL") ) if preset_key and preset_key in app_config.models: model = get_model_config(preset_key) if not _is_finetuneable_preset(model): print( f"Preset {preset_key!r} is {model.backend}" + (" multimodal" if model.multimodal else "") + "; falling back to a text transformers preset for fine-tuning." ) preset_key = None if preset_key is None: for candidate in (_FALLBACK_FINETUNE_PRESET, *app_config.models): if candidate not in app_config.models: continue model = get_model_config(candidate) if _is_finetuneable_preset(model): preset_key = candidate break if not preset_key: raise SystemExit( "No fine-tunable transformers preset found. Pass --model or set BASE/MODEL_ID." ) model = get_model_config(preset_key) if not _is_finetuneable_preset(model): raise SystemExit( f"Preset {preset_key!r} cannot be fine-tuned " f"(backend={model.backend}, multimodal={model.multimodal})." ) return model.model_id, preset_key, model.trust_remote_code def default_output_dir(preset_key: str | None, mode: str) -> str: label = preset_key or "custom" return str((_FINETUNE_ROOT / f"{label}-{mode}").resolve()) # ---------------------------------------------------------------------------- # Args # ---------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser() p.add_argument( "--model", type=str, default=None, help="HF id or local path (overrides models.yaml / env)", ) p.add_argument( "--preset", type=str, default=None, help="Preset key from models.yaml (default: FINETUNE_PRESET or ACTIVE_MODEL)", ) p.add_argument( "--dataset", type=str, default=None, help="HF Hub repo id (e.g. tatsu-lab/alpaca) or local file path", ) p.add_argument( "--dataset-config", type=str, default=os.environ.get("FINETUNE_DATASET_CONFIG"), help="HF dataset config/subset name (optional)", ) p.add_argument( "--dataset-split", type=str, default=os.environ.get("FINETUNE_DATASET_SPLIT", "train"), help="HF split name or slice, e.g. train or train[:1000]", ) p.add_argument( "--dataset-max-samples", type=int, default=int(os.environ["FINETUNE_MAX_SAMPLES"]) if os.environ.get("FINETUNE_MAX_SAMPLES") else None, help="Cap examples after loading (useful for Hub smoke tests)", ) p.add_argument( "--mix-json", type=str, default=os.environ.get("FINETUNE_MIX_JSON"), help=( "JSON list of dataset source specs to mix/replay; overrides " "--dataset/--format. Each spec: " '{"dataset":..,"format":..,"columns":{..},"dataset_config":..,' '"dataset_split":..,"max_samples":..,"max_len":..,"weight":..}' ), ) p.add_argument( "--format", type=str, default=os.environ.get("FINETUNE_FORMAT", "chat"), choices=["alpaca", "chat", "prompt", "text"], ) # Column-name overrides: let a dataset's own columns map onto a --format # without preprocessing (e.g. MetaMathQA query/response -> prompt format, # orca-math question/answer -> prompt format). p.add_argument("--prompt-key", default=None, help="column to use as the prompt (prompt format)") p.add_argument("--response-key", default=None, help="column to use as the response (prompt format)") p.add_argument("--instruction-key", default=None, help="column to use as instruction (alpaca format)") p.add_argument("--input-key", default=None, help="column to use as optional input (alpaca format)") p.add_argument("--output-key", default=None, help="column to use as output (alpaca format)") p.add_argument("--mode", type=str, default="lora", choices=["full", "lora", "qlora"]) p.add_argument( "--out", type=str, default=None, help="Output directory (default: ./models/finetuned/-)", ) # training hparams p.add_argument("--epochs", type=float, default=1.0) p.add_argument("--max_steps", type=int, default=-1) p.add_argument("--batch_size", type=int, default=4) p.add_argument("--grad_accum", type=int, default=4) p.add_argument("--lr", type=float, default=None, help="default: 2e-4 for (q)lora, 2e-5 for full") p.add_argument("--max_len", type=int, default=1024) p.add_argument("--warmup_ratio", type=float, default=0.03) p.add_argument("--mask_prompt", action="store_true", default=True, help="compute loss only on the response tokens") p.add_argument("--no_mask_prompt", dest="mask_prompt", action="store_false") # training schedule / regularization (previously hardcoded) p.add_argument("--lr_scheduler", type=str, default="cosine", help="LR scheduler type: cosine, linear, constant, ...") p.add_argument("--weight_decay", type=float, default=0.01) p.add_argument("--max_grad_norm", type=float, default=1.0) p.add_argument("--logging_steps", type=int, default=10) p.add_argument("--eval_steps", type=int, default=None, help="eval every N steps (default: max_steps//5, else 200)") p.add_argument("--save_steps", type=int, default=500) p.add_argument("--save_total_limit", type=int, default=2) p.add_argument("--early_stopping_patience", type=int, default=0, help=">0 enables early stopping + load_best_model_at_end on eval_loss") p.add_argument("--neftune_noise_alpha", type=float, default=None, help="NEFTune noise alpha (e.g. 5) — quick instruction-tuning gain") p.add_argument("--report_to", type=str, default="none", help="trainer reporting: none, wandb, tensorboard, ...") # lora hparams p.add_argument("--lora_r", type=int, default=16) p.add_argument("--lora_alpha", type=int, default=32) p.add_argument("--lora_dropout", type=float, default=0.05) p.add_argument("--lora_targets", type=str, default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj", help="comma list; 'all-linear' also works") # misc p.add_argument("--val_split", type=float, default=0.02) p.add_argument("--seed", type=int, default=42) p.add_argument("--bf16", action="store_true", default=None) p.add_argument("--gradient_checkpointing", action="store_true", default=True) p.add_argument( "--device", type=str, default=os.environ.get("FINETUNE_DEVICE", "auto"), choices=["auto", "cpu", "cuda"], help="Training device (default: auto; set FINETUNE_DEVICE=cpu to avoid GPU OOM)", ) p.add_argument("--resume", type=str, default=None) # merge mode p.add_argument("--merge", type=str, default=None, help="path to a LoRA adapter dir to merge into its base model") p.add_argument( "--lm-eval-after", action="store_true", help="run slm-lm-eval on the saved checkpoint after training", ) p.add_argument( "--lm-eval-config", type=str, default=str(_REPO_ROOT / "research/evals/configs/lm_eval_smoke.yaml"), help="YAML config for post-training lm-eval (default: lm_eval_smoke.yaml)", ) p.add_argument( "--lm-eval-baseline", type=str, default=None, help="optional baseline preset key; runs lm-eval on base model and compares", ) return p.parse_args() # ---------------------------------------------------------------------------- # Dataset loading + normalization to (prompt, response) or raw text # ---------------------------------------------------------------------------- def load_raw_dataset( path: str, *, config: str | None = None, split: str = "train", max_samples: int | None = None, ): """Load from a local file or Hugging Face Hub (datasets.load_dataset).""" if os.path.exists(path): ext = os.path.splitext(path)[1].lower() if ext in (".json", ".jsonl"): ds = load_dataset("json", data_files=path, split="train") elif ext == ".csv": ds = load_dataset("csv", data_files=path, split="train") elif ext == ".txt": ds = load_dataset("text", data_files=path, split="train") else: raise ValueError(f"Unsupported local file type: {ext}") else: kwargs: dict = {"path": path, "split": split} if config: kwargs["name"] = config print(f"Loading Hub dataset: {path}" + (f" (config={config})" if config else "") + f" split={split}") ds = load_dataset(**kwargs) if max_samples is not None and max_samples > 0: ds = ds.select(range(min(max_samples, len(ds)))) return ds def _last_metric(history: list[dict], key: str) -> float | None: for entry in reversed(history): if key in entry: return float(entry[key]) return None def _result_score(eval_loss: float | None, train_loss: float | None) -> float | None: """Higher is better (0–100). Derived from eval loss, else train loss.""" loss = eval_loss if eval_loss is not None else train_loss if loss is None: return None # exp(-loss) maps typical LM losses (~0.5–3) into a readable 0–100 band. return round(min(100.0, max(0.0, 100.0 * math.exp(-loss))), 2) def save_training_results( out_dir: str, *, args, preset_key: str | None, train_count: int, eval_count: int, train_result, log_history: list[dict], eval_metrics: dict | None, ) -> Path: history = train_result.metrics if hasattr(train_result, "metrics") else {} final_train_loss = _last_metric(log_history, "loss") if final_train_loss is None and "train_loss" in history: final_train_loss = float(history["train_loss"]) eval_loss = None perplexity = None if eval_metrics: eval_loss = float(eval_metrics.get("eval_loss", 0)) if eval_loss < 20: perplexity = round(math.exp(eval_loss), 4) result_score = _result_score(eval_loss, final_train_loss) payload = { "model": args.model, "preset": preset_key, "dataset": args.dataset, "dataset_config": args.dataset_config, "dataset_split": args.dataset_split, "mix": json.loads(args.mix_json) if args.mix_json else None, "format": args.format, "mode": args.mode, "output_dir": out_dir, "samples": {"train": train_count, "eval": eval_count}, "metrics": { "final_train_loss": round(final_train_loss, 6) if final_train_loss is not None else None, "eval_loss": round(eval_loss, 6) if eval_loss is not None else None, "perplexity": perplexity, "loss_score": round(eval_loss, 6) if eval_loss is not None else ( round(final_train_loss, 6) if final_train_loss is not None else None ), "result_score": result_score, }, "training": { "epochs": args.epochs, "max_steps": args.max_steps, "global_step": getattr(train_result, "global_step", None), "train_runtime_sec": round(history.get("train_runtime", 0), 2) if history else None, "train_samples_per_second": history.get("train_samples_per_second"), }, } path = Path(out_dir) / "training_results.json" path.write_text(json.dumps(payload, indent=2) + "\n") return path def to_prompt_response(example, fmt, tokenizer, keys=None, prompt_prefix=None): """Normalize any supported format into a single training string, returning (full_text, prompt_text). prompt_text is None for raw text. `keys` optionally remaps a dataset's column names onto the format's expected fields (e.g. {"prompt": "query"} for MetaMathQA). `prompt_prefix` prepends fixed instruction text to prompt-format user turns.""" keys = keys or {} if fmt == "text": return example[keys.get("text", "text")], None if fmt == "alpaca": instr = example.get(keys.get("instruction", "instruction"), "") inp = example.get(keys.get("input", "input"), "") or "" out = example.get(keys.get("output", "output"), "") user = instr if not inp else f"{instr}\n\n{inp}" messages = [{"role": "user", "content": user}, {"role": "assistant", "content": out}] elif fmt == "prompt": prompt = example.get(keys.get("prompt", "prompt"), "") if prompt_prefix: prompt = f"{prompt_prefix}{prompt}" rkey = keys.get("response") resp = example.get(rkey, "") if rkey else example.get( "completion", example.get("response", "")) messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": resp}] elif fmt == "chat": messages = example["messages"] else: raise ValueError(fmt) # Use the model's chat template when it has one; else simple fallback. if tokenizer.chat_template: full = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False) prompt_only = tokenizer.apply_chat_template( messages[:-1], tokenize=False, add_generation_prompt=True) else: prompt_only = "".join( f"### {m['role'].capitalize()}:\n{m['content']}\n\n" for m in messages[:-1]) + "### Assistant:\n" full = prompt_only + messages[-1]["content"] + (tokenizer.eos_token or "") return full, prompt_only def build_tokenize_fn(tokenizer, fmt, max_len, mask_prompt, keys=None, prompt_prefix=None): def fn(example): full, prompt = to_prompt_response( example, fmt, tokenizer, keys, prompt_prefix=prompt_prefix) ids = tokenizer(full, truncation=True, max_length=max_len, add_special_tokens=(fmt == "text"))["input_ids"] labels = list(ids) if mask_prompt and prompt is not None: p_len = len(tokenizer(prompt, add_special_tokens=False)["input_ids"]) p_len = min(p_len, len(labels)) labels[:p_len] = [IGNORE_INDEX] * p_len # no loss on prompt return {"input_ids": ids, "labels": labels} return fn def _source_specs(args) -> list[dict]: """Return the list of dataset source specs to train on. With --mix-json, parse the JSON list verbatim. Otherwise synthesize a single source from the top-level --dataset/--format/--*-key args.""" if args.mix_json: specs = json.loads(args.mix_json) if not isinstance(specs, list) or not specs: raise SystemExit("--mix-json must be a non-empty JSON list of source specs") return specs return [{ "dataset": args.dataset, "format": args.format, "dataset_config": args.dataset_config, "dataset_split": args.dataset_split, "max_samples": args.dataset_max_samples, "columns": {k: v for k, v in { "prompt": args.prompt_key, "response": args.response_key, "instruction": args.instruction_key, "input": args.input_key, "output": args.output_key, }.items() if v}, }] def _apply_weight(ds, weight): """Up-sample (weight > 1, with repeats) or sub-sample (weight < 1) a source.""" if not weight or weight == 1.0 or len(ds) == 0: return ds target = max(0, int(round(len(ds) * float(weight)))) if target == 0: return ds.select([]) n = len(ds) return ds.select([i % n for i in range(target)]) # repeats when target > n def build_training_dataset(args, tokenizer): """Load, tokenize, weight and concatenate every source into one dataset. Each source carries its own format / columns / split / max_len so a skill dataset can be mixed with a general-data replay slice in one run.""" from datasets import concatenate_datasets specs = _source_specs(args) multi = len(specs) > 1 if multi: print(f"Mixing {len(specs)} dataset source(s):") parts = [] for i, spec in enumerate(specs): dataset = spec.get("dataset") if not dataset: raise SystemExit(f"mix source #{i} is missing 'dataset'") fmt = spec.get("format", args.format) raw = load_raw_dataset( dataset, config=spec.get("dataset_config"), split=spec.get("dataset_split", "train"), max_samples=spec.get("max_samples"), ) raw = raw.shuffle(seed=args.seed) keys = spec.get("columns") or {} max_len = spec.get("max_len", args.max_len) prefix = spec.get("prompt_prefix") tokenize = build_tokenize_fn( tokenizer, fmt, max_len, args.mask_prompt, keys, prompt_prefix=prefix) tok = raw.map(tokenize, remove_columns=raw.column_names, desc=f"tokenizing {dataset}") tok = tok.filter(lambda e: len(e["input_ids"]) > 1) tok = _apply_weight(tok, spec.get("weight")) if multi: wnote = f" (weight {spec['weight']})" if spec.get("weight") else "" print(f" - {dataset} [{fmt}] -> {len(tok)} examples{wnote}") parts.append(tok) ds = parts[0] if len(parts) == 1 else concatenate_datasets(parts) return ds.shuffle(seed=args.seed) class CausalCollator: """Pads input_ids with pad_token and labels with IGNORE_INDEX.""" def __init__(self, tokenizer): self.tok = tokenizer def __call__(self, batch): max_len = max(len(b["input_ids"]) for b in batch) input_ids, labels, attn = [], [], [] pad = self.tok.pad_token_id for b in batch: n = max_len - len(b["input_ids"]) input_ids.append(b["input_ids"] + [pad] * n) labels.append(b["labels"] + [IGNORE_INDEX] * n) attn.append([1] * len(b["input_ids"]) + [0] * n) return { "input_ids": torch.tensor(input_ids), "labels": torch.tensor(labels), "attention_mask": torch.tensor(attn), } # ---------------------------------------------------------------------------- # Model loading for each mode # ---------------------------------------------------------------------------- def _training_uses_cuda(args) -> bool: if args.device == "cpu": return False if args.device == "cuda": return True return torch.cuda.is_available() def _gpu_memory_summary() -> str: if not torch.cuda.is_available(): return "CUDA not available" free, total = torch.cuda.mem_get_info() alloc = torch.cuda.memory_allocated() reserved = torch.cuda.memory_reserved() return ( f"{free // 2**20} MiB free / {total // 2**20} MiB total " f"(allocated {alloc // 2**20} MiB, reserved {reserved // 2**20} MiB)" ) def _gpu_total_gib() -> float | None: if not torch.cuda.is_available(): return None _, total = torch.cuda.mem_get_info() return total / (1024**3) def _apply_low_vram_defaults(args) -> None: """Cap batch/seq length and prefer QLoRA on GPUs that cannot fit full LoRA.""" if not _training_uses_cuda(args): return total_gib = _gpu_total_gib() if total_gib is None or total_gib >= 6.0: return orig_batch, orig_max_len, orig_mode = args.batch_size, args.max_len, args.mode args.batch_size = min(args.batch_size, 1) args.max_len = min(args.max_len, 512) args.gradient_checkpointing = True if total_gib < 4.5 and args.mode == "lora": try: import bitsandbytes # noqa: F401 args.mode = "qlora" except ImportError: print( f"Warning: {total_gib:.1f} GiB GPU — full LoRA may OOM. " "Install finetune extras and use --mode qlora:\n" " uv sync --group finetune" ) if ( args.batch_size != orig_batch or args.max_len != orig_max_len or args.mode != orig_mode ): print( f"Low VRAM ({total_gib:.1f} GiB): adjusted training defaults — " f"batch_size {orig_batch}->{args.batch_size}, " f"max_len {orig_max_len}->{args.max_len}" + (f", mode {orig_mode}->{args.mode}" if args.mode != orig_mode else "") ) def _validate_cuda_device(args) -> None: if not _training_uses_cuda(args): return if torch.cuda.is_available(): return raise SystemExit( "CUDA training was requested (--device cuda or auto with a visible GPU) " "but PyTorch cannot use the GPU.\n" f" torch.cuda.is_available() = False\n" f" torch.cuda.device_count() = {torch.cuda.device_count()}\n" "Run `nvidia-smi` and check for driver errors (ERR! fields). " "If the GPU is busy or broken, free it or reboot, then retry.\n" "Fallback: pass --device cpu (slower, higher RAM use)." ) def clear_gpu_memory(*, reset_peak: bool = True) -> None: """Release cached CUDA allocations before loading a model.""" gc.collect() if not torch.cuda.is_available(): return torch.cuda.empty_cache() try: torch.cuda.ipc_collect() except Exception: pass if reset_peak: torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() def _cuda_device_map() -> str | dict[str, int]: """Keep weights on one GPU; avoid CPU offload on small cards.""" if torch.cuda.device_count() <= 1: return {"": 0} return "auto" def load_model_and_tokenizer(args): common = {"trust_remote_code": args.trust_remote_code} tokenizer = AutoTokenizer.from_pretrained(args.model, **common) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token use_cuda = _training_uses_cuda(args) bf16_ok = ( args.bf16 if args.bf16 is not None else use_cuda and torch.cuda.is_bf16_supported() ) dtype = torch.bfloat16 if bf16_ok else torch.float32 if args.mode == "qlora": if not use_cuda: raise SystemExit("QLoRA requires CUDA. Use --mode lora with --device cpu.") try: import bitsandbytes # noqa: F401 except ImportError as exc: raise SystemExit( "QLoRA requires bitsandbytes. Install with:\n" " uv sync --group finetune" ) from exc from transformers import BitsAndBytesConfig bnb = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 if bf16_ok else torch.float16, ) model = AutoModelForCausalLM.from_pretrained( args.model, quantization_config=bnb, device_map=_cuda_device_map(), **common, ) from peft import prepare_model_for_kbit_training model = prepare_model_for_kbit_training(model) else: model = AutoModelForCausalLM.from_pretrained( args.model, dtype=dtype, device_map=_cuda_device_map() if use_cuda else None, **common, ) if not use_cuda: model.to("cpu") if args.mode in ("lora", "qlora"): from peft import LoraConfig, get_peft_model targets = ("all-linear" if args.lora_targets == "all-linear" else [t.strip() for t in args.lora_targets.split(",")]) cfg = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, target_modules=targets, task_type="CAUSAL_LM") model = get_peft_model(model, cfg) model.print_trainable_parameters() if args.gradient_checkpointing: model.gradient_checkpointing_enable() model.config.use_cache = False return model, tokenizer, bf16_ok # ---------------------------------------------------------------------------- # Merge a trained LoRA adapter back into base weights # ---------------------------------------------------------------------------- def merge_adapter(adapter_dir, out_dir): from peft import PeftModel, PeftConfig cfg = PeftConfig.from_pretrained(adapter_dir) base = AutoModelForCausalLM.from_pretrained( cfg.base_model_name_or_path, torch_dtype=torch.bfloat16) tok = AutoTokenizer.from_pretrained(cfg.base_model_name_or_path) model = PeftModel.from_pretrained(base, adapter_dir) model = model.merge_and_unload() model.save_pretrained(out_dir) tok.save_pretrained(out_dir) print(f"Merged model saved to {out_dir}") def run_post_lm_eval( *, checkpoint_path: str, config_path: str, experiment_name: str, baseline_preset: str | None = None, adapter_path: str | None = None, ) -> dict | None: """Run slm-lm-eval via subprocess; return paths written under post_eval.""" baseline_results: Path | None = None if baseline_preset: baseline_name = f"{baseline_preset}__lm-eval-baseline" baseline_cmd = [ "uv", "run", "--package", "slm-evals", "slm-lm-eval", "--config", config_path, "--preset", baseline_preset, "--experiment-name", baseline_name, ] print(f"\n--- lm-eval baseline ({baseline_preset}) ---") subprocess.run(baseline_cmd, cwd=_REPO_ROOT, check=False) baseline_results = ( _REPO_ROOT / "results" / "lm_eval" / baseline_name / "results.json" ) cmd = [ "uv", "run", "--package", "slm-evals", "slm-lm-eval", "--config", config_path, "--model", checkpoint_path, "--experiment-name", experiment_name, ] if adapter_path: cmd.extend(["--adapter", adapter_path]) if baseline_results and baseline_results.is_file(): cmd.extend(["--compare-to", str(baseline_results)]) print(f"\n--- lm-eval candidate ({experiment_name}) ---") proc = subprocess.run(cmd, cwd=_REPO_ROOT, check=False) out_root = _REPO_ROOT / "results" / "lm_eval" / experiment_name post_eval = { "experiment_name": experiment_name, "config": config_path, "checkpoint_path": checkpoint_path, "adapter_path": adapter_path, "baseline_preset": baseline_preset, "results_json": str(out_root / "results.json"), "summary_md": str(out_root / "summary.md"), "comparison_md": str(out_root / "comparison.md") if (out_root / "comparison.md").is_file() else None, "exit_code": proc.returncode, } return post_eval if proc.returncode == 0 else post_eval # ---------------------------------------------------------------------------- # Main # ---------------------------------------------------------------------------- def main(): _load_dotenv(_REPO_ROOT / ".env") args = parse_args() if args.merge: out_dir = args.out or default_output_dir(None, "merged") merge_adapter(args.merge, out_dir) return model_id, preset_key, trust_remote_code = resolve_model_and_preset( model_arg=args.model, preset_arg=args.preset, ) args.model = model_id args.trust_remote_code = trust_remote_code if not args.dataset: args.dataset = ( os.environ.get("FINETUNE_DATASET") or str(_DEFAULT_DATASET) ) if not args.out: args.out = os.environ.get("FINETUNE_OUT") or default_output_dir( preset_key, args.mode ) Path(args.out).mkdir(parents=True, exist_ok=True) print(f"Base model: {args.model}") if preset_key: print(f"Preset: {preset_key}") if args.mix_json: print(f"Dataset mix: {len(json.loads(args.mix_json))} source(s)") else: print(f"Dataset: {args.dataset}") print(f"Output: {args.out}") print(f"Device: {args.device}") _validate_cuda_device(args) _apply_low_vram_defaults(args) if _training_uses_cuda(args): print(f"GPU before cleanup: {_gpu_memory_summary()}") clear_gpu_memory() print(f"GPU after cleanup: {_gpu_memory_summary()}") lr = args.lr or (2e-5 if args.mode == "full" else 2e-4) model, tokenizer, bf16_ok = load_model_and_tokenizer(args) if _training_uses_cuda(args): print(f"GPU after model load: {_gpu_memory_summary()}") ds = build_training_dataset(args, tokenizer) if args.val_split > 0: split = ds.train_test_split(test_size=args.val_split, seed=args.seed) train_ds, eval_ds = split["train"], split["test"] else: train_ds, eval_ds = ds, None # Default eval cadence to the run length so short (max_steps) runs still # evaluate mid-training instead of only at the end. eval_steps = args.eval_steps if eval_steps is None: eval_steps = max(1, args.max_steps // 5) if args.max_steps > 0 else 200 use_best = args.early_stopping_patience > 0 and eval_ds is not None # load_best_model_at_end needs save_steps aligned to eval_steps. save_steps = eval_steps if use_best else args.save_steps targs = TrainingArguments( output_dir=args.out, num_train_epochs=args.epochs, max_steps=args.max_steps, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=lr, lr_scheduler_type=args.lr_scheduler, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, max_grad_norm=args.max_grad_norm, logging_steps=args.logging_steps, eval_strategy="steps" if eval_ds is not None else "no", eval_steps=eval_steps, save_strategy="steps", save_steps=save_steps, save_total_limit=args.save_total_limit, load_best_model_at_end=use_best, metric_for_best_model="eval_loss" if use_best else None, greater_is_better=False if use_best else None, bf16=bf16_ok, fp16=(not bf16_ok and _training_uses_cuda(args)), gradient_checkpointing=args.gradient_checkpointing, neftune_noise_alpha=args.neftune_noise_alpha, report_to=args.report_to, seed=args.seed, ) callbacks = [] if use_best: from transformers import EarlyStoppingCallback callbacks.append( EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience) ) trainer = Trainer( model=model, args=targs, train_dataset=train_ds, eval_dataset=eval_ds, data_collator=CausalCollator(tokenizer), callbacks=callbacks, ) train_result = trainer.train(resume_from_checkpoint=args.resume) # ---- save ----------------------------------------------------------- model.config.use_cache = True trainer.save_model(args.out) # full weights OR adapter only tokenizer.save_pretrained(args.out) eval_metrics = None if eval_ds is not None: eval_metrics = trainer.evaluate() ppl = ( math.exp(eval_metrics["eval_loss"]) if eval_metrics["eval_loss"] < 20 else float("inf") ) print( f"\neval_loss={eval_metrics['eval_loss']:.4f} " f"perplexity={ppl:.2f}" ) results_path = save_training_results( args.out, args=args, preset_key=preset_key, train_count=len(train_ds), eval_count=len(eval_ds) if eval_ds is not None else 0, train_result=train_result, log_history=trainer.state.log_history, eval_metrics=eval_metrics, ) m = json.loads(results_path.read_text())["metrics"] print("\n--- scores ---") print(f"loss_score = {m['loss_score']} (lower is better)") print(f"result_score = {m['result_score']} (0–100, higher is better)") print(f"Saved to {results_path}") if args.mode in ("lora", "qlora"): merged = f"{args.out}-merged" print(f"\nAdapter saved to {args.out}") print( "Use in Gradio: set ACTIVE_MODEL to the matching *-lora preset " "in models.yaml, or merge with:\n" f" python research/finetune.py --merge {args.out} --out {merged}" ) else: print(f"\nFull model saved to {args.out}") # quick smoke generation try: model.eval() prompt = "Hello! Briefly introduce yourself." if tokenizer.chat_template: text = tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) else: text = prompt device = next(model.parameters()).device ids = tokenizer(text, return_tensors="pt").to(device) out = model.generate(**ids, max_new_tokens=60, do_sample=True, temperature=0.7, pad_token_id=tokenizer.pad_token_id) print("\n--- sample ---\n" + tokenizer.decode(out[0][ids["input_ids"].shape[1]:], skip_special_tokens=True)) except Exception as e: # smoke test is best-effort print(f"(sample generation skipped: {e})") if args.lm_eval_after: exp_name = f"{Path(args.out).name}__lm-eval-posttrain" if args.mode in ("lora", "qlora"): post_eval = run_post_lm_eval( checkpoint_path=args.model, config_path=args.lm_eval_config, experiment_name=exp_name, baseline_preset=args.lm_eval_baseline or preset_key, adapter_path=args.out, ) else: post_eval = run_post_lm_eval( checkpoint_path=args.out, config_path=args.lm_eval_config, experiment_name=exp_name, baseline_preset=args.lm_eval_baseline or preset_key, ) if post_eval: payload = json.loads(results_path.read_text()) payload["post_eval"] = post_eval results_path.write_text(json.dumps(payload, indent=2)) print(f"Appended post_eval to {results_path}") if _training_uses_cuda(args): clear_gpu_memory() print(f"GPU after training: {_gpu_memory_summary()}") if __name__ == "__main__": main()