Dataset-Creator / schema_detect.py
TitleOS's picture
Upload 9 files
390cebe verified
Raw
History Blame Contribute Delete
6.19 kB
"""Pure functions for guessing how a raw HF dataset row maps onto a
(system, user, assistant) triplet, based on a handful of sample rows.
Nothing here touches the network - callers fetch sample rows separately
(see hf_inspect.py) and hand them in. That split is deliberate:
this logic is the part worth unit-testing on its own, against fixed
sample data, without spinning up HTTP calls every time.
"""
from __future__ import annotations
from typing import Optional
from models import FieldMapping
# Known flat-column name pairs, checked case-insensitively against the
# dataset's actual column names. Ordered roughly by how common they are
# across instruction-tuning datasets on the Hub.
_KNOWN_FLAT_PAIRS = [
("instruction", "output"),
("instruction", "response"),
("prompt", "chosen"),
("prompt", "completion"),
("query", "answer"),
("query", "answers"),
("question", "solution"),
("question", "answer"),
("input", "output"),
]
# Known (human_tag, assistant_tag) pairs for the role key inside a
# conversation-list column (e.g. ShareGPT's "from", OpenAI's "role").
_KNOWN_TAG_PAIRS = [
("human", "gpt"),
("user", "assistant"),
("user", "model"),
]
def _is_list_of_dicts_column(values: list) -> bool:
for v in values:
if not isinstance(v, list) or not v:
return False
if not all(isinstance(item, dict) for item in v):
return False
return True
def detect_flat_pair(columns: list) -> Optional[FieldMapping]:
"""Match the dataset's column names against known flat pairs."""
lower_to_actual = {c.lower(): c for c in columns}
for user_name, asst_name in _KNOWN_FLAT_PAIRS:
if user_name in lower_to_actual and asst_name in lower_to_actual:
return FieldMapping(
kind="flat_pair",
config={
"user_field": lower_to_actual[user_name],
"assistant_field": lower_to_actual[asst_name],
},
)
return None
def _guess_role_and_content_key(values: list, keys: set):
"""Pick which key behaves like a role tag (short, low-cardinality
strings) and which behaves like free-text content (longer, varied).
Returns (role_key, content_key), or (None, None) if it can't tell.
"""
candidates = list(keys)
if len(candidates) < 2:
return None, None
avg_lengths = {}
distinct_counts = {}
for key in candidates:
all_values = [item.get(key) for value_list in values for item in value_list if key in item]
string_values = [v for v in all_values if isinstance(v, str)]
if not string_values:
avg_lengths[key] = float("inf")
distinct_counts[key] = len(set(map(str, all_values)))
continue
avg_lengths[key] = sum(len(v) for v in string_values) / len(string_values)
distinct_counts[key] = len(set(string_values))
role_key = min(candidates, key=lambda k: (distinct_counts[k], avg_lengths[k]))
remaining = [k for k in candidates if k != role_key]
content_key = max(remaining, key=lambda k: avg_lengths[k])
return role_key, content_key
def detect_list_column(sample_rows: list) -> Optional[dict]:
"""Find a column whose values look like a conversation list (ShareGPT's
`conversations`, OpenAI's `messages`, or anything shaped like them) and
figure out which sub-key is the role and which is the content.
Returns a dict describing what was found - used both to try full
auto-detection and to pre-fill the manual-mapping UI when auto-detect
isn't confident. Returns None if nothing list-shaped turned up.
"""
if not sample_rows:
return None
columns = list(sample_rows[0].keys())
for col in columns:
values = [row.get(col) for row in sample_rows]
if not _is_list_of_dicts_column(values):
continue
common_keys = None
for value_list in values:
for item in value_list:
keys = set(item.keys())
common_keys = keys if common_keys is None else (common_keys & keys)
if not common_keys or len(common_keys) < 2:
continue
role_key, content_key = _guess_role_and_content_key(values, common_keys)
if role_key is None:
continue
tag_values = sorted(
{
item.get(role_key)
for value_list in values
for item in value_list
if role_key in item and item.get(role_key) is not None
}
)
return {
"list_field": col,
"role_key": role_key,
"content_key": content_key,
"tag_values": tag_values,
}
return None
def detect_conversation_list(sample_rows: list) -> Optional[FieldMapping]:
"""Full auto-detect for conversation-list columns. Only returns a
mapping when both the human and assistant tags are recognized -
anything less certain falls through to manual mapping on purpose.
"""
found = detect_list_column(sample_rows)
if not found:
return None
tags_lower = {str(t).lower(): t for t in found["tag_values"] if t is not None}
for human_tag, gpt_tag in _KNOWN_TAG_PAIRS:
if human_tag in tags_lower and gpt_tag in tags_lower:
return FieldMapping(
kind="conversation_list",
config={
"list_field": found["list_field"],
"role_key": found["role_key"],
"content_key": found["content_key"],
"human_tag": tags_lower[human_tag],
"gpt_tag": tags_lower[gpt_tag],
},
)
return None
def auto_detect(sample_rows: list) -> Optional[FieldMapping]:
"""Try every detector in order of confidence. Returns None if nothing
lands cleanly - caller should fall back to manual mapping.
"""
if not sample_rows:
return None
mapping = detect_conversation_list(sample_rows)
if mapping:
return mapping
columns = list(sample_rows[0].keys())
return detect_flat_pair(columns)