| """Dataset preparation for Arbiter SFT and BDPO stages.""" |
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from datasets import Dataset, concatenate_datasets, load_dataset |
|
|
|
|
| |
|
|
| def _orca_to_messages(example: dict[str, Any]) -> dict[str, Any]: |
| sys_p = example.get("system_prompt") or "" |
| q = example.get("question") or "" |
| a = example.get("response") or "" |
| msgs = [] |
| if sys_p: |
| msgs.append({"role": "system", "content": sys_p}) |
| msgs.append({"role": "user", "content": q}) |
| msgs.append({"role": "assistant", "content": a}) |
| return {"messages": msgs} |
|
|
|
|
| def _dolphin_to_messages(example: dict[str, Any]) -> dict[str, Any]: |
| |
| sys_p = example.get("system") or example.get("system_prompt") or "" |
| user = example.get("instruction") or example.get("input") or example.get("prompt") or "" |
| resp = example.get("output") or example.get("response") or example.get("completion") or "" |
| msgs = [] |
| if sys_p: |
| msgs.append({"role": "system", "content": sys_p}) |
| msgs.append({"role": "user", "content": user}) |
| msgs.append({"role": "assistant", "content": resp}) |
| return {"messages": msgs} |
|
|
|
|
| def _normalize(ds: Dataset, source: str) -> Dataset: |
| if "orca" in source.lower(): |
| return ds.map(_orca_to_messages, remove_columns=ds.column_names) |
| return ds.map(_dolphin_to_messages, remove_columns=ds.column_names) |
|
|
|
|
| def load_sft_mix(spec_list: list[dict[str, Any]], seed: int = 42) -> Dataset: |
| parts: list[Dataset] = [] |
| for spec in spec_list: |
| kwargs = {"split": spec.get("split", "train")} |
| if "subset" in spec: |
| ds = load_dataset(spec["name"], spec["subset"], **kwargs) |
| else: |
| ds = load_dataset(spec["name"], **kwargs) |
| max_s = spec.get("max_samples") |
| if max_s is not None: |
| ds = ds.shuffle(seed=seed).select(range(min(max_s, len(ds)))) |
| parts.append(_normalize(ds, spec["name"])) |
| mixed = concatenate_datasets(parts).shuffle(seed=seed) |
| |
| return mixed.filter(lambda x: bool(x["messages"]) and bool(x["messages"][-1]["content"])) |
|
|
|
|
| |
|
|
| def _hh_to_pref(example: dict[str, Any]) -> dict[str, Any]: |
| |
| chosen = example["chosen"] |
| rejected = example["rejected"] |
| |
| cut = chosen.rfind("\n\nAssistant:") |
| prompt = chosen[: cut + len("\n\nAssistant:")] if cut >= 0 else "" |
| return { |
| "prompt": prompt, |
| "chosen": chosen[len(prompt):].strip(), |
| "rejected": rejected[len(prompt):].strip() if rejected.startswith(prompt) else rejected, |
| } |
|
|
|
|
| def load_pref_dataset(spec: dict[str, Any], seed: int = 42) -> Dataset: |
| ds = load_dataset(spec["name"], split=spec.get("split", "train")) |
| max_s = spec.get("max_samples") |
| if max_s is not None: |
| ds = ds.shuffle(seed=seed).select(range(min(max_s, len(ds)))) |
| return ds.map(_hh_to_pref, remove_columns=ds.column_names) |
|
|