Buckets:
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| from training.hf_fingpt_train import ( | |
| _jsonable, | |
| _metric_rows, | |
| build_trainer_metrics_summary, | |
| copy_file_with_retry, | |
| package_status, | |
| sha256_file, | |
| utc_now, | |
| write_json, | |
| ) | |
| def append_jsonl(path: Path, payload: dict[str, Any]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("a", encoding="utf-8", newline="\n") as handle: | |
| handle.write(json.dumps(payload, sort_keys=True) + "\n") | |
| def load_preference_rows(path: Path) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | |
| if not path.exists(): | |
| raise FileNotFoundError(f"preference JSONL does not exist: {path}") | |
| admitted: list[dict[str, Any]] = [] | |
| rejected: list[dict[str, Any]] = [] | |
| with path.open("r", encoding="utf-8-sig", newline="") as handle: | |
| for line_no, line in enumerate(handle, start=1): | |
| if not line.strip(): | |
| continue | |
| item = json.loads(line) | |
| if not isinstance(item, dict): | |
| raise ValueError(f"record at {path}:{line_no} must be a JSON object") | |
| prompt = str(item.get("prompt") or "").strip() | |
| chosen = str(item.get("chosen") or "").strip() | |
| rejected_text = str(item.get("rejected") or "").strip() | |
| reason = None | |
| if not prompt or not chosen or not rejected_text: | |
| reason = "missing_prompt_chosen_or_rejected" | |
| elif chosen == rejected_text: | |
| reason = "chosen_equals_rejected" | |
| target = ((item.get("metadata") or {}).get("repair_target") or {}) | |
| if target.get("admitted_to_training") is False: | |
| reason = reason or "repair_target_not_admitted" | |
| row = { | |
| "prompt": prompt, | |
| "chosen": chosen, | |
| "rejected": rejected_text, | |
| "metadata": item.get("metadata") or {}, | |
| "source_line": line_no, | |
| } | |
| if reason: | |
| row["reject_reason"] = reason | |
| rejected.append(row) | |
| else: | |
| admitted.append(row) | |
| if not admitted: | |
| raise ValueError(f"no admissible preference pairs found in {path}") | |
| return admitted, rejected | |
| def split_rows(rows: list[dict[str, Any]], valid_ratio: float) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | |
| if len(rows) < 10: | |
| return rows, rows | |
| valid_count = max(1, int(round(len(rows) * valid_ratio))) | |
| valid_count = min(valid_count, len(rows) - 1) | |
| return rows[:-valid_count], rows[-valid_count:] | |
| def build_plan(args: argparse.Namespace, rows: list[dict[str, Any]], rejects: list[dict[str, Any]]) -> dict[str, Any]: | |
| preference_path = Path(args.preference_jsonl) | |
| return { | |
| "schema_version": "shft_preference_training_plan_v1", | |
| "created_at": utc_now(), | |
| "run_id": args.run_id, | |
| "source_run_id": args.source_run_id, | |
| "training_stage": "preference_dpo", | |
| "trainer": "dpo", | |
| "base_model_id": args.base_model_id, | |
| "start_adapter": args.start_adapter, | |
| "preference_jsonl": str(preference_path), | |
| "preference_jsonl_sha256": sha256_file(preference_path), | |
| "admitted_pair_count": len(rows), | |
| "rejected_pair_count": len(rejects), | |
| "rejected_pairs": rejects[:25], | |
| "hyperparameters": { | |
| "max_steps": args.max_steps, | |
| "learning_rate": args.learning_rate, | |
| "beta": args.beta, | |
| "per_device_train_batch_size": args.per_device_train_batch_size, | |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, | |
| "max_seq_length": args.max_seq_length, | |
| "valid_ratio": args.valid_ratio, | |
| }, | |
| "package_status": package_status(), | |
| "ok": True, | |
| } | |
| def _selected_checkpoint(output_dir: Path, trainer: Any) -> dict[str, Any]: | |
| state = getattr(trainer, "state", None) | |
| best = getattr(state, "best_model_checkpoint", None) | |
| best_metric = getattr(state, "best_metric", None) | |
| return { | |
| "schema_version": "shft_selected_preference_checkpoint_v1", | |
| "created_at": utc_now(), | |
| "selection_metric": "eval_loss", | |
| "selection_metric_value": _jsonable(best_metric), | |
| "checkpoint_path": best or str(output_dir / "adapter"), | |
| } | |
| def run_preference_training(args: argparse.Namespace) -> dict[str, Any]: | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| rows, rejects = load_preference_rows(Path(args.preference_jsonl)) | |
| train_rows, valid_rows = split_rows(rows, args.valid_ratio) | |
| plan = build_plan(args, rows, rejects) | |
| plan["train_pair_count"] = len(train_rows) | |
| plan["valid_pair_count"] = len(valid_rows) | |
| write_json(output_dir / "preference_training_plan.json", plan) | |
| if args.dry_run: | |
| result = { | |
| "schema_version": "shft_preference_training_result_v1", | |
| "created_at": utc_now(), | |
| "run_id": args.run_id, | |
| "source_run_id": args.source_run_id, | |
| "status": "dry_run_validated", | |
| "admitted_pair_count": len(rows), | |
| "rejected_pair_count": len(rejects), | |
| "ok": True, | |
| } | |
| write_json(output_dir / "preference_training_result.json", result) | |
| return result | |
| if not args.start_adapter: | |
| raise ValueError("--start-adapter is required for live preference training") | |
| import torch | |
| from datasets import Dataset | |
| from peft import PeftModel, prepare_model_for_kbit_training | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from trl import DPOConfig, DPOTrainer | |
| tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| quantization = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.base_model_id, | |
| quantization_config=quantization, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| model = prepare_model_for_kbit_training(model) | |
| model = PeftModel.from_pretrained(model, args.start_adapter, is_trainable=True) | |
| training_args = DPOConfig( | |
| output_dir=str(output_dir / "checkpoints"), | |
| max_steps=args.max_steps, | |
| per_device_train_batch_size=args.per_device_train_batch_size, | |
| per_device_eval_batch_size=1, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| learning_rate=args.learning_rate, | |
| beta=args.beta, | |
| max_length=args.max_seq_length, | |
| max_prompt_length=min(args.max_prompt_length, args.max_seq_length), | |
| logging_steps=args.logging_steps, | |
| eval_strategy="steps", | |
| eval_steps=args.eval_steps, | |
| save_strategy="steps", | |
| save_steps=args.save_steps, | |
| save_total_limit=args.save_total_limit, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| bf16=True, | |
| gradient_checkpointing=True, | |
| report_to=[], | |
| ) | |
| trainer = DPOTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=Dataset.from_list(train_rows), | |
| eval_dataset=Dataset.from_list(valid_rows), | |
| processing_class=tokenizer, | |
| ) | |
| train_result = trainer.train() | |
| adapter_dir = output_dir / "adapter" | |
| trainer.save_model(str(adapter_dir)) | |
| tokenizer.save_pretrained(str(adapter_dir)) | |
| log_history = list(getattr(trainer.state, "log_history", []) or []) | |
| for row in _metric_rows(log_history): | |
| append_jsonl(output_dir / "preference_trainer_metrics.jsonl", row) | |
| selected = _selected_checkpoint(output_dir, trainer) | |
| write_json(output_dir / "selected_preference_checkpoint.json", selected) | |
| summary = build_trainer_metrics_summary( | |
| rows=_metric_rows(log_history), | |
| selected_checkpoint=selected, | |
| overfit_tolerance=args.overfit_tolerance, | |
| ) | |
| write_json(output_dir / "preference_trainer_metrics_summary.json", summary) | |
| result = { | |
| "schema_version": "shft_preference_training_result_v1", | |
| "created_at": utc_now(), | |
| "run_id": args.run_id, | |
| "source_run_id": args.source_run_id, | |
| "training_stage": "preference_dpo", | |
| "status": "completed", | |
| "admitted_pair_count": len(rows), | |
| "train_pair_count": len(train_rows), | |
| "valid_pair_count": len(valid_rows), | |
| "train_metrics": _jsonable(getattr(train_result, "metrics", {})), | |
| "adapter_dir": str(adapter_dir), | |
| "ok": True, | |
| } | |
| write_json(output_dir / "preference_training_result.json", result) | |
| for source, target in [ | |
| ("preference_training_plan.json", "training_plan.json"), | |
| ("preference_training_result.json", "training_result.json"), | |
| ("preference_trainer_metrics_summary.json", "trainer_metrics_summary.json"), | |
| ("selected_preference_checkpoint.json", "selected_checkpoint.json"), | |
| ]: | |
| copy_file_with_retry(output_dir / source, output_dir / target, attempts=3, delay_seconds=0.5) | |
| return result | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser(description="Run SHFT DPO preference optimization from paired-eval losses.") | |
| parser.add_argument("--run-id", required=True) | |
| parser.add_argument("--source-run-id", required=True) | |
| parser.add_argument("--base-model-id", default=os.environ.get("SHFT_BASE_MODEL_ID", "Qwen/Qwen3-32B")) | |
| parser.add_argument("--start-adapter") | |
| parser.add_argument("--preference-jsonl", required=True) | |
| parser.add_argument("--output-dir", required=True) | |
| parser.add_argument("--max-steps", type=int, default=int(os.environ.get("SHFT_PREFERENCE_MAX_STEPS", "200"))) | |
| parser.add_argument("--learning-rate", type=float, default=float(os.environ.get("SHFT_PREFERENCE_LR", "5e-6"))) | |
| parser.add_argument("--beta", type=float, default=float(os.environ.get("SHFT_DPO_BETA", "0.1"))) | |
| parser.add_argument("--per-device-train-batch-size", type=int, default=1) | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=8) | |
| parser.add_argument("--max-seq-length", type=int, default=4096) | |
| parser.add_argument("--max-prompt-length", type=int, default=1536) | |
| parser.add_argument("--valid-ratio", type=float, default=0.1) | |
| parser.add_argument("--logging-steps", type=int, default=5) | |
| parser.add_argument("--eval-steps", type=int, default=25) | |
| parser.add_argument("--save-steps", type=int, default=25) | |
| parser.add_argument("--save-total-limit", type=int, default=3) | |
| parser.add_argument("--overfit-tolerance", type=float, default=0.1) | |
| parser.add_argument("--dry-run", action="store_true") | |
| return parser | |
| def main(argv: list[str] | None = None) -> int: | |
| args = build_parser().parse_args(argv) | |
| result = run_preference_training(args) | |
| print(json.dumps(result, indent=2)) | |
| return 0 if result.get("ok") else 2 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |
Xet Storage Details
- Size:
- 11.3 kB
- Xet hash:
- 14fe4d5e7ec64916b31c1f5007b2a6d58fd78c74a6dc90db03b87314dad0de56
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.