linvest21's picture
download
raw
11.3 kB
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any
from training.hf_fingpt_train import (
_jsonable,
_metric_rows,
build_trainer_metrics_summary,
copy_file_with_retry,
package_status,
sha256_file,
utc_now,
write_json,
)
def append_jsonl(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8", newline="\n") as handle:
handle.write(json.dumps(payload, sort_keys=True) + "\n")
def load_preference_rows(path: Path) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
if not path.exists():
raise FileNotFoundError(f"preference JSONL does not exist: {path}")
admitted: list[dict[str, Any]] = []
rejected: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8-sig", newline="") as handle:
for line_no, line in enumerate(handle, start=1):
if not line.strip():
continue
item = json.loads(line)
if not isinstance(item, dict):
raise ValueError(f"record at {path}:{line_no} must be a JSON object")
prompt = str(item.get("prompt") or "").strip()
chosen = str(item.get("chosen") or "").strip()
rejected_text = str(item.get("rejected") or "").strip()
reason = None
if not prompt or not chosen or not rejected_text:
reason = "missing_prompt_chosen_or_rejected"
elif chosen == rejected_text:
reason = "chosen_equals_rejected"
target = ((item.get("metadata") or {}).get("repair_target") or {})
if target.get("admitted_to_training") is False:
reason = reason or "repair_target_not_admitted"
row = {
"prompt": prompt,
"chosen": chosen,
"rejected": rejected_text,
"metadata": item.get("metadata") or {},
"source_line": line_no,
}
if reason:
row["reject_reason"] = reason
rejected.append(row)
else:
admitted.append(row)
if not admitted:
raise ValueError(f"no admissible preference pairs found in {path}")
return admitted, rejected
def split_rows(rows: list[dict[str, Any]], valid_ratio: float) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
if len(rows) < 10:
return rows, rows
valid_count = max(1, int(round(len(rows) * valid_ratio)))
valid_count = min(valid_count, len(rows) - 1)
return rows[:-valid_count], rows[-valid_count:]
def build_plan(args: argparse.Namespace, rows: list[dict[str, Any]], rejects: list[dict[str, Any]]) -> dict[str, Any]:
preference_path = Path(args.preference_jsonl)
return {
"schema_version": "shft_preference_training_plan_v1",
"created_at": utc_now(),
"run_id": args.run_id,
"source_run_id": args.source_run_id,
"training_stage": "preference_dpo",
"trainer": "dpo",
"base_model_id": args.base_model_id,
"start_adapter": args.start_adapter,
"preference_jsonl": str(preference_path),
"preference_jsonl_sha256": sha256_file(preference_path),
"admitted_pair_count": len(rows),
"rejected_pair_count": len(rejects),
"rejected_pairs": rejects[:25],
"hyperparameters": {
"max_steps": args.max_steps,
"learning_rate": args.learning_rate,
"beta": args.beta,
"per_device_train_batch_size": args.per_device_train_batch_size,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"max_seq_length": args.max_seq_length,
"valid_ratio": args.valid_ratio,
},
"package_status": package_status(),
"ok": True,
}
def _selected_checkpoint(output_dir: Path, trainer: Any) -> dict[str, Any]:
state = getattr(trainer, "state", None)
best = getattr(state, "best_model_checkpoint", None)
best_metric = getattr(state, "best_metric", None)
return {
"schema_version": "shft_selected_preference_checkpoint_v1",
"created_at": utc_now(),
"selection_metric": "eval_loss",
"selection_metric_value": _jsonable(best_metric),
"checkpoint_path": best or str(output_dir / "adapter"),
}
def run_preference_training(args: argparse.Namespace) -> dict[str, Any]:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
rows, rejects = load_preference_rows(Path(args.preference_jsonl))
train_rows, valid_rows = split_rows(rows, args.valid_ratio)
plan = build_plan(args, rows, rejects)
plan["train_pair_count"] = len(train_rows)
plan["valid_pair_count"] = len(valid_rows)
write_json(output_dir / "preference_training_plan.json", plan)
if args.dry_run:
result = {
"schema_version": "shft_preference_training_result_v1",
"created_at": utc_now(),
"run_id": args.run_id,
"source_run_id": args.source_run_id,
"status": "dry_run_validated",
"admitted_pair_count": len(rows),
"rejected_pair_count": len(rejects),
"ok": True,
}
write_json(output_dir / "preference_training_result.json", result)
return result
if not args.start_adapter:
raise ValueError("--start-adapter is required for live preference training")
import torch
from datasets import Dataset
from peft import PeftModel, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import DPOConfig, DPOTrainer
tokenizer = AutoTokenizer.from_pretrained(args.base_model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
quantization = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.base_model_id,
quantization_config=quantization,
device_map="auto",
trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
model = PeftModel.from_pretrained(model, args.start_adapter, is_trainable=True)
training_args = DPOConfig(
output_dir=str(output_dir / "checkpoints"),
max_steps=args.max_steps,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=1,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
beta=args.beta,
max_length=args.max_seq_length,
max_prompt_length=min(args.max_prompt_length, args.max_seq_length),
logging_steps=args.logging_steps,
eval_strategy="steps",
eval_steps=args.eval_steps,
save_strategy="steps",
save_steps=args.save_steps,
save_total_limit=args.save_total_limit,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
bf16=True,
gradient_checkpointing=True,
report_to=[],
)
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=Dataset.from_list(train_rows),
eval_dataset=Dataset.from_list(valid_rows),
processing_class=tokenizer,
)
train_result = trainer.train()
adapter_dir = output_dir / "adapter"
trainer.save_model(str(adapter_dir))
tokenizer.save_pretrained(str(adapter_dir))
log_history = list(getattr(trainer.state, "log_history", []) or [])
for row in _metric_rows(log_history):
append_jsonl(output_dir / "preference_trainer_metrics.jsonl", row)
selected = _selected_checkpoint(output_dir, trainer)
write_json(output_dir / "selected_preference_checkpoint.json", selected)
summary = build_trainer_metrics_summary(
rows=_metric_rows(log_history),
selected_checkpoint=selected,
overfit_tolerance=args.overfit_tolerance,
)
write_json(output_dir / "preference_trainer_metrics_summary.json", summary)
result = {
"schema_version": "shft_preference_training_result_v1",
"created_at": utc_now(),
"run_id": args.run_id,
"source_run_id": args.source_run_id,
"training_stage": "preference_dpo",
"status": "completed",
"admitted_pair_count": len(rows),
"train_pair_count": len(train_rows),
"valid_pair_count": len(valid_rows),
"train_metrics": _jsonable(getattr(train_result, "metrics", {})),
"adapter_dir": str(adapter_dir),
"ok": True,
}
write_json(output_dir / "preference_training_result.json", result)
for source, target in [
("preference_training_plan.json", "training_plan.json"),
("preference_training_result.json", "training_result.json"),
("preference_trainer_metrics_summary.json", "trainer_metrics_summary.json"),
("selected_preference_checkpoint.json", "selected_checkpoint.json"),
]:
copy_file_with_retry(output_dir / source, output_dir / target, attempts=3, delay_seconds=0.5)
return result
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run SHFT DPO preference optimization from paired-eval losses.")
parser.add_argument("--run-id", required=True)
parser.add_argument("--source-run-id", required=True)
parser.add_argument("--base-model-id", default=os.environ.get("SHFT_BASE_MODEL_ID", "Qwen/Qwen3-32B"))
parser.add_argument("--start-adapter")
parser.add_argument("--preference-jsonl", required=True)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--max-steps", type=int, default=int(os.environ.get("SHFT_PREFERENCE_MAX_STEPS", "200")))
parser.add_argument("--learning-rate", type=float, default=float(os.environ.get("SHFT_PREFERENCE_LR", "5e-6")))
parser.add_argument("--beta", type=float, default=float(os.environ.get("SHFT_DPO_BETA", "0.1")))
parser.add_argument("--per-device-train-batch-size", type=int, default=1)
parser.add_argument("--gradient-accumulation-steps", type=int, default=8)
parser.add_argument("--max-seq-length", type=int, default=4096)
parser.add_argument("--max-prompt-length", type=int, default=1536)
parser.add_argument("--valid-ratio", type=float, default=0.1)
parser.add_argument("--logging-steps", type=int, default=5)
parser.add_argument("--eval-steps", type=int, default=25)
parser.add_argument("--save-steps", type=int, default=25)
parser.add_argument("--save-total-limit", type=int, default=3)
parser.add_argument("--overfit-tolerance", type=float, default=0.1)
parser.add_argument("--dry-run", action="store_true")
return parser
def main(argv: list[str] | None = None) -> int:
args = build_parser().parse_args(argv)
result = run_preference_training(args)
print(json.dumps(result, indent=2))
return 0 if result.get("ok") else 2
if __name__ == "__main__":
raise SystemExit(main())

Xet Storage Details

Size:
11.3 kB
·
Xet hash:
14fe4d5e7ec64916b31c1f5007b2a6d58fd78c74a6dc90db03b87314dad0de56

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.