Cloud / arbiter /data.py
Grimxlock's picture
Arbiter pipeline: SFT (Orca+Dolphin) -> BDPO -> Heretic
92104ff verified
"""Dataset preparation for Arbiter SFT and BDPO stages."""
from __future__ import annotations
from typing import Any
from datasets import Dataset, concatenate_datasets, load_dataset
# ---------------------------- SFT ---------------------------- #
def _orca_to_messages(example: dict[str, Any]) -> dict[str, Any]:
sys_p = example.get("system_prompt") or ""
q = example.get("question") or ""
a = example.get("response") or ""
msgs = []
if sys_p:
msgs.append({"role": "system", "content": sys_p})
msgs.append({"role": "user", "content": q})
msgs.append({"role": "assistant", "content": a})
return {"messages": msgs}
def _dolphin_to_messages(example: dict[str, Any]) -> dict[str, Any]:
# Dolphin variants carry different field names; handle the common ones.
sys_p = example.get("system") or example.get("system_prompt") or ""
user = example.get("instruction") or example.get("input") or example.get("prompt") or ""
resp = example.get("output") or example.get("response") or example.get("completion") or ""
msgs = []
if sys_p:
msgs.append({"role": "system", "content": sys_p})
msgs.append({"role": "user", "content": user})
msgs.append({"role": "assistant", "content": resp})
return {"messages": msgs}
def _normalize(ds: Dataset, source: str) -> Dataset:
if "orca" in source.lower():
return ds.map(_orca_to_messages, remove_columns=ds.column_names)
return ds.map(_dolphin_to_messages, remove_columns=ds.column_names)
def load_sft_mix(spec_list: list[dict[str, Any]], seed: int = 42) -> Dataset:
parts: list[Dataset] = []
for spec in spec_list:
kwargs = {"split": spec.get("split", "train")}
if "subset" in spec:
ds = load_dataset(spec["name"], spec["subset"], **kwargs)
else:
ds = load_dataset(spec["name"], **kwargs)
max_s = spec.get("max_samples")
if max_s is not None:
ds = ds.shuffle(seed=seed).select(range(min(max_s, len(ds))))
parts.append(_normalize(ds, spec["name"]))
mixed = concatenate_datasets(parts).shuffle(seed=seed)
# Drop empties defensively.
return mixed.filter(lambda x: bool(x["messages"]) and bool(x["messages"][-1]["content"]))
# ---------------------------- BDPO --------------------------- #
def _hh_to_pref(example: dict[str, Any]) -> dict[str, Any]:
# Anthropic/hh-rlhf gives full transcripts; convert to (prompt, chosen, rejected)
chosen = example["chosen"]
rejected = example["rejected"]
# Split at the last "Assistant:" — the shared prompt is everything before.
cut = chosen.rfind("\n\nAssistant:")
prompt = chosen[: cut + len("\n\nAssistant:")] if cut >= 0 else ""
return {
"prompt": prompt,
"chosen": chosen[len(prompt):].strip(),
"rejected": rejected[len(prompt):].strip() if rejected.startswith(prompt) else rejected,
}
def load_pref_dataset(spec: dict[str, Any], seed: int = 42) -> Dataset:
ds = load_dataset(spec["name"], split=spec.get("split", "train"))
max_s = spec.get("max_samples")
if max_s is not None:
ds = ds.shuffle(seed=seed).select(range(min(max_s, len(ds))))
return ds.map(_hh_to_pref, remove_columns=ds.column_names)