Spaces:
Runtime error
Runtime error
File size: 5,694 Bytes
f016eb7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | """Dataset helpers for synthetic and bootstrap SFT records."""
from __future__ import annotations
import copy
import json
from pathlib import Path
from typing import Any, Iterable
import yaml
def load_jsonl_records(paths: Iterable[str | Path]) -> list[dict[str, Any]]:
"""Load newline-delimited JSON records from one or more files."""
records: list[dict[str, Any]] = []
for raw_path in paths:
path = Path(raw_path)
with path.open("r", encoding="utf-8") as handle:
for lineno, line in enumerate(handle, start=1):
text = line.strip()
if not text:
continue
payload = json.loads(text)
if not isinstance(payload, dict):
raise TypeError(f"{path}:{lineno} is not a JSON object")
records.append(payload)
return records
def load_tool_context(paths: Iterable[str | Path]) -> str:
"""Load and normalize a tool-context file or files."""
blocks: list[str] = []
for raw_path in paths:
path = Path(raw_path)
suffix = path.suffix.lower()
text = path.read_text(encoding="utf-8").strip()
if not text:
continue
if suffix in {".json", ".yaml", ".yml"}:
payload = json.loads(text) if suffix == ".json" else yaml.safe_load(text)
blocks.append(_render_tool_payload(payload))
else:
blocks.append(text)
return "\n\n".join(block for block in blocks if block.strip())
def append_tool_context(
records: list[dict[str, Any]],
tool_context: str,
) -> list[dict[str, Any]]:
"""Append tool descriptions to the first system prompt in each record."""
if not tool_context.strip():
return [copy.deepcopy(record) for record in records]
block = tool_context.strip()
if not block.lower().startswith("available tools"):
block = "Available tools:\n" + block
enriched: list[dict[str, Any]] = []
for record in records:
clone = copy.deepcopy(record)
messages = clone.get("messages", [])
if isinstance(messages, list):
for message in messages:
if not isinstance(message, dict):
continue
if message.get("role") != "system":
continue
content = str(message.get("content", "")).rstrip()
if block not in content:
message["content"] = f"{content}\n\n{block}".strip()
break
enriched.append(clone)
return enriched
def extract_bootstrap_messages(
records: list[dict[str, Any]],
*,
role: str = "red",
limit: int = 0,
) -> list[dict[str, Any]]:
"""Extract few-shot chat messages from prior SFT records."""
if limit <= 0:
return []
examples: list[dict[str, Any]] = []
ranked_records = sorted(records, key=_bootstrap_record_rank, reverse=True)
used = 0
for record in ranked_records:
record_role = (
str(record.get("role", "")).strip().lower()
or str(record.get("metadata", {}).get("role", "")).strip().lower()
)
if record_role and record_role != role:
continue
messages = record.get("messages", [])
if not isinstance(messages, list):
continue
example = [
copy.deepcopy(message)
for message in messages
if isinstance(message, dict)
]
if example and example[0].get("role") == "system":
example = example[1:]
if not example:
continue
examples.extend(example)
used += 1
if used >= limit:
break
return examples
def write_jsonl_records(path: str | Path, records: list[dict[str, Any]]) -> int:
"""Write JSONL records to *path*."""
output = Path(path)
output.parent.mkdir(parents=True, exist_ok=True)
with output.open("w", encoding="utf-8") as handle:
for record in records:
handle.write(json.dumps(record) + "\n")
return len(records)
def _render_tool_payload(payload: Any) -> str:
if isinstance(payload, str):
return payload.strip()
if isinstance(payload, dict):
lines = []
for key, value in payload.items():
if isinstance(value, str):
lines.append(f"- {key}: {value}")
else:
rendered = json.dumps(value, sort_keys=True)
lines.append(f"- {key}: {rendered}")
return "\n".join(lines)
if isinstance(payload, list):
lines = []
for item in payload:
if isinstance(item, dict):
name = str(item.get("name", "")).strip()
description = str(item.get("description", "")).strip()
if name and description:
lines.append(f"- {name}: {description}")
elif name:
lines.append(f"- {name}")
else:
lines.append(f"- {json.dumps(item, sort_keys=True)}")
else:
lines.append(f"- {item}")
return "\n".join(lines)
return str(payload).strip()
def _bootstrap_record_rank(record: dict[str, Any]) -> tuple[int, int, int]:
metadata = record.get("metadata", {})
success = 1 if metadata.get("success") else 0
total_turns = int(metadata.get("total_turns") or 0)
tool_turns = sum(
1
for message in record.get("messages", [])
if isinstance(message, dict)
and message.get("role") == "assistant"
and message.get("tool_calls")
)
return success, tool_turns, total_turns
|