corp-env / scripts /prepare_sft_data.py
Navigam's picture
refactor: update training scripts and documentation for SFT and RLVR processes
4e1a75b
Raw
History Blame Contribute Delete
5.71 kB
"""Convert verified CORP-ENV trajectories into chat-format SFT JSONL.
Pass one or more processed JSONLs (e.g. `e1_m1_clean` + `h1_seed_clean`) from
`scripts/verify_examples.py`. Each output row is TRL-style chat SFT data:
{"task_id": "...", "example_id": "...", "messages": [...]}
"""
from __future__ import annotations
import argparse
from collections import defaultdict
import sys
from pathlib import Path
from typing import Any, Dict, List
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from scripts._trajectory_utils import ( # noqa: E402
actions_to_sft_messages,
deliberation_features,
extract_actions,
read_jsonl,
validate_stepwise_deliberation,
write_jsonl,
)
def convert_example(
example: Dict[str, Any],
min_pass_rate: float,
min_reasoning_steps: int,
min_conflict_steps: int,
min_resolution_steps: int,
require_stepwise_deliberation: bool,
) -> Dict[str, Any] | None:
if example.get("status") and example.get("status") != "clean":
return None
pass_rate = float(example.get("verifier_pass_rate", 1.0))
if pass_rate < min_pass_rate:
return None
task_id = str(example.get("task_id") or example.get("task") or "")
if not task_id:
return None
actions = extract_actions(example)
if require_stepwise_deliberation:
if validate_stepwise_deliberation(task_id, actions):
return None
features = deliberation_features(actions)
if int(features["reasoning_steps"]) < min_reasoning_steps:
return None
if int(features["conflict_steps"]) < min_conflict_steps:
return None
if int(features["resolution_steps"]) < min_resolution_steps:
return None
messages = actions_to_sft_messages(task_id, actions)
return {
"example_id": str(example.get("example_id") or example.get("id") or "unknown"),
"task_id": task_id,
"messages": messages,
"num_actions": len(actions),
"terminal_reward": example.get("terminal_reward"),
"verifier_pass_rate": example.get("verifier_pass_rate"),
"reasoning_steps": int(features["reasoning_steps"]),
"conflict_steps": int(features["conflict_steps"]),
"resolution_steps": int(features["resolution_steps"]),
"phase_progression_ok": bool(features["phase_progression_ok"]),
}
def _parse_input_paths(raw: List[str]) -> List[Path]:
"""Expand comma-separated entries and return unique ordered paths."""
out: List[Path] = []
for part in raw:
for p in part.split(","):
p = p.strip()
if p:
out.append(Path(p))
return out
def main() -> None:
parser = argparse.ArgumentParser(description="Prepare chat SFT data from verified examples.")
default_inputs = (
"data/processed/e1_m1_clean.jsonl,data/processed/h1_seed_clean.jsonl"
)
parser.add_argument(
"--input",
dest="inputs",
action="append",
default=None,
metavar="PATH",
help=(
"Processed JSONL (repeat flag or use commas). "
f"Default: {default_inputs}"
),
)
parser.add_argument("--output", default="data/sft/e1_m1_h1_examples.jsonl")
parser.add_argument("--min-pass-rate", type=float, default=0.80)
parser.add_argument("--min-reasoning-steps", type=int, default=1)
parser.add_argument("--min-conflict-steps", type=int, default=0)
parser.add_argument("--min-resolution-steps", type=int, default=0)
parser.add_argument(
"--require-stepwise-deliberation",
action="store_true",
help="Require task-specific SWD step-wise deliberation checks from verification utilities.",
)
parser.add_argument(
"--max-per-task",
type=int,
default=0,
help="Optional cap for kept SFT rows per task (0 = unlimited).",
)
args = parser.parse_args()
raw_inputs = list(args.inputs) if args.inputs else [default_inputs]
input_paths = _parse_input_paths(raw_inputs)
rows: List[Dict[str, Any]] = []
by_task_kept: Dict[str, int] = defaultdict(int)
seen_ids: set[str] = set()
skipped = 0
for path in input_paths:
if not path.is_file():
print(f"warning: input missing, skip: {path}", file=sys.stderr)
continue
for example in read_jsonl(path):
eid = str(example.get("example_id") or example.get("id") or "")
if eid and eid in seen_ids:
skipped += 1
continue
try:
row = convert_example(
example,
args.min_pass_rate,
args.min_reasoning_steps,
args.min_conflict_steps,
args.min_resolution_steps,
args.require_stepwise_deliberation,
)
except Exception as exc:
skipped += 1
print(f"skip {example.get('example_id', 'unknown')}: {exc}")
continue
if row is None:
skipped += 1
continue
if args.max_per_task > 0 and by_task_kept[row["task_id"]] >= args.max_per_task:
skipped += 1
continue
rows.append(row)
by_task_kept[row["task_id"]] += 1
eid2 = str(row.get("example_id") or "unknown")
if eid2 and eid2 != "unknown":
seen_ids.add(eid2)
write_jsonl(Path(args.output), rows)
print(f"Wrote {len(rows)} SFT conversations to {args.output}; skipped {skipped}.")
if __name__ == "__main__":
main()