frankenstallm / data /prepare_preference_combined.py
pathcosmos's picture
feat: Add data pipeline scripts + phase reports (Tier 3 - reproducibility)
b3d361d verified
#!/usr/bin/env python3
"""
prepare_preference_combined.py โ€” Preference ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ + ํฌ๋งท ์ •๊ทœํ™” ์Šคํฌ๋ฆฝํŠธ
Phase 0F: ORPO ํŒŒ์ดํ”„๋ผ์ธ ์ค€๋น„
์ž…๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ: data/preference/
์ถœ๋ ฅ ํŒŒ์ผ: data/preference/combined_preference.jsonl
์ง€์› ํฌ๋งท:
- {prompt, chosen, rejected} (ํ‘œ์ค€ DPO/ORPO ํฌ๋งท)
- {question, chosen, rejected, [system]} (heegyu, kuotient orca-math ๊ณ„์—ด)
- {instruction, chosen, rejected} (instruction ํ‚ค ๋ณ€ํ˜•)
- {orig_instruction, orig_response_A/B, orig_preference} (nayohan preference-collection)
- {prompt, response_a, response_b, preferred} (response_a/b + preferred ํ‚ค)
- {prompt, response_a, response_b, winner} (winner ํ‚ค ๋ณ€ํ˜•)
- {instruction, preferred, dispreferred} (preferred/dispreferred ํ‚ค)
- {prompt, winning_response, losing_response} (Ultrafeedback ๊ณ„์—ด)
- {conversations, chosen, rejected} (conversations ๋ฆฌ์ŠคํŠธ ํฌ๋งท)
ํ’ˆ์งˆ ํ•„ํ„ฐ:
- chosen, rejected ๋ชจ๋‘ ๋น„์–ด์žˆ์ง€ ์•Š์„ ๊ฒƒ
- chosen != rejected
- ์ตœ์†Œ 20์ž ์ด์ƒ (chosen ๊ธฐ์ค€)
Usage:
python data/prepare_preference_combined.py [--input_dir data/preference] [--output data/preference/combined_preference.jsonl]
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
from typing import Optional
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# ํ•„๋“œ๋ช… ์ž๋™ ๊ฐ์ง€ ๋กœ์ง
# ---------------------------------------------------------------------------
def _extract_text(val) -> str:
"""๊ฐ’์ด str์ด๋ฉด ๊ทธ๋Œ€๋กœ, list(conversations ํฌ๋งท)์ด๋ฉด ๋งˆ์ง€๋ง‰ content ์ถ”์ถœ."""
if isinstance(val, str):
return val.strip()
if isinstance(val, list):
# [{"role": ..., "content": ...}, ...] ํ˜•ํƒœ
parts = []
for item in val:
if isinstance(item, dict):
content = item.get("content") or item.get("value") or item.get("text") or ""
parts.append(str(content))
else:
parts.append(str(item))
return "\n".join(parts).strip()
if isinstance(val, dict):
return (val.get("content") or val.get("value") or val.get("text") or "").strip()
return str(val).strip()
def _build_prompt(record: dict) -> str:
"""๋ ˆ์ฝ”๋“œ์—์„œ prompt ๋ฌธ์ž์—ด์„ ์ถ”์ถœํ•œ๋‹ค."""
# ํ‘œ์ค€ prompt ํ‚ค
for key in ("prompt", "instruction", "question", "input", "user_prompt", "orig_instruction"):
if key in record and record[key]:
val = _extract_text(record[key])
if val:
# system ํ•„๋“œ๊ฐ€ ์žˆ์œผ๋ฉด ์•ž์— ๋ถ™์ž„
system = record.get("system", "")
if system:
return f"{system.strip()}\n{val}"
return val
# conversations ํฌ๋งท: ์ฒซ ๋ฒˆ์งธ human ํ„ด
if "conversations" in record:
convs = record["conversations"]
if isinstance(convs, list):
for item in convs:
role = (item.get("role") or item.get("from") or "").lower()
if role in ("human", "user"):
return _extract_text(item.get("content") or item.get("value") or "")
return ""
def normalize_record(record: dict, source_name: str) -> Optional[dict]:
"""
๋‹จ์ผ ๋ ˆ์ฝ”๋“œ๋ฅผ {prompt, chosen, rejected} ๋กœ ์ •๊ทœํ™”.
๋ณ€ํ™˜ ๋ถˆ๊ฐ€ ์‹œ None ๋ฐ˜ํ™˜.
"""
chosen = ""
rejected = ""
# --- ํŒจํ„ด 1: ํ‘œ์ค€ {chosen, rejected} ---
if "chosen" in record and "rejected" in record:
chosen = _extract_text(record["chosen"])
rejected = _extract_text(record["rejected"])
# --- ํŒจํ„ด 2: nayohan preference-collection (orig_preference + orig_response_A/B) ---
elif "orig_preference" in record:
resp_a = _extract_text(record.get("orig_response_A", record.get("response_A", "")))
resp_b = _extract_text(record.get("orig_response_B", record.get("response_B", "")))
pref = str(record.get("orig_preference", "")).strip().upper()
if pref == "B":
chosen, rejected = resp_b, resp_a
else:
chosen, rejected = resp_a, resp_b
# --- ํŒจํ„ด 3: preferred/dispreferred ---
elif "preferred" in record and "dispreferred" in record:
chosen = _extract_text(record["preferred"])
rejected = _extract_text(record["dispreferred"])
# --- ํŒจํ„ด 4: response_a/b + preferred or winner ํ‚ค ---
elif "response_a" in record and "response_b" in record:
resp_a = _extract_text(record["response_a"])
resp_b = _extract_text(record["response_b"])
winner_key = record.get("preferred") or record.get("winner") or ""
winner = str(winner_key).strip().lower()
if winner in ("b", "response_b", "model_b"):
chosen, rejected = resp_b, resp_a
else:
# ๊ธฐ๋ณธ: A๊ฐ€ chosen
chosen, rejected = resp_a, resp_b
# --- ํŒจํ„ด 5: winning_response / losing_response (Ultrafeedback ๊ณ„์—ด) ---
elif "winning_response" in record and "losing_response" in record:
chosen = _extract_text(record["winning_response"])
rejected = _extract_text(record["losing_response"])
# --- ํŒจํ„ด 6: completions ๋ฆฌ์ŠคํŠธ (์ผ๋ถ€ HH-RLHF ๋ณ€ํ˜•) ---
elif "completions" in record:
completions = record["completions"]
if isinstance(completions, list) and len(completions) >= 2:
# rating ์žˆ์œผ๋ฉด ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ
def rating(c):
return c.get("rating", c.get("score", 0)) if isinstance(c, dict) else 0
sorted_c = sorted(completions, key=rating, reverse=True)
chosen = _extract_text(sorted_c[0].get("text", sorted_c[0]) if isinstance(sorted_c[0], dict) else sorted_c[0])
rejected = _extract_text(sorted_c[-1].get("text", sorted_c[-1]) if isinstance(sorted_c[-1], dict) else sorted_c[-1])
else:
return None # ์•Œ ์ˆ˜ ์—†๋Š” ํฌ๋งท
prompt = _build_prompt(record)
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
# ---------------------------------------------------------------------------
# ํ’ˆ์งˆ ํ•„ํ„ฐ
# ---------------------------------------------------------------------------
MIN_LEN = 20
def passes_quality_filter(record: dict) -> bool:
"""ํ’ˆ์งˆ ํ•„ํ„ฐ: chosen/rejected ๋น„์–ด์žˆ์ง€ ์•Š๊ณ , ๋‹ค๋ฅด๊ณ , ์ตœ์†Œ ๊ธธ์ด ์ถฉ์กฑ."""
prompt = record.get("prompt", "")
chosen = record.get("chosen", "")
rejected = record.get("rejected", "")
if not chosen or not rejected:
return False
if chosen == rejected:
return False
if len(chosen) < MIN_LEN:
return False
if not prompt:
# prompt ์—†์œผ๋ฉด ๊ฒฝ๊ณ ๋งŒ โ€” ์™„์ „ํžˆ ๋ฒ„๋ฆฌ์ง€๋Š” ์•Š์Œ (ORPO๋Š” prompt ํ•„์ˆ˜์ด๋ฏ€๋กœ ์‹ค์ œ๋กœ ์ œ์™ธ)
return False
return True
# ---------------------------------------------------------------------------
# ํŒŒ์ผ๋ณ„ ๋กœ๋”
# ---------------------------------------------------------------------------
def load_jsonl(path: Path):
"""JSONL ํŒŒ์ผ์„ ์ˆœ์ฐจ์ ์œผ๋กœ ํŒŒ์‹ฑํ•˜๋Š” ์ œ๋„ˆ๋ ˆ์ดํ„ฐ."""
with path.open("r", encoding="utf-8") as f:
for lineno, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
yield json.loads(line)
except json.JSONDecodeError as e:
log.warning(f" JSON ํŒŒ์‹ฑ ์˜ค๋ฅ˜ {path.name}:{lineno} โ€” {e}")
def process_file(src_path: Path, out_f, stats: dict) -> None:
"""๋‹จ์ผ JSONL ํŒŒ์ผ์„ ์ฝ์–ด ์ •๊ทœํ™” ํ›„ out_f์— ์“ด๋‹ค. stats ๋”•์…”๋„ˆ๋ฆฌ ๊ฐฑ์‹ ."""
source_name = src_path.stem
loaded = 0
written = 0
skipped_format = 0
skipped_quality = 0
log.info(f" ๋กœ๋”ฉ: {src_path.name}")
for record in load_jsonl(src_path):
loaded += 1
normalized = normalize_record(record, source_name)
if normalized is None:
skipped_format += 1
continue
if not passes_quality_filter(normalized):
skipped_quality += 1
continue
out_f.write(json.dumps(normalized, ensure_ascii=False) + "\n")
written += 1
log.info(
f" {source_name}: ๋กœ๋”ฉ {loaded:,} โ†’ ํฌ๋งท ์Šคํ‚ต {skipped_format:,} โ†’ ํ’ˆ์งˆ ์Šคํ‚ต {skipped_quality:,} โ†’ ์ถœ๋ ฅ {written:,}"
)
stats[source_name] = {
"loaded": loaded,
"skipped_format": skipped_format,
"skipped_quality": skipped_quality,
"written": written,
}
# ---------------------------------------------------------------------------
# ๋ฉ”์ธ
# ---------------------------------------------------------------------------
# ์ฒ˜๋ฆฌํ•  ํŒŒ์ผ ๋ชฉ๋ก (์ˆœ์„œ ๊ณ ์ • โ†’ ์žฌํ˜„์„ฑ)
TARGET_FILES = [
"heegyu_orca-math-korean-preference-cleaned.jsonl",
"kuotient_orca-math-korean-dpo-pairs.jsonl",
"nayohan_preference-collection-ko-full.jsonl",
"maywell_ko_Ultrafeedback_binarized.jsonl",
"jojo0217_korean_rlhf_dataset.jsonl",
"lemon-mint_korean-realqa-reasoning-v01-preference.jsonl",
"tellang_yeji-preference-ko-v1.jsonl",
]
def main():
parser = argparse.ArgumentParser(
description="Preference ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ + ํฌ๋งท ์ •๊ทœํ™” (ORPO ํ˜ธํ™˜)"
)
parser.add_argument(
"--input_dir",
type=str,
default="data/preference",
help="์ž…๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ (๊ธฐ๋ณธ: data/preference)",
)
parser.add_argument(
"--output",
type=str,
default="data/preference/combined_preference.jsonl",
help="์ถœ๋ ฅ ํŒŒ์ผ ๊ฒฝ๋กœ",
)
parser.add_argument(
"--include_all",
action="store_true",
help="TARGET_FILES ๋ชฉ๋ก ์™ธ์˜ .jsonl ํŒŒ์ผ๋„ ํฌํ•จ",
)
args = parser.parse_args()
input_dir = Path(args.input_dir)
output_path = Path(args.output)
if not input_dir.is_dir():
log.error(f"์ž…๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ ์—†์Œ: {input_dir}")
sys.exit(1)
# ์ฒ˜๋ฆฌ ํŒŒ์ผ ๊ฒฐ์ •
if args.include_all:
src_files = sorted(input_dir.glob("*.jsonl"))
# combined_preference.jsonl ์ž๊ธฐ ์ž์‹  ์ œ์™ธ
src_files = [f for f in src_files if f.name != output_path.name]
else:
src_files = []
for fname in TARGET_FILES:
p = input_dir / fname
if p.exists():
src_files.append(p)
else:
log.warning(f"ํŒŒ์ผ ์—†์Œ (์Šคํ‚ต): {p}")
if not src_files:
log.error("์ฒ˜๋ฆฌํ•  JSONL ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค.")
sys.exit(1)
output_path.parent.mkdir(parents=True, exist_ok=True)
log.info("=" * 60)
log.info("Phase 0F: Preference ๋ฐ์ดํ„ฐ ํ†ตํ•ฉ")
log.info(f" ์ž…๋ ฅ ํŒŒ์ผ ์ˆ˜ : {len(src_files)}")
log.info(f" ์ถœ๋ ฅ ํŒŒ์ผ : {output_path}")
log.info(f" ์ตœ์†Œ ๊ธธ์ด ๊ธฐ์ค€: {MIN_LEN}์ž")
log.info("=" * 60)
stats: dict = {}
total_written = 0
with output_path.open("w", encoding="utf-8") as out_f:
for src_path in src_files:
process_file(src_path, out_f, stats)
total_written += stats.get(src_path.stem, {}).get("written", 0)
# ์ตœ์ข… ํ†ต๊ณ„ ์š”์•ฝ
log.info("")
log.info("=" * 60)
log.info("์ตœ์ข… ํ†ต๊ณ„ ์š”์•ฝ")
log.info("=" * 60)
log.info(f"{'๋ฐ์ดํ„ฐ์…‹':<50} {'๋กœ๋”ฉ':>8} {'ํฌ๋งท์Šคํ‚ต':>8} {'ํ’ˆ์งˆ์Šคํ‚ต':>8} {'์ถœ๋ ฅ':>8}")
log.info("-" * 86)
grand_loaded = 0
grand_fmt_skip = 0
grand_qual_skip = 0
for name, s in stats.items():
log.info(
f"{name:<50} {s['loaded']:>8,} {s['skipped_format']:>8,} {s['skipped_quality']:>8,} {s['written']:>8,}"
)
grand_loaded += s["loaded"]
grand_fmt_skip += s["skipped_format"]
grand_qual_skip += s["skipped_quality"]
log.info("-" * 86)
log.info(
f"{'ํ•ฉ๊ณ„':<50} {grand_loaded:>8,} {grand_fmt_skip:>8,} {grand_qual_skip:>8,} {total_written:>8,}"
)
log.info("=" * 60)
log.info(f"์ถœ๋ ฅ ์™„๋ฃŒ: {output_path} ({total_written:,}๊ฐœ ๋ ˆ์ฝ”๋“œ)")
if __name__ == "__main__":
main()