File size: 5,709 Bytes
1ab87df 6b13adb 1ab87df 4e1a75b 1ab87df 4e1a75b 1ab87df 4e1a75b 1ab87df 4e1a75b 1ab87df 4e1a75b 1ab87df 4e1a75b 1ab87df 6b13adb 1ab87df 6b13adb 1ab87df 4e1a75b 1ab87df 6b13adb 1ab87df 4e1a75b 6b13adb 1ab87df 6b13adb 1ab87df 6b13adb 4e1a75b 6b13adb 4e1a75b 6b13adb 4e1a75b 6b13adb 1ab87df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """Convert verified CORP-ENV trajectories into chat-format SFT JSONL.
Pass one or more processed JSONLs (e.g. `e1_m1_clean` + `h1_seed_clean`) from
`scripts/verify_examples.py`. Each output row is TRL-style chat SFT data:
{"task_id": "...", "example_id": "...", "messages": [...]}
"""
from __future__ import annotations
import argparse
from collections import defaultdict
import sys
from pathlib import Path
from typing import Any, Dict, List
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from scripts._trajectory_utils import ( # noqa: E402
actions_to_sft_messages,
deliberation_features,
extract_actions,
read_jsonl,
validate_stepwise_deliberation,
write_jsonl,
)
def convert_example(
example: Dict[str, Any],
min_pass_rate: float,
min_reasoning_steps: int,
min_conflict_steps: int,
min_resolution_steps: int,
require_stepwise_deliberation: bool,
) -> Dict[str, Any] | None:
if example.get("status") and example.get("status") != "clean":
return None
pass_rate = float(example.get("verifier_pass_rate", 1.0))
if pass_rate < min_pass_rate:
return None
task_id = str(example.get("task_id") or example.get("task") or "")
if not task_id:
return None
actions = extract_actions(example)
if require_stepwise_deliberation:
if validate_stepwise_deliberation(task_id, actions):
return None
features = deliberation_features(actions)
if int(features["reasoning_steps"]) < min_reasoning_steps:
return None
if int(features["conflict_steps"]) < min_conflict_steps:
return None
if int(features["resolution_steps"]) < min_resolution_steps:
return None
messages = actions_to_sft_messages(task_id, actions)
return {
"example_id": str(example.get("example_id") or example.get("id") or "unknown"),
"task_id": task_id,
"messages": messages,
"num_actions": len(actions),
"terminal_reward": example.get("terminal_reward"),
"verifier_pass_rate": example.get("verifier_pass_rate"),
"reasoning_steps": int(features["reasoning_steps"]),
"conflict_steps": int(features["conflict_steps"]),
"resolution_steps": int(features["resolution_steps"]),
"phase_progression_ok": bool(features["phase_progression_ok"]),
}
def _parse_input_paths(raw: List[str]) -> List[Path]:
"""Expand comma-separated entries and return unique ordered paths."""
out: List[Path] = []
for part in raw:
for p in part.split(","):
p = p.strip()
if p:
out.append(Path(p))
return out
def main() -> None:
parser = argparse.ArgumentParser(description="Prepare chat SFT data from verified examples.")
default_inputs = (
"data/processed/e1_m1_clean.jsonl,data/processed/h1_seed_clean.jsonl"
)
parser.add_argument(
"--input",
dest="inputs",
action="append",
default=None,
metavar="PATH",
help=(
"Processed JSONL (repeat flag or use commas). "
f"Default: {default_inputs}"
),
)
parser.add_argument("--output", default="data/sft/e1_m1_h1_examples.jsonl")
parser.add_argument("--min-pass-rate", type=float, default=0.80)
parser.add_argument("--min-reasoning-steps", type=int, default=1)
parser.add_argument("--min-conflict-steps", type=int, default=0)
parser.add_argument("--min-resolution-steps", type=int, default=0)
parser.add_argument(
"--require-stepwise-deliberation",
action="store_true",
help="Require task-specific SWD step-wise deliberation checks from verification utilities.",
)
parser.add_argument(
"--max-per-task",
type=int,
default=0,
help="Optional cap for kept SFT rows per task (0 = unlimited).",
)
args = parser.parse_args()
raw_inputs = list(args.inputs) if args.inputs else [default_inputs]
input_paths = _parse_input_paths(raw_inputs)
rows: List[Dict[str, Any]] = []
by_task_kept: Dict[str, int] = defaultdict(int)
seen_ids: set[str] = set()
skipped = 0
for path in input_paths:
if not path.is_file():
print(f"warning: input missing, skip: {path}", file=sys.stderr)
continue
for example in read_jsonl(path):
eid = str(example.get("example_id") or example.get("id") or "")
if eid and eid in seen_ids:
skipped += 1
continue
try:
row = convert_example(
example,
args.min_pass_rate,
args.min_reasoning_steps,
args.min_conflict_steps,
args.min_resolution_steps,
args.require_stepwise_deliberation,
)
except Exception as exc:
skipped += 1
print(f"skip {example.get('example_id', 'unknown')}: {exc}")
continue
if row is None:
skipped += 1
continue
if args.max_per_task > 0 and by_task_kept[row["task_id"]] >= args.max_per_task:
skipped += 1
continue
rows.append(row)
by_task_kept[row["task_id"]] += 1
eid2 = str(row.get("example_id") or "unknown")
if eid2 and eid2 != "unknown":
seen_ids.add(eid2)
write_jsonl(Path(args.output), rows)
print(f"Wrote {len(rows)} SFT conversations to {args.output}; skipped {skipped}.")
if __name__ == "__main__":
main()
|