| import json |
| import re |
| import site |
| import sys |
| from itertools import chain |
| from pathlib import Path |
|
|
| from .text_quality import clean_answer_text, clean_context_text, clean_training_text |
|
|
| TEXT_FIELD_PREFERENCES = ( |
| "text", |
| "content", |
| "body", |
| "article", |
| "document", |
| "passage", |
| "markdown", |
| ) |
|
|
| DIALOGUE_FIELD_PREFERENCES = ( |
| "messages", |
| "conversation", |
| "conversations", |
| "dialogue", |
| "dialog", |
| "turns", |
| ) |
|
|
| PREFERENCE_FIELD_PAIRS = ( |
| ("chosen", "rejected"), |
| ("response_j", "response_k"), |
| ("response_0", "response_1"), |
| ) |
|
|
| INSTRUCTION_FIELD_PAIRS = ( |
| ("instruction", "output"), |
| ("prompt", "completion"), |
| ("prompt", "response"), |
| ("question", "answer"), |
| ("question", "response"), |
| ("query", "response"), |
| ) |
|
|
| TRANSCRIPT_ROLE_PATTERN = re.compile(r"(?:^|\n\s*\n)(Human|Assistant|System)\s*:\s*", re.IGNORECASE) |
| ROLE_ALIASES = { |
| "assistant": "assistant", |
| "bot": "assistant", |
| "gpt": "assistant", |
| "model": "assistant", |
| "assistant_response": "assistant", |
| "human": "user", |
| "user": "user", |
| "prompter": "user", |
| "customer": "user", |
| "system": "system", |
| } |
|
|
|
|
| def _word_count(text: str) -> int: |
| return len(text.split()) |
|
|
|
|
| def _alpha_ratio(text: str) -> float: |
| if not text: |
| return 0.0 |
| alpha_count = sum(character.isalpha() for character in text) |
| return alpha_count / len(text) |
|
|
|
|
| def _default_record_weight(record_type: str) -> int: |
| if record_type == "dialogue_turn": |
| return 2 |
| if record_type == "instruction_answer": |
| return 2 |
| if record_type == "preference_chosen": |
| return 3 |
| if record_type == "preference_rejected": |
| return 0 |
| return 1 |
|
|
|
|
| def choose_text_field(columns: list[str]) -> str: |
| normalized = {column.casefold(): column for column in columns} |
| for preferred in TEXT_FIELD_PREFERENCES: |
| if preferred in normalized: |
| return normalized[preferred] |
| raise ValueError("Could not infer a text column. Pass --text-field explicitly.") |
|
|
|
|
| def choose_dialogue_field(columns: list[str]) -> str: |
| normalized = {column.casefold(): column for column in columns} |
| for preferred in DIALOGUE_FIELD_PREFERENCES: |
| if preferred in normalized: |
| return normalized[preferred] |
| raise ValueError("Could not infer a conversation column.") |
|
|
|
|
| def choose_preference_fields(columns: list[str]) -> tuple[str, str]: |
| normalized = {column.casefold(): column for column in columns} |
| for chosen_name, rejected_name in PREFERENCE_FIELD_PAIRS: |
| if chosen_name in normalized and rejected_name in normalized: |
| return normalized[chosen_name], normalized[rejected_name] |
| raise ValueError("Could not infer chosen/rejected preference columns.") |
|
|
|
|
| def choose_instruction_fields(columns: list[str]) -> tuple[str, str]: |
| normalized = {column.casefold(): column for column in columns} |
| for prompt_name, answer_name in INSTRUCTION_FIELD_PAIRS: |
| if prompt_name in normalized and answer_name in normalized: |
| return normalized[prompt_name], normalized[answer_name] |
| raise ValueError("Could not infer instruction/answer columns.") |
|
|
|
|
| def _row_identifier(row: dict[str, object]) -> str: |
| for candidate in ("id", "_id", "row_id", "uuid", "prompt_id"): |
| if candidate in row and str(row[candidate]).strip(): |
| return str(row[candidate]).strip() |
| return "" |
|
|
|
|
| def _base_record( |
| *, |
| dataset: str, |
| config: str | None, |
| split: str, |
| row_id: str, |
| ) -> dict[str, str]: |
| return { |
| "source": "huggingface", |
| "dataset": dataset, |
| "config": config or "", |
| "split": split, |
| "row_id": row_id, |
| } |
|
|
|
|
| def _row_language(row: dict[str, object]) -> str: |
| for candidate in ("lang", "language", "locale"): |
| value = row.get(candidate) |
| if isinstance(value, str) and value.strip(): |
| return value.strip() |
| return "" |
|
|
|
|
| def _normalize_role(raw_role: object) -> str: |
| role = str(raw_role or "").strip().casefold() |
| return ROLE_ALIASES.get(role, role) |
|
|
|
|
| def _message_content(message: dict[str, object]) -> str: |
| for field in ("content", "value", "text", "message"): |
| value = message.get(field) |
| if isinstance(value, str) and value.strip(): |
| return clean_training_text(value) |
| return "" |
|
|
|
|
| def _message_role(message: dict[str, object]) -> str: |
| for field in ("role", "from", "speaker", "author"): |
| value = message.get(field) |
| if value is not None: |
| normalized = _normalize_role(value) |
| if normalized: |
| return normalized |
| return "" |
|
|
|
|
| def _parse_dialogue_messages(raw_messages: object) -> list[dict[str, str]]: |
| if not isinstance(raw_messages, list): |
| return [] |
|
|
| parsed: list[dict[str, str]] = [] |
| for message in raw_messages: |
| if not isinstance(message, dict): |
| continue |
| role = _message_role(message) |
| content = _message_content(message) |
| if role not in {"system", "user", "assistant"} or not content: |
| continue |
| parsed.append({"role": role, "content": content}) |
| return parsed |
|
|
|
|
| def _parse_transcript_messages(raw_text: object) -> list[dict[str, str]]: |
| if not isinstance(raw_text, str): |
| return [] |
|
|
| text = raw_text.strip() |
| if not text: |
| return [] |
|
|
| matches = list(TRANSCRIPT_ROLE_PATTERN.finditer(text)) |
| if not matches: |
| return [] |
|
|
| parsed: list[dict[str, str]] = [] |
| for index, match in enumerate(matches): |
| role = _normalize_role(match.group(1)) |
| start = match.end() |
| end = matches[index + 1].start() if index + 1 < len(matches) else len(text) |
| content = clean_training_text(text[start:end].strip()) |
| if role in {"system", "user", "assistant"} and content: |
| parsed.append({"role": role, "content": content}) |
| return parsed |
|
|
|
|
| def _render_prompt(messages: list[dict[str, str]]) -> str: |
| lines = [] |
| for message in messages: |
| content = clean_context_text(message["content"]) |
| if content: |
| lines.append(content) |
| return "\n".join(lines).strip() |
|
|
|
|
| def _compose_training_text(context: str, answer: str) -> str: |
| context = clean_context_text(context) |
| answer = clean_answer_text(answer) |
| return f"<reason> {context} <answer> {answer}".strip() |
|
|
|
|
| def _compose_instruction_context(row: dict[str, object], prompt_field: str) -> str: |
| parts: list[str] = [] |
| prompt = clean_context_text(str(row.get(prompt_field, "")).strip()) |
| extra_input = clean_context_text(str(row.get("input", "")).strip()) |
| if prompt: |
| parts.append(prompt) |
| if extra_input: |
| parts.append(extra_input) |
| return "\n".join(parts).strip() |
|
|
|
|
| def _extract_prompt_answer( |
| row: dict[str, object], |
| *, |
| field_name: str, |
| ) -> tuple[str, str]: |
| dialogue_messages = _parse_dialogue_messages(row.get(field_name)) |
| if dialogue_messages and dialogue_messages[-1]["role"] == "assistant": |
| prompt = _render_prompt(dialogue_messages[:-1]) |
| answer = dialogue_messages[-1]["content"] |
| if prompt and answer: |
| return prompt, answer |
|
|
| messages = _parse_transcript_messages(row.get(field_name)) |
| if messages: |
| if messages[-1]["role"] == "assistant": |
| prompt = _render_prompt(messages[:-1]) |
| answer = messages[-1]["content"] |
| if prompt and answer: |
| return prompt, answer |
|
|
| prompt = clean_training_text(str(row.get("prompt", row.get("question", ""))).strip()) |
| answer = clean_answer_text(str(row.get(field_name, "")).strip()) |
| return prompt, answer |
|
|
|
|
| def _ordered_preference_fields( |
| row: dict[str, object], |
| *, |
| left_field: str, |
| right_field: str, |
| ) -> tuple[str, str]: |
| if {left_field, right_field} != {"response_0", "response_1"}: |
| return left_field, right_field |
|
|
| for selector in ("safer_response_id", "better_response_id"): |
| value = row.get(selector) |
| try: |
| preferred = int(value) |
| except (TypeError, ValueError): |
| continue |
| if preferred == 0: |
| return "response_0", "response_1" |
| if preferred == 1: |
| return "response_1", "response_0" |
| return left_field, right_field |
|
|
|
|
| def _passes_quality_gate( |
| record: dict[str, str], |
| *, |
| min_words: int, |
| max_words: int, |
| min_alpha_ratio: float, |
| allowed_languages: set[str], |
| ) -> bool: |
| candidate = str(record.get("answer") or record.get("text") or "").strip() |
| if not candidate: |
| return False |
|
|
| word_count = _word_count(candidate) |
| if min_words > 0 and word_count < min_words: |
| return False |
| if max_words > 0 and word_count > max_words: |
| return False |
|
|
| alpha_ratio = _alpha_ratio(candidate) |
| if min_alpha_ratio > 0.0 and alpha_ratio < min_alpha_ratio: |
| return False |
|
|
| if allowed_languages: |
| language = str(record.get("language", "")).strip().casefold() |
| if not language or language not in allowed_languages: |
| return False |
|
|
| record["quality_word_count"] = str(word_count) |
| record["quality_alpha_ratio"] = f"{alpha_ratio:.4f}" |
| return True |
|
|
|
|
| def to_json_record( |
| *, |
| dataset: str, |
| config: str | None, |
| split: str, |
| text_field: str, |
| row: dict[str, object], |
| ) -> dict[str, str]: |
| text = clean_training_text(str(row.get(text_field, "")).strip()) |
| if not text: |
| raise ValueError("Row is missing usable text.") |
|
|
| record_type = "text" |
| return { |
| **_base_record( |
| dataset=dataset, |
| config=config, |
| split=split, |
| row_id=_row_identifier(row), |
| ), |
| "record_type": record_type, |
| "language": _row_language(row), |
| "text_field": text_field, |
| "text": text, |
| "word_count": _word_count(text), |
| "weight": _default_record_weight(record_type), |
| } |
|
|
|
|
| def dialogue_to_json_records( |
| *, |
| dataset: str, |
| config: str | None, |
| split: str, |
| conversation_field: str, |
| row: dict[str, object], |
| ) -> list[dict[str, str]]: |
| messages = _parse_dialogue_messages(row.get(conversation_field)) |
| if not messages: |
| raise ValueError("Row does not contain usable dialogue turns.") |
|
|
| row_id = _row_identifier(row) |
| records: list[dict[str, str]] = [] |
| history: list[dict[str, str]] = [] |
| row_language = _row_language(row) |
| system_text = clean_training_text(str(row.get("system", "")).strip()) |
| if system_text: |
| history.append({"role": "system", "content": system_text}) |
| assistant_turn_index = 0 |
| for message in messages: |
| if message["role"] != "assistant": |
| history.append(message) |
| continue |
| prompt = _render_prompt(history) |
| if not prompt: |
| continue |
| assistant_turn_index += 1 |
| records.append( |
| { |
| **_base_record( |
| dataset=dataset, |
| config=config, |
| split=split, |
| row_id=row_id, |
| ), |
| "record_type": "dialogue_turn", |
| "language": row_language, |
| "conversation_field": conversation_field, |
| "turn_index": str(assistant_turn_index), |
| "context": prompt, |
| "answer": clean_answer_text(message["content"]), |
| "text": _compose_training_text(prompt, message["content"]), |
| "word_count": _word_count(clean_answer_text(message["content"])), |
| "weight": _default_record_weight("dialogue_turn"), |
| } |
| ) |
| history.append(message) |
|
|
| if not records: |
| raise ValueError("Dialogue row did not yield any assistant training turns.") |
| return records |
|
|
|
|
| def preference_to_json_records( |
| *, |
| dataset: str, |
| config: str | None, |
| split: str, |
| chosen_field: str, |
| rejected_field: str, |
| row: dict[str, object], |
| preference_target: str = "both", |
| ) -> list[dict[str, str]]: |
| row_id = _row_identifier(row) |
| pair_id = row_id or f"{chosen_field}:{rejected_field}" |
| records: list[dict[str, str]] = [] |
| row_language = _row_language(row) |
| chosen_field, rejected_field = _ordered_preference_fields( |
| row, |
| left_field=chosen_field, |
| right_field=rejected_field, |
| ) |
|
|
| field_specs = [ |
| (chosen_field, "preference_chosen"), |
| (rejected_field, "preference_rejected"), |
| ] |
| if preference_target == "chosen": |
| field_specs = [(chosen_field, "preference_chosen")] |
| elif preference_target == "rejected": |
| field_specs = [(rejected_field, "preference_rejected")] |
| elif preference_target != "both": |
| raise ValueError("preference_target must be one of: both, chosen, rejected.") |
|
|
| for field_name, record_type in field_specs: |
| prompt, answer = _extract_prompt_answer(row, field_name=field_name) |
| if not prompt or not answer: |
| continue |
| records.append( |
| { |
| **_base_record( |
| dataset=dataset, |
| config=config, |
| split=split, |
| row_id=row_id, |
| ), |
| "record_type": record_type, |
| "language": row_language, |
| "pair_id": pair_id, |
| "text_field": field_name, |
| "context": prompt, |
| "answer": clean_answer_text(answer), |
| "text": _compose_training_text(prompt, answer), |
| "word_count": _word_count(clean_answer_text(answer)), |
| "weight": _default_record_weight(record_type), |
| } |
| ) |
|
|
| if not records: |
| raise ValueError("Preference row did not yield usable chosen/rejected transcripts.") |
| return records |
|
|
|
|
| def instruction_to_json_records( |
| *, |
| dataset: str, |
| config: str | None, |
| split: str, |
| prompt_field: str, |
| answer_field: str, |
| row: dict[str, object], |
| ) -> list[dict[str, str]]: |
| context = _compose_instruction_context(row, prompt_field) |
| answer = clean_answer_text(str(row.get(answer_field, "")).strip()) |
| if not context or not answer: |
| raise ValueError("Instruction row did not contain usable prompt and answer text.") |
| record_type = "instruction_answer" |
| return [ |
| { |
| **_base_record( |
| dataset=dataset, |
| config=config, |
| split=split, |
| row_id=_row_identifier(row), |
| ), |
| "record_type": record_type, |
| "language": _row_language(row), |
| "context": context, |
| "answer": answer, |
| "text": _compose_training_text(context, answer), |
| "word_count": _word_count(answer), |
| "weight": _default_record_weight(record_type), |
| } |
| ] |
|
|
|
|
| def _expand_row_records( |
| *, |
| dataset: str, |
| config: str | None, |
| split: str, |
| row: dict[str, object], |
| text_field: str | None, |
| preference_target: str, |
| ) -> list[dict[str, str]]: |
| if text_field is not None: |
| explicit_value = row.get(text_field) |
| if isinstance(explicit_value, list): |
| return dialogue_to_json_records( |
| dataset=dataset, |
| config=config, |
| split=split, |
| conversation_field=text_field, |
| row=row, |
| ) |
| return [ |
| to_json_record( |
| dataset=dataset, |
| config=config, |
| split=split, |
| text_field=text_field, |
| row=row, |
| ) |
| ] |
|
|
| columns = list(row) |
| try: |
| chosen_field, rejected_field = choose_preference_fields(columns) |
| return preference_to_json_records( |
| dataset=dataset, |
| config=config, |
| split=split, |
| chosen_field=chosen_field, |
| rejected_field=rejected_field, |
| row=row, |
| preference_target=preference_target, |
| ) |
| except ValueError: |
| pass |
|
|
| try: |
| prompt_field, answer_field = choose_instruction_fields(columns) |
| return instruction_to_json_records( |
| dataset=dataset, |
| config=config, |
| split=split, |
| prompt_field=prompt_field, |
| answer_field=answer_field, |
| row=row, |
| ) |
| except ValueError: |
| pass |
|
|
| try: |
| conversation_field = choose_dialogue_field(columns) |
| if isinstance(row.get(conversation_field), list): |
| return dialogue_to_json_records( |
| dataset=dataset, |
| config=config, |
| split=split, |
| conversation_field=conversation_field, |
| row=row, |
| ) |
| except ValueError: |
| pass |
|
|
| inferred_text_field = choose_text_field(columns) |
| return [ |
| to_json_record( |
| dataset=dataset, |
| config=config, |
| split=split, |
| text_field=inferred_text_field, |
| row=row, |
| ) |
| ] |
|
|
|
|
| def import_hf_dataset( |
| *, |
| dataset: str, |
| output_path: str | Path, |
| config: str | None = None, |
| split: str = "train", |
| text_field: str | None = None, |
| limit: int = 1000, |
| streaming: bool = True, |
| preference_target: str = "chosen", |
| min_words: int = 0, |
| max_words: int = 0, |
| min_alpha_ratio: float = 0.0, |
| allowed_languages: tuple[str, ...] = (), |
| ) -> dict[str, object]: |
| try: |
| from datasets import load_dataset |
| except ModuleNotFoundError: |
| user_site = site.getusersitepackages() |
| if user_site and user_site not in sys.path: |
| sys.path.append(user_site) |
| from datasets import load_dataset |
|
|
| dataset_kwargs: dict[str, object] = { |
| "split": split, |
| "streaming": streaming, |
| } |
| if config: |
| dataset_kwargs["name"] = config |
|
|
| hf_dataset = load_dataset(dataset, **dataset_kwargs) |
| iterator = iter(hf_dataset) |
|
|
| first_row: dict[str, object] | None = None |
| if text_field is None: |
| first_row = dict(next(iterator)) |
| iterator = chain([first_row], iterator) |
|
|
| output = Path(output_path) |
| output.parent.mkdir(parents=True, exist_ok=True) |
|
|
| written = 0 |
| record_types: set[str] = set() |
| normalized_languages = {language.casefold() for language in allowed_languages if language.strip()} |
| with output.open("w", encoding="utf-8") as handle: |
| for row in iterator: |
| if written >= limit: |
| break |
| normalized_row = dict(row) |
| try: |
| records = _expand_row_records( |
| dataset=dataset, |
| config=config, |
| split=split, |
| row=normalized_row, |
| text_field=text_field, |
| preference_target=preference_target, |
| ) |
| except ValueError: |
| continue |
|
|
| for record in records: |
| if written >= limit: |
| break |
| if not _passes_quality_gate( |
| record, |
| min_words=min_words, |
| max_words=max_words, |
| min_alpha_ratio=min_alpha_ratio, |
| allowed_languages=normalized_languages, |
| ): |
| continue |
| record_types.add(record.get("record_type", "text")) |
| handle.write(json.dumps(record, ensure_ascii=False) + "\n") |
| written += 1 |
|
|
| inferred_mode = "mixed" if len(record_types) > 1 else (next(iter(record_types)) if record_types else "unknown") |
| return { |
| "dataset": dataset, |
| "config": config or "", |
| "split": split, |
| "text_field": text_field or "", |
| "output_path": str(output.resolve()), |
| "records_written": written, |
| "record_types": sorted(record_types), |
| "mode": inferred_mode, |
| "preference_target": preference_target, |
| "streaming": streaming, |
| "min_words": min_words, |
| "max_words": max_words, |
| "min_alpha_ratio": min_alpha_ratio, |
| "allowed_languages": sorted(normalized_languages), |
| } |
|
|