Reframr-RFM-v2-Base / reframr /hf_import.py
OkeyMeta's picture
Add Reframr-RFM-v2-Base release files
52da7b7 verified
import json
import re
import site
import sys
from itertools import chain
from pathlib import Path
from .reasoning import TOOL_PROTOCOL_TOKENS
from .text_quality import clean_answer_text, clean_context_text, clean_training_text
TEXT_FIELD_PREFERENCES = (
"text",
"content",
"body",
"article",
"document",
"passage",
"markdown",
)
DIALOGUE_FIELD_PREFERENCES = (
"messages",
"conversation",
"conversations",
"dialogue",
"dialog",
"turns",
"chat",
)
PREFERENCE_FIELD_PAIRS = (
("chosen", "rejected"),
("response_j", "response_k"),
("response_0", "response_1"),
)
INSTRUCTION_FIELD_PAIRS = (
("instruction", "output"),
("prompt", "completion"),
("prompt", "response"),
("question", "answer"),
("question", "response"),
("query", "response"),
)
TRANSCRIPT_ROLE_PATTERN = re.compile(
r"(?:^|\n\s*\n)(Human|Assistant|System|User|Function Response|Function|Tool)\s*:\s*",
re.IGNORECASE,
)
ROLE_ALIASES = {
"assistant": "assistant",
"bot": "assistant",
"gpt": "assistant",
"model": "assistant",
"assistant_response": "assistant",
"human": "user",
"user": "user",
"prompter": "user",
"customer": "user",
"system": "system",
"function": "tool",
"function response": "tool",
"tool": "tool",
"tool_result": "tool",
}
TOOL_DEFINITION_FIELDS = ("tools_json", "tools", "functions", "available_tools")
def _word_count(text: str) -> int:
return len(text.split())
def _alpha_ratio(text: str) -> float:
if not text:
return 0.0
alpha_count = sum(character.isalpha() for character in text)
return alpha_count / len(text)
def _default_record_weight(record_type: str) -> int:
if record_type == "dialogue_turn":
return 2
if record_type == "instruction_answer":
return 2
if record_type == "preference_chosen":
return 3
if record_type == "preference_rejected":
return 0
return 1
def choose_text_field(columns: list[str]) -> str:
normalized = {column.casefold(): column for column in columns}
for preferred in TEXT_FIELD_PREFERENCES:
if preferred in normalized:
return normalized[preferred]
raise ValueError("Could not infer a text column. Pass --text-field explicitly.")
def choose_dialogue_field(columns: list[str]) -> str:
normalized = {column.casefold(): column for column in columns}
for preferred in DIALOGUE_FIELD_PREFERENCES:
if preferred in normalized:
return normalized[preferred]
raise ValueError("Could not infer a conversation column.")
def choose_preference_fields(columns: list[str]) -> tuple[str, str]:
normalized = {column.casefold(): column for column in columns}
for chosen_name, rejected_name in PREFERENCE_FIELD_PAIRS:
if chosen_name in normalized and rejected_name in normalized:
return normalized[chosen_name], normalized[rejected_name]
raise ValueError("Could not infer chosen/rejected preference columns.")
def choose_instruction_fields(columns: list[str]) -> tuple[str, str]:
normalized = {column.casefold(): column for column in columns}
for prompt_name, answer_name in INSTRUCTION_FIELD_PAIRS:
if prompt_name in normalized and answer_name in normalized:
return normalized[prompt_name], normalized[answer_name]
raise ValueError("Could not infer instruction/answer columns.")
def _row_identifier(row: dict[str, object]) -> str:
for candidate in ("id", "_id", "row_id", "uuid", "prompt_id"):
if candidate in row and str(row[candidate]).strip():
return str(row[candidate]).strip()
return ""
def _base_record(
*,
dataset: str,
config: str | None,
split: str,
row_id: str,
) -> dict[str, str]:
return {
"source": "huggingface",
"dataset": dataset,
"config": config or "",
"split": split,
"row_id": row_id,
}
def _row_language(row: dict[str, object]) -> str:
for candidate in ("lang", "language", "locale"):
value = row.get(candidate)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def _normalize_role(raw_role: object) -> str:
role = str(raw_role or "").strip().casefold()
return ROLE_ALIASES.get(role, role)
def _coerce_json_payload(payload: object) -> object:
if not isinstance(payload, str):
return payload
stripped = payload.strip()
if not stripped:
return ""
try:
return json.loads(stripped)
except json.JSONDecodeError:
return stripped
def _compact_json(payload: object) -> str:
if isinstance(payload, str):
return payload.strip()
return json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
def _render_tool_call(call: object) -> str:
if not isinstance(call, dict):
return f"<tool_call> {str(call).strip()}".strip()
function_payload = call.get("function", {})
function = function_payload if isinstance(function_payload, dict) else {}
name = str(call.get("name", function.get("name", "tool"))).strip() or "tool"
arguments = call.get("arguments", function.get("arguments", {}))
return f"<tool_call> {name} {_compact_json(arguments)}".strip()
def _render_source_lines(payload: object) -> list[str]:
if not isinstance(payload, dict):
return []
raw_sources = payload.get("sources", payload.get("source", []))
if isinstance(raw_sources, dict):
sources = [raw_sources]
elif isinstance(raw_sources, list):
sources = raw_sources
elif raw_sources:
sources = [raw_sources]
else:
sources = []
lines: list[str] = []
for source in sources:
if isinstance(source, dict):
title = str(source.get("title", source.get("name", "source"))).strip()
url = str(source.get("url", source.get("uri", ""))).strip()
snippet = str(source.get("snippet", source.get("text", source.get("content", "")))).strip()
parts = [part for part in (title, url, snippet) if part]
if parts:
lines.append(f"<source> {' | '.join(parts)}")
elif source:
lines.append(f"<source> {str(source).strip()}")
return lines
def _render_tool_result(name: str, payload: object) -> list[str]:
tool_name = name.strip() or "tool"
parsed = _coerce_json_payload(payload)
if isinstance(parsed, dict):
explicit_name = str(parsed.get("name", parsed.get("tool", ""))).strip()
if explicit_name:
tool_name = explicit_name
status = str(parsed.get("status", "")).casefold()
ok_value = parsed.get("ok", None)
error = str(parsed.get("error", parsed.get("message", ""))).strip()
failed = ok_value is False or status in {"error", "failed", "failure", "timeout"} or bool(error)
if failed:
first = f"<tool_result> {tool_name} failed: {error or status or 'unknown error'}"
else:
summary = str(parsed.get("summary", parsed.get("content", parsed.get("text", "")))).strip()
first = f"<tool_result> {tool_name} ok"
if summary and not _render_source_lines(parsed):
first = f"{first}: {summary}"
return [first, *_render_source_lines(parsed)]
if parsed:
return [f"<tool_result> {tool_name} {str(parsed).strip()}"]
return [f"<tool_result> {tool_name} empty"]
def _message_content(message: dict[str, object], role: str = "") -> str:
if role == "tool":
name = str(message.get("name", message.get("tool_call_id", "tool"))).strip() or "tool"
payload = message.get("content", message.get("value", message.get("text", message)))
return clean_training_text("\n".join(_render_tool_result(name, payload)))
parts: list[str] = []
for field in ("content", "value", "text", "message"):
value = message.get(field)
if isinstance(value, str) and value.strip():
parts.append(clean_training_text(value))
break
tool_calls = message.get("tool_calls", message.get("function_calls", message.get("tools")))
if isinstance(tool_calls, str):
tool_calls = _coerce_json_payload(tool_calls)
if isinstance(tool_calls, dict):
tool_calls = [tool_calls]
if isinstance(tool_calls, list):
for call in tool_calls:
parts.append(_render_tool_call(call))
return "\n".join(part for part in parts if part).strip()
def _message_role(message: dict[str, object]) -> str:
for field in ("role", "from", "speaker", "author"):
value = message.get(field)
if value is not None:
normalized = _normalize_role(value)
if normalized:
return normalized
return ""
def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]:
if isinstance(raw_messages, str):
parsed_json = _coerce_json_payload(raw_messages)
if parsed_json is not raw_messages:
raw_messages = parsed_json
if not isinstance(raw_messages, list):
return []
parsed: list[dict[str, str]] = []
for message in raw_messages:
if not isinstance(message, dict):
continue
role = _message_role(message)
content = _message_content(message, role)
if role not in {"system", "user", "assistant", "tool"} or not content:
continue
parsed.append({"role": role, "content": content})
return parsed
def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]:
if not isinstance(raw_text, str):
return []
text = raw_text.strip()
if not text:
return []
matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text))
if not matches:
return []
parsed: list[dict[str, str]] = []
for index, match in enumerate(matches):
role = _normalize_role(match.group(1))
start = match.end()
end = matches[index + 1].start() if index + 1 < len(matches) else len(text)
raw_content = text[start:end].strip()
if role == "tool":
content = clean_training_text("\n".join(_render_tool_result("tool", raw_content)))
else:
content = clean_training_text(raw_content)
if role in {"system", "user", "assistant", "tool"} and content:
parsed.append({"role": role, "content": content})
return parsed
def _render_prompt(messages: list[dict[str, str]]) -> str:
lines = []
for message in messages:
raw_content = message["content"]
if message["role"] in {"system", "tool"} or any(
token in raw_content for token in TOOL_PROTOCOL_TOKENS
):
content = clean_training_text(raw_content)
else:
content = clean_context_text(raw_content)
if content:
lines.append(content)
return "\n".join(lines).strip()
def _tool_definition_text(row: dict[str, object]) -> str:
parts: list[str] = []
for field in TOOL_DEFINITION_FIELDS:
value = row.get(field)
if value in (None, ""):
continue
parts.append(_compact_json(_coerce_json_payload(value)))
if not parts:
return ""
return clean_training_text("Available tools: " + "\n".join(parts))
def _compose_training_text(context: str, answer: str) -> str:
context = clean_context_text(context)
answer = clean_answer_text(answer)
return f"<reason> {context} <answer> {answer}".strip()
def _compose_instruction_context(row: dict[str, object], prompt_field: str) -> str:
parts: list[str] = []
prompt = clean_context_text(str(row.get(prompt_field, "")).strip())
extra_input = clean_context_text(str(row.get("input", "")).strip())
if prompt:
parts.append(prompt)
if extra_input:
parts.append(extra_input)
return "\n".join(parts).strip()
def _extract_prompt_answer(
row: dict[str, object],
*,
field_name: str,
) -> tuple[str, str]:
dialogue_messages = _parse_dialogue_messages(row.get(field_name))
if dialogue_messages and dialogue_messages[-1]["role"] == "assistant":
prompt = _render_prompt(dialogue_messages[:-1])
answer = dialogue_messages[-1]["content"]
if prompt and answer:
return prompt, answer
messages = _parse_transcript_messages(row.get(field_name))
if messages:
if messages[-1]["role"] == "assistant":
prompt = _render_prompt(messages[:-1])
answer = messages[-1]["content"]
if prompt and answer:
return prompt, answer
prompt = clean_training_text(str(row.get("prompt", row.get("question", ""))).strip())
answer = clean_answer_text(str(row.get(field_name, "")).strip())
return prompt, answer
def _ordered_preference_fields(
row: dict[str, object],
*,
left_field: str,
right_field: str,
) -> tuple[str, str]:
if {left_field, right_field} != {"response_0", "response_1"}:
return left_field, right_field
for selector in ("safer_response_id", "better_response_id"):
value = row.get(selector)
try:
preferred = int(value)
except (TypeError, ValueError):
continue
if preferred == 0:
return "response_0", "response_1"
if preferred == 1:
return "response_1", "response_0"
return left_field, right_field
def _passes_quality_gate(
record: dict[str, str],
*,
min_words: int,
max_words: int,
min_alpha_ratio: float,
allowed_languages: set[str],
) -> bool:
candidate = str(record.get("answer") or record.get("text") or "").strip()
if not candidate:
return False
word_count = _word_count(candidate)
if min_words > 0 and word_count < min_words:
return False
if max_words > 0 and word_count > max_words:
return False
alpha_ratio = _alpha_ratio(candidate)
if min_alpha_ratio > 0.0 and alpha_ratio < min_alpha_ratio:
return False
if allowed_languages:
language = str(record.get("language", "")).strip().casefold()
if not language or language not in allowed_languages:
return False
record["quality_word_count"] = str(word_count)
record["quality_alpha_ratio"] = f"{alpha_ratio:.4f}"
return True
def to_json_record(
*,
dataset: str,
config: str | None,
split: str,
text_field: str,
row: dict[str, object],
) -> dict[str, str]:
text = clean_training_text(str(row.get(text_field, "")).strip())
if not text:
raise ValueError("Row is missing usable text.")
record_type = "text"
return {
**_base_record(
dataset=dataset,
config=config,
split=split,
row_id=_row_identifier(row),
),
"record_type": record_type,
"language": _row_language(row),
"text_field": text_field,
"text": text,
"word_count": _word_count(text),
"weight": _default_record_weight(record_type),
}
def dialogue_to_json_records(
*,
dataset: str,
config: str | None,
split: str,
conversation_field: str,
row: dict[str, object],
) -> list[dict[str, str]]:
messages = _parse_dialogue_messages(row.get(conversation_field))
if not messages:
raise ValueError("Row does not contain usable dialogue turns.")
row_id = _row_identifier(row)
records: list[dict[str, str]] = []
history: list[dict[str, str]] = []
row_language = _row_language(row)
system_text = clean_training_text(str(row.get("system", "")).strip())
if system_text:
history.append({"role": "system", "content": system_text})
tool_definition = _tool_definition_text(row)
if tool_definition and tool_definition != system_text:
history.append({"role": "system", "content": tool_definition})
assistant_turn_index = 0
for message in messages:
if message["role"] != "assistant":
history.append(message)
continue
prompt = _render_prompt(history)
if not prompt:
continue
assistant_turn_index += 1
records.append(
{
**_base_record(
dataset=dataset,
config=config,
split=split,
row_id=row_id,
),
"record_type": "dialogue_turn",
"language": row_language,
"conversation_field": conversation_field,
"turn_index": str(assistant_turn_index),
"context": prompt,
"answer": clean_answer_text(message["content"]),
"text": _compose_training_text(prompt, message["content"]),
"word_count": _word_count(clean_answer_text(message["content"])),
"weight": _default_record_weight("dialogue_turn"),
}
)
history.append(message)
if not records:
raise ValueError("Dialogue row did not yield any assistant training turns.")
return records
def preference_to_json_records(
*,
dataset: str,
config: str | None,
split: str,
chosen_field: str,
rejected_field: str,
row: dict[str, object],
preference_target: str = "both",
) -> list[dict[str, str]]:
row_id = _row_identifier(row)
pair_id = row_id or f"{chosen_field}:{rejected_field}"
records: list[dict[str, str]] = []
row_language = _row_language(row)
chosen_field, rejected_field = _ordered_preference_fields(
row,
left_field=chosen_field,
right_field=rejected_field,
)
field_specs = [
(chosen_field, "preference_chosen"),
(rejected_field, "preference_rejected"),
]
if preference_target == "chosen":
field_specs = [(chosen_field, "preference_chosen")]
elif preference_target == "rejected":
field_specs = [(rejected_field, "preference_rejected")]
elif preference_target != "both":
raise ValueError("preference_target must be one of: both, chosen, rejected.")
for field_name, record_type in field_specs:
prompt, answer = _extract_prompt_answer(row, field_name=field_name)
if not prompt or not answer:
continue
records.append(
{
**_base_record(
dataset=dataset,
config=config,
split=split,
row_id=row_id,
),
"record_type": record_type,
"language": row_language,
"pair_id": pair_id,
"text_field": field_name,
"context": prompt,
"answer": clean_answer_text(answer),
"text": _compose_training_text(prompt, answer),
"word_count": _word_count(clean_answer_text(answer)),
"weight": _default_record_weight(record_type),
}
)
if not records:
raise ValueError("Preference row did not yield usable chosen/rejected transcripts.")
return records
def instruction_to_json_records(
*,
dataset: str,
config: str | None,
split: str,
prompt_field: str,
answer_field: str,
row: dict[str, object],
) -> list[dict[str, str]]:
context = _compose_instruction_context(row, prompt_field)
answer = clean_answer_text(str(row.get(answer_field, "")).strip())
if not context or not answer:
raise ValueError("Instruction row did not contain usable prompt and answer text.")
record_type = "instruction_answer"
return [
{
**_base_record(
dataset=dataset,
config=config,
split=split,
row_id=_row_identifier(row),
),
"record_type": record_type,
"language": _row_language(row),
"context": context,
"answer": answer,
"text": _compose_training_text(context, answer),
"word_count": _word_count(answer),
"weight": _default_record_weight(record_type),
}
]
def _expand_row_records(
*,
dataset: str,
config: str | None,
split: str,
row: dict[str, object],
text_field: str | None,
preference_target: str,
) -> list[dict[str, str]]:
if text_field is not None:
explicit_value = row.get(text_field)
if isinstance(explicit_value, list):
return dialogue_to_json_records(
dataset=dataset,
config=config,
split=split,
conversation_field=text_field,
row=row,
)
return [
to_json_record(
dataset=dataset,
config=config,
split=split,
text_field=text_field,
row=row,
)
]
columns = list(row)
try:
chosen_field, rejected_field = choose_preference_fields(columns)
return preference_to_json_records(
dataset=dataset,
config=config,
split=split,
chosen_field=chosen_field,
rejected_field=rejected_field,
row=row,
preference_target=preference_target,
)
except ValueError:
pass
try:
prompt_field, answer_field = choose_instruction_fields(columns)
return instruction_to_json_records(
dataset=dataset,
config=config,
split=split,
prompt_field=prompt_field,
answer_field=answer_field,
row=row,
)
except ValueError:
pass
try:
conversation_field = choose_dialogue_field(columns)
if isinstance(row.get(conversation_field), list):
return dialogue_to_json_records(
dataset=dataset,
config=config,
split=split,
conversation_field=conversation_field,
row=row,
)
except ValueError:
pass
inferred_text_field = choose_text_field(columns)
return [
to_json_record(
dataset=dataset,
config=config,
split=split,
text_field=inferred_text_field,
row=row,
)
]
def import_hf_dataset(
*,
dataset: str,
output_path: str | Path,
config: str | None = None,
split: str = "train",
text_field: str | None = None,
limit: int = 1000,
streaming: bool = True,
preference_target: str = "chosen",
min_words: int = 0,
max_words: int = 0,
min_alpha_ratio: float = 0.0,
allowed_languages: tuple[str, ...] = (),
) -> dict[str, object]:
try:
from datasets import load_dataset
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
from datasets import load_dataset
dataset_kwargs: dict[str, object] = {
"split": split,
"streaming": streaming,
}
if config:
dataset_kwargs["name"] = config
hf_dataset = load_dataset(dataset, **dataset_kwargs)
iterator = iter(hf_dataset)
first_row: dict[str, object] | None = None
if text_field is None:
first_row = dict(next(iterator))
iterator = chain([first_row], iterator)
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
written = 0
record_types: set[str] = set()
normalized_languages = {language.casefold() for language in allowed_languages if language.strip()}
with output.open("w", encoding="utf-8") as handle:
for row in iterator:
if written >= limit:
break
normalized_row = dict(row)
try:
records = _expand_row_records(
dataset=dataset,
config=config,
split=split,
row=normalized_row,
text_field=text_field,
preference_target=preference_target,
)
except ValueError:
continue
for record in records:
if written >= limit:
break
if not _passes_quality_gate(
record,
min_words=min_words,
max_words=max_words,
min_alpha_ratio=min_alpha_ratio,
allowed_languages=normalized_languages,
):
continue
record_types.add(record.get("record_type", "text"))
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
written += 1
inferred_mode = "mixed" if len(record_types) > 1 else (next(iter(record_types)) if record_types else "unknown")
return {
"dataset": dataset,
"config": config or "",
"split": split,
"text_field": text_field or "",
"output_path": str(output.resolve()),
"records_written": written,
"record_types": sorted(record_types),
"mode": inferred_mode,
"preference_target": preference_target,
"streaming": streaming,
"min_words": min_words,
"max_words": max_words,
"min_alpha_ratio": min_alpha_ratio,
"allowed_languages": sorted(normalized_languages),
}