File size: 5,694 Bytes
f016eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""Dataset helpers for synthetic and bootstrap SFT records."""

from __future__ import annotations

import copy
import json
from pathlib import Path
from typing import Any, Iterable

import yaml


def load_jsonl_records(paths: Iterable[str | Path]) -> list[dict[str, Any]]:
    """Load newline-delimited JSON records from one or more files."""
    records: list[dict[str, Any]] = []
    for raw_path in paths:
        path = Path(raw_path)
        with path.open("r", encoding="utf-8") as handle:
            for lineno, line in enumerate(handle, start=1):
                text = line.strip()
                if not text:
                    continue
                payload = json.loads(text)
                if not isinstance(payload, dict):
                    raise TypeError(f"{path}:{lineno} is not a JSON object")
                records.append(payload)
    return records


def load_tool_context(paths: Iterable[str | Path]) -> str:
    """Load and normalize a tool-context file or files."""
    blocks: list[str] = []
    for raw_path in paths:
        path = Path(raw_path)
        suffix = path.suffix.lower()
        text = path.read_text(encoding="utf-8").strip()
        if not text:
            continue
        if suffix in {".json", ".yaml", ".yml"}:
            payload = json.loads(text) if suffix == ".json" else yaml.safe_load(text)
            blocks.append(_render_tool_payload(payload))
        else:
            blocks.append(text)
    return "\n\n".join(block for block in blocks if block.strip())


def append_tool_context(
    records: list[dict[str, Any]],
    tool_context: str,
) -> list[dict[str, Any]]:
    """Append tool descriptions to the first system prompt in each record."""
    if not tool_context.strip():
        return [copy.deepcopy(record) for record in records]

    block = tool_context.strip()
    if not block.lower().startswith("available tools"):
        block = "Available tools:\n" + block

    enriched: list[dict[str, Any]] = []
    for record in records:
        clone = copy.deepcopy(record)
        messages = clone.get("messages", [])
        if isinstance(messages, list):
            for message in messages:
                if not isinstance(message, dict):
                    continue
                if message.get("role") != "system":
                    continue
                content = str(message.get("content", "")).rstrip()
                if block not in content:
                    message["content"] = f"{content}\n\n{block}".strip()
                break
        enriched.append(clone)
    return enriched


def extract_bootstrap_messages(
    records: list[dict[str, Any]],
    *,
    role: str = "red",
    limit: int = 0,
) -> list[dict[str, Any]]:
    """Extract few-shot chat messages from prior SFT records."""
    if limit <= 0:
        return []

    examples: list[dict[str, Any]] = []
    ranked_records = sorted(records, key=_bootstrap_record_rank, reverse=True)
    used = 0
    for record in ranked_records:
        record_role = (
            str(record.get("role", "")).strip().lower()
            or str(record.get("metadata", {}).get("role", "")).strip().lower()
        )
        if record_role and record_role != role:
            continue

        messages = record.get("messages", [])
        if not isinstance(messages, list):
            continue
        example = [
            copy.deepcopy(message)
            for message in messages
            if isinstance(message, dict)
        ]
        if example and example[0].get("role") == "system":
            example = example[1:]
        if not example:
            continue

        examples.extend(example)
        used += 1
        if used >= limit:
            break

    return examples


def write_jsonl_records(path: str | Path, records: list[dict[str, Any]]) -> int:
    """Write JSONL records to *path*."""
    output = Path(path)
    output.parent.mkdir(parents=True, exist_ok=True)
    with output.open("w", encoding="utf-8") as handle:
        for record in records:
            handle.write(json.dumps(record) + "\n")
    return len(records)


def _render_tool_payload(payload: Any) -> str:
    if isinstance(payload, str):
        return payload.strip()
    if isinstance(payload, dict):
        lines = []
        for key, value in payload.items():
            if isinstance(value, str):
                lines.append(f"- {key}: {value}")
            else:
                rendered = json.dumps(value, sort_keys=True)
                lines.append(f"- {key}: {rendered}")
        return "\n".join(lines)
    if isinstance(payload, list):
        lines = []
        for item in payload:
            if isinstance(item, dict):
                name = str(item.get("name", "")).strip()
                description = str(item.get("description", "")).strip()
                if name and description:
                    lines.append(f"- {name}: {description}")
                elif name:
                    lines.append(f"- {name}")
                else:
                    lines.append(f"- {json.dumps(item, sort_keys=True)}")
            else:
                lines.append(f"- {item}")
        return "\n".join(lines)
    return str(payload).strip()


def _bootstrap_record_rank(record: dict[str, Any]) -> tuple[int, int, int]:
    metadata = record.get("metadata", {})
    success = 1 if metadata.get("success") else 0
    total_turns = int(metadata.get("total_turns") or 0)
    tool_turns = sum(
        1
        for message in record.get("messages", [])
        if isinstance(message, dict)
        and message.get("role") == "assistant"
        and message.get("tool_calls")
    )
    return success, tool_turns, total_turns