File size: 5,709 Bytes
1ab87df
 
6b13adb
 
1ab87df
 
 
 
 
 
 
4e1a75b
1ab87df
 
 
 
 
 
 
 
 
 
4e1a75b
1ab87df
 
4e1a75b
1ab87df
 
 
 
4e1a75b
 
 
 
 
 
 
 
1ab87df
 
 
 
 
 
 
 
 
4e1a75b
 
 
 
 
 
 
 
 
 
1ab87df
 
 
 
 
 
 
 
4e1a75b
 
 
 
1ab87df
 
 
6b13adb
 
 
 
 
 
 
 
 
 
 
1ab87df
 
6b13adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ab87df
4e1a75b
 
 
 
 
 
 
 
 
 
 
 
 
 
1ab87df
6b13adb
 
1ab87df
 
4e1a75b
6b13adb
1ab87df
6b13adb
 
 
1ab87df
6b13adb
 
 
 
 
 
4e1a75b
 
 
 
 
 
 
 
6b13adb
 
 
 
 
 
 
4e1a75b
 
 
6b13adb
4e1a75b
6b13adb
 
 
1ab87df
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Convert verified CORP-ENV trajectories into chat-format SFT JSONL.

Pass one or more processed JSONLs (e.g. `e1_m1_clean` + `h1_seed_clean`) from
`scripts/verify_examples.py`. Each output row is TRL-style chat SFT data:

  {"task_id": "...", "example_id": "...", "messages": [...]}
"""

from __future__ import annotations

import argparse
from collections import defaultdict
import sys
from pathlib import Path
from typing import Any, Dict, List

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from scripts._trajectory_utils import (  # noqa: E402
    actions_to_sft_messages,
    deliberation_features,
    extract_actions,
    read_jsonl,
    validate_stepwise_deliberation,
    write_jsonl,
)


def convert_example(
    example: Dict[str, Any],
    min_pass_rate: float,
    min_reasoning_steps: int,
    min_conflict_steps: int,
    min_resolution_steps: int,
    require_stepwise_deliberation: bool,
) -> Dict[str, Any] | None:
    if example.get("status") and example.get("status") != "clean":
        return None
    pass_rate = float(example.get("verifier_pass_rate", 1.0))
    if pass_rate < min_pass_rate:
        return None
    task_id = str(example.get("task_id") or example.get("task") or "")
    if not task_id:
        return None
    actions = extract_actions(example)
    if require_stepwise_deliberation:
        if validate_stepwise_deliberation(task_id, actions):
            return None
    features = deliberation_features(actions)
    if int(features["reasoning_steps"]) < min_reasoning_steps:
        return None
    if int(features["conflict_steps"]) < min_conflict_steps:
        return None
    if int(features["resolution_steps"]) < min_resolution_steps:
        return None
    messages = actions_to_sft_messages(task_id, actions)
    return {
        "example_id": str(example.get("example_id") or example.get("id") or "unknown"),
        "task_id": task_id,
        "messages": messages,
        "num_actions": len(actions),
        "terminal_reward": example.get("terminal_reward"),
        "verifier_pass_rate": example.get("verifier_pass_rate"),
        "reasoning_steps": int(features["reasoning_steps"]),
        "conflict_steps": int(features["conflict_steps"]),
        "resolution_steps": int(features["resolution_steps"]),
        "phase_progression_ok": bool(features["phase_progression_ok"]),
    }


def _parse_input_paths(raw: List[str]) -> List[Path]:
    """Expand comma-separated entries and return unique ordered paths."""
    out: List[Path] = []
    for part in raw:
        for p in part.split(","):
            p = p.strip()
            if p:
                out.append(Path(p))
    return out


def main() -> None:
    parser = argparse.ArgumentParser(description="Prepare chat SFT data from verified examples.")
    default_inputs = (
        "data/processed/e1_m1_clean.jsonl,data/processed/h1_seed_clean.jsonl"
    )
    parser.add_argument(
        "--input",
        dest="inputs",
        action="append",
        default=None,
        metavar="PATH",
        help=(
            "Processed JSONL (repeat flag or use commas). "
            f"Default: {default_inputs}"
        ),
    )
    parser.add_argument("--output", default="data/sft/e1_m1_h1_examples.jsonl")
    parser.add_argument("--min-pass-rate", type=float, default=0.80)
    parser.add_argument("--min-reasoning-steps", type=int, default=1)
    parser.add_argument("--min-conflict-steps", type=int, default=0)
    parser.add_argument("--min-resolution-steps", type=int, default=0)
    parser.add_argument(
        "--require-stepwise-deliberation",
        action="store_true",
        help="Require task-specific SWD step-wise deliberation checks from verification utilities.",
    )
    parser.add_argument(
        "--max-per-task",
        type=int,
        default=0,
        help="Optional cap for kept SFT rows per task (0 = unlimited).",
    )
    args = parser.parse_args()
    raw_inputs = list(args.inputs) if args.inputs else [default_inputs]
    input_paths = _parse_input_paths(raw_inputs)

    rows: List[Dict[str, Any]] = []
    by_task_kept: Dict[str, int] = defaultdict(int)
    seen_ids: set[str] = set()
    skipped = 0
    for path in input_paths:
        if not path.is_file():
            print(f"warning: input missing, skip: {path}", file=sys.stderr)
            continue
        for example in read_jsonl(path):
            eid = str(example.get("example_id") or example.get("id") or "")
            if eid and eid in seen_ids:
                skipped += 1
                continue
            try:
                row = convert_example(
                    example,
                    args.min_pass_rate,
                    args.min_reasoning_steps,
                    args.min_conflict_steps,
                    args.min_resolution_steps,
                    args.require_stepwise_deliberation,
                )
            except Exception as exc:
                skipped += 1
                print(f"skip {example.get('example_id', 'unknown')}: {exc}")
                continue
            if row is None:
                skipped += 1
                continue
            if args.max_per_task > 0 and by_task_kept[row["task_id"]] >= args.max_per_task:
                skipped += 1
                continue
            rows.append(row)
            by_task_kept[row["task_id"]] += 1
            eid2 = str(row.get("example_id") or "unknown")
            if eid2 and eid2 != "unknown":
                seen_ids.add(eid2)

    write_jsonl(Path(args.output), rows)
    print(f"Wrote {len(rows)} SFT conversations to {args.output}; skipped {skipped}.")


if __name__ == "__main__":
    main()