Reframr-RFM-v1-Base / reframr /hf_import.py
OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
import json
import re
import site
import sys
from itertools import chain
from pathlib import Path
from .text_quality import clean_answer_text, clean_context_text, clean_training_text
TEXT_FIELD_PREFERENCES = (
"text",
"content",
"body",
"article",
"document",
"passage",
"markdown",
)
DIALOGUE_FIELD_PREFERENCES = (
"messages",
"conversation",
"conversations",
"dialogue",
"dialog",
"turns",
)
PREFERENCE_FIELD_PAIRS = (
("chosen", "rejected"),
("response_j", "response_k"),
("response_0", "response_1"),
)
INSTRUCTION_FIELD_PAIRS = (
("instruction", "output"),
("prompt", "completion"),
("prompt", "response"),
("question", "answer"),
("question", "response"),
("query", "response"),
)
TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE)
ROLE_ALIASES = {
"assistant": "assistant",
"bot": "assistant",
"gpt": "assistant",
"model": "assistant",
"assistant_response": "assistant",
"human": "user",
"user": "user",
"prompter": "user",
"customer": "user",
"system": "system",
}
def _word_count(text: str) -> int:
return len(text.split())
def _alpha_ratio(text: str) -> float:
if not text:
return 0.0
alpha_count = sum(character.isalpha() for character in text)
return alpha_count / len(text)
def _default_record_weight(record_type: str) -> int:
if record_type == "dialogue_turn":
return 2
if record_type == "instruction_answer":
return 2
if record_type == "preference_chosen":
return 3
if record_type == "preference_rejected":
return 0
return 1
def choose_text_field(columns: list[str]) -> str:
normalized = {column.casefold(): column for column in columns}
for preferred in TEXT_FIELD_PREFERENCES:
if preferred in normalized:
return normalized[preferred]
raise ValueError("Could not infer a text column. Pass --text-field explicitly.")
def choose_dialogue_field(columns: list[str]) -> str:
normalized = {column.casefold(): column for column in columns}
for preferred in DIALOGUE_FIELD_PREFERENCES:
if preferred in normalized:
return normalized[preferred]
raise ValueError("Could not infer a conversation column.")
def choose_preference_fields(columns: list[str]) -> tuple[str, str]:
normalized = {column.casefold(): column for column in columns}
for chosen_name, rejected_name in PREFERENCE_FIELD_PAIRS:
if chosen_name in normalized and rejected_name in normalized:
return normalized[chosen_name], normalized[rejected_name]
raise ValueError("Could not infer chosen/rejected preference columns.")
def choose_instruction_fields(columns: list[str]) -> tuple[str, str]:
normalized = {column.casefold(): column for column in columns}
for prompt_name, answer_name in INSTRUCTION_FIELD_PAIRS:
if prompt_name in normalized and answer_name in normalized:
return normalized[prompt_name], normalized[answer_name]
raise ValueError("Could not infer instruction/answer columns.")
def _row_identifier(row: dict[str, object]) -> str:
for candidate in ("id", "_id", "row_id", "uuid", "prompt_id"):
if candidate in row and str(row[candidate]).strip():
return str(row[candidate]).strip()
return ""
def _base_record(
*,
dataset: str,
config: str | None,
split: str,
row_id: str,
) -> dict[str, str]:
return {
"source": "huggingface",
"dataset": dataset,
"config": config or "",
"split": split,
"row_id": row_id,
}
def _row_language(row: dict[str, object]) -> str:
for candidate in ("lang", "language", "locale"):
value = row.get(candidate)
if isinstance(value, str) and value.strip():
return value.strip()
return ""
def _normalize_role(raw_role: object) -> str:
role = str(raw_role or "").strip().casefold()
return ROLE_ALIASES.get(role, role)
def _message_content(message: dict[str, object]) -> str:
for field in ("content", "value", "text", "message"):
value = message.get(field)
if isinstance(value, str) and value.strip():
return clean_training_text(value)
return ""
def _message_role(message: dict[str, object]) -> str:
for field in ("role", "from", "speaker", "author"):
value = message.get(field)
if value is not None:
normalized = _normalize_role(value)
if normalized:
return normalized
return ""
def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]:
if not isinstance(raw_messages, list):
return []
parsed: list[dict[str, str]] = []
for message in raw_messages:
if not isinstance(message, dict):
continue
role = _message_role(message)
content = _message_content(message)
if role not in {"system", "user", "assistant"} or not content:
continue
parsed.append({"role": role, "content": content})
return parsed
def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]:
if not isinstance(raw_text, str):
return []
text = raw_text.strip()
if not text:
return []
matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text))
if not matches:
return []
parsed: list[dict[str, str]] = []
for index, match in enumerate(matches):
role = _normalize_role(match.group(1))
start = match.end()
end = matches[index + 1].start() if index + 1 < len(matches) else len(text)
content = clean_training_text(text[start:end].strip())
if role in {"system", "user", "assistant"} and content:
parsed.append({"role": role, "content": content})
return parsed
def _render_prompt(messages: list[dict[str, str]]) -> str:
lines = []
for message in messages:
content = clean_context_text(message["content"])
if content:
lines.append(content)
return "\n".join(lines).strip()
def _compose_training_text(context: str, answer: str) -> str:
context = clean_context_text(context)
answer = clean_answer_text(answer)
return f"<reason> {context} <answer> {answer}".strip()
def _compose_instruction_context(row: dict[str, object], prompt_field: str) -> str:
parts: list[str] = []
prompt = clean_context_text(str(row.get(prompt_field, "")).strip())
extra_input = clean_context_text(str(row.get("input", "")).strip())
if prompt:
parts.append(prompt)
if extra_input:
parts.append(extra_input)
return "\n".join(parts).strip()
def _extract_prompt_answer(
row: dict[str, object],
*,
field_name: str,
) -> tuple[str, str]:
dialogue_messages = _parse_dialogue_messages(row.get(field_name))
if dialogue_messages and dialogue_messages[-1]["role"] == "assistant":
prompt = _render_prompt(dialogue_messages[:-1])
answer = dialogue_messages[-1]["content"]
if prompt and answer:
return prompt, answer
messages = _parse_transcript_messages(row.get(field_name))
if messages:
if messages[-1]["role"] == "assistant":
prompt = _render_prompt(messages[:-1])
answer = messages[-1]["content"]
if prompt and answer:
return prompt, answer
prompt = clean_training_text(str(row.get("prompt", row.get("question", ""))).strip())
answer = clean_answer_text(str(row.get(field_name, "")).strip())
return prompt, answer
def _ordered_preference_fields(
row: dict[str, object],
*,
left_field: str,
right_field: str,
) -> tuple[str, str]:
if {left_field, right_field} != {"response_0", "response_1"}:
return left_field, right_field
for selector in ("safer_response_id", "better_response_id"):
value = row.get(selector)
try:
preferred = int(value)
except (TypeError, ValueError):
continue
if preferred == 0:
return "response_0", "response_1"
if preferred == 1:
return "response_1", "response_0"
return left_field, right_field
def _passes_quality_gate(
record: dict[str, str],
*,
min_words: int,
max_words: int,
min_alpha_ratio: float,
allowed_languages: set[str],
) -> bool:
candidate = str(record.get("answer") or record.get("text") or "").strip()
if not candidate:
return False
word_count = _word_count(candidate)
if min_words > 0 and word_count < min_words:
return False
if max_words > 0 and word_count > max_words:
return False
alpha_ratio = _alpha_ratio(candidate)
if min_alpha_ratio > 0.0 and alpha_ratio < min_alpha_ratio:
return False
if allowed_languages:
language = str(record.get("language", "")).strip().casefold()
if not language or language not in allowed_languages:
return False
record["quality_word_count"] = str(word_count)
record["quality_alpha_ratio"] = f"{alpha_ratio:.4f}"
return True
def to_json_record(
*,
dataset: str,
config: str | None,
split: str,
text_field: str,
row: dict[str, object],
) -> dict[str, str]:
text = clean_training_text(str(row.get(text_field, "")).strip())
if not text:
raise ValueError("Row is missing usable text.")
record_type = "text"
return {
**_base_record(
dataset=dataset,
config=config,
split=split,
row_id=_row_identifier(row),
),
"record_type": record_type,
"language": _row_language(row),
"text_field": text_field,
"text": text,
"word_count": _word_count(text),
"weight": _default_record_weight(record_type),
}
def dialogue_to_json_records(
*,
dataset: str,
config: str | None,
split: str,
conversation_field: str,
row: dict[str, object],
) -> list[dict[str, str]]:
messages = _parse_dialogue_messages(row.get(conversation_field))
if not messages:
raise ValueError("Row does not contain usable dialogue turns.")
row_id = _row_identifier(row)
records: list[dict[str, str]] = []
history: list[dict[str, str]] = []
row_language = _row_language(row)
system_text = clean_training_text(str(row.get("system", "")).strip())
if system_text:
history.append({"role": "system", "content": system_text})
assistant_turn_index = 0
for message in messages:
if message["role"] != "assistant":
history.append(message)
continue
prompt = _render_prompt(history)
if not prompt:
continue
assistant_turn_index += 1
records.append(
{
**_base_record(
dataset=dataset,
config=config,
split=split,
row_id=row_id,
),
"record_type": "dialogue_turn",
"language": row_language,
"conversation_field": conversation_field,
"turn_index": str(assistant_turn_index),
"context": prompt,
"answer": clean_answer_text(message["content"]),
"text": _compose_training_text(prompt, message["content"]),
"word_count": _word_count(clean_answer_text(message["content"])),
"weight": _default_record_weight("dialogue_turn"),
}
)
history.append(message)
if not records:
raise ValueError("Dialogue row did not yield any assistant training turns.")
return records
def preference_to_json_records(
*,
dataset: str,
config: str | None,
split: str,
chosen_field: str,
rejected_field: str,
row: dict[str, object],
preference_target: str = "both",
) -> list[dict[str, str]]:
row_id = _row_identifier(row)
pair_id = row_id or f"{chosen_field}:{rejected_field}"
records: list[dict[str, str]] = []
row_language = _row_language(row)
chosen_field, rejected_field = _ordered_preference_fields(
row,
left_field=chosen_field,
right_field=rejected_field,
)
field_specs = [
(chosen_field, "preference_chosen"),
(rejected_field, "preference_rejected"),
]
if preference_target == "chosen":
field_specs = [(chosen_field, "preference_chosen")]
elif preference_target == "rejected":
field_specs = [(rejected_field, "preference_rejected")]
elif preference_target != "both":
raise ValueError("preference_target must be one of: both, chosen, rejected.")
for field_name, record_type in field_specs:
prompt, answer = _extract_prompt_answer(row, field_name=field_name)
if not prompt or not answer:
continue
records.append(
{
**_base_record(
dataset=dataset,
config=config,
split=split,
row_id=row_id,
),
"record_type": record_type,
"language": row_language,
"pair_id": pair_id,
"text_field": field_name,
"context": prompt,
"answer": clean_answer_text(answer),
"text": _compose_training_text(prompt, answer),
"word_count": _word_count(clean_answer_text(answer)),
"weight": _default_record_weight(record_type),
}
)
if not records:
raise ValueError("Preference row did not yield usable chosen/rejected transcripts.")
return records
def instruction_to_json_records(
*,
dataset: str,
config: str | None,
split: str,
prompt_field: str,
answer_field: str,
row: dict[str, object],
) -> list[dict[str, str]]:
context = _compose_instruction_context(row, prompt_field)
answer = clean_answer_text(str(row.get(answer_field, "")).strip())
if not context or not answer:
raise ValueError("Instruction row did not contain usable prompt and answer text.")
record_type = "instruction_answer"
return [
{
**_base_record(
dataset=dataset,
config=config,
split=split,
row_id=_row_identifier(row),
),
"record_type": record_type,
"language": _row_language(row),
"context": context,
"answer": answer,
"text": _compose_training_text(context, answer),
"word_count": _word_count(answer),
"weight": _default_record_weight(record_type),
}
]
def _expand_row_records(
*,
dataset: str,
config: str | None,
split: str,
row: dict[str, object],
text_field: str | None,
preference_target: str,
) -> list[dict[str, str]]:
if text_field is not None:
explicit_value = row.get(text_field)
if isinstance(explicit_value, list):
return dialogue_to_json_records(
dataset=dataset,
config=config,
split=split,
conversation_field=text_field,
row=row,
)
return [
to_json_record(
dataset=dataset,
config=config,
split=split,
text_field=text_field,
row=row,
)
]
columns = list(row)
try:
chosen_field, rejected_field = choose_preference_fields(columns)
return preference_to_json_records(
dataset=dataset,
config=config,
split=split,
chosen_field=chosen_field,
rejected_field=rejected_field,
row=row,
preference_target=preference_target,
)
except ValueError:
pass
try:
prompt_field, answer_field = choose_instruction_fields(columns)
return instruction_to_json_records(
dataset=dataset,
config=config,
split=split,
prompt_field=prompt_field,
answer_field=answer_field,
row=row,
)
except ValueError:
pass
try:
conversation_field = choose_dialogue_field(columns)
if isinstance(row.get(conversation_field), list):
return dialogue_to_json_records(
dataset=dataset,
config=config,
split=split,
conversation_field=conversation_field,
row=row,
)
except ValueError:
pass
inferred_text_field = choose_text_field(columns)
return [
to_json_record(
dataset=dataset,
config=config,
split=split,
text_field=inferred_text_field,
row=row,
)
]
def import_hf_dataset(
*,
dataset: str,
output_path: str | Path,
config: str | None = None,
split: str = "train",
text_field: str | None = None,
limit: int = 1000,
streaming: bool = True,
preference_target: str = "chosen",
min_words: int = 0,
max_words: int = 0,
min_alpha_ratio: float = 0.0,
allowed_languages: tuple[str, ...] = (),
) -> dict[str, object]:
try:
from datasets import load_dataset
except ModuleNotFoundError:
user_site = site.getusersitepackages()
if user_site and user_site not in sys.path:
sys.path.append(user_site)
from datasets import load_dataset
dataset_kwargs: dict[str, object] = {
"split": split,
"streaming": streaming,
}
if config:
dataset_kwargs["name"] = config
hf_dataset = load_dataset(dataset, **dataset_kwargs)
iterator = iter(hf_dataset)
first_row: dict[str, object] | None = None
if text_field is None:
first_row = dict(next(iterator))
iterator = chain([first_row], iterator)
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
written = 0
record_types: set[str] = set()
normalized_languages = {language.casefold() for language in allowed_languages if language.strip()}
with output.open("w", encoding="utf-8") as handle:
for row in iterator:
if written >= limit:
break
normalized_row = dict(row)
try:
records = _expand_row_records(
dataset=dataset,
config=config,
split=split,
row=normalized_row,
text_field=text_field,
preference_target=preference_target,
)
except ValueError:
continue
for record in records:
if written >= limit:
break
if not _passes_quality_gate(
record,
min_words=min_words,
max_words=max_words,
min_alpha_ratio=min_alpha_ratio,
allowed_languages=normalized_languages,
):
continue
record_types.add(record.get("record_type", "text"))
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
written += 1
inferred_mode = "mixed" if len(record_types) > 1 else (next(iter(record_types)) if record_types else "unknown")
return {
"dataset": dataset,
"config": config or "",
"split": split,
"text_field": text_field or "",
"output_path": str(output.resolve()),
"records_written": written,
"record_types": sorted(record_types),
"mode": inferred_mode,
"preference_target": preference_target,
"streaming": streaming,
"min_words": min_words,
"max_words": max_words,
"min_alpha_ratio": min_alpha_ratio,
"allowed_languages": sorted(normalized_languages),
}