"""Dataset preparation for Arbiter SFT and BDPO stages.""" from __future__ import annotations from typing import Any from datasets import Dataset, concatenate_datasets, load_dataset # ---------------------------- SFT ---------------------------- # def _orca_to_messages(example: dict[str, Any]) -> dict[str, Any]: sys_p = example.get("system_prompt") or "" q = example.get("question") or "" a = example.get("response") or "" msgs = [] if sys_p: msgs.append({"role": "system", "content": sys_p}) msgs.append({"role": "user", "content": q}) msgs.append({"role": "assistant", "content": a}) return {"messages": msgs} def _dolphin_to_messages(example: dict[str, Any]) -> dict[str, Any]: # Dolphin variants carry different field names; handle the common ones. sys_p = example.get("system") or example.get("system_prompt") or "" user = example.get("instruction") or example.get("input") or example.get("prompt") or "" resp = example.get("output") or example.get("response") or example.get("completion") or "" msgs = [] if sys_p: msgs.append({"role": "system", "content": sys_p}) msgs.append({"role": "user", "content": user}) msgs.append({"role": "assistant", "content": resp}) return {"messages": msgs} def _normalize(ds: Dataset, source: str) -> Dataset: if "orca" in source.lower(): return ds.map(_orca_to_messages, remove_columns=ds.column_names) return ds.map(_dolphin_to_messages, remove_columns=ds.column_names) def load_sft_mix(spec_list: list[dict[str, Any]], seed: int = 42) -> Dataset: parts: list[Dataset] = [] for spec in spec_list: kwargs = {"split": spec.get("split", "train")} if "subset" in spec: ds = load_dataset(spec["name"], spec["subset"], **kwargs) else: ds = load_dataset(spec["name"], **kwargs) max_s = spec.get("max_samples") if max_s is not None: ds = ds.shuffle(seed=seed).select(range(min(max_s, len(ds)))) parts.append(_normalize(ds, spec["name"])) mixed = concatenate_datasets(parts).shuffle(seed=seed) # Drop empties defensively. return mixed.filter(lambda x: bool(x["messages"]) and bool(x["messages"][-1]["content"])) # ---------------------------- BDPO --------------------------- # def _hh_to_pref(example: dict[str, Any]) -> dict[str, Any]: # Anthropic/hh-rlhf gives full transcripts; convert to (prompt, chosen, rejected) chosen = example["chosen"] rejected = example["rejected"] # Split at the last "Assistant:" — the shared prompt is everything before. cut = chosen.rfind("\n\nAssistant:") prompt = chosen[: cut + len("\n\nAssistant:")] if cut >= 0 else "" return { "prompt": prompt, "chosen": chosen[len(prompt):].strip(), "rejected": rejected[len(prompt):].strip() if rejected.startswith(prompt) else rejected, } def load_pref_dataset(spec: dict[str, Any], seed: int = 42) -> Dataset: ds = load_dataset(spec["name"], split=spec.get("split", "train")) max_s = spec.get("max_samples") if max_s is not None: ds = ds.shuffle(seed=seed).select(range(min(max_s, len(ds)))) return ds.map(_hh_to_pref, remove_columns=ds.column_names)