Spaces:
Running
Running
File size: 6,193 Bytes
390cebe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """Pure functions for guessing how a raw HF dataset row maps onto a
(system, user, assistant) triplet, based on a handful of sample rows.
Nothing here touches the network - callers fetch sample rows separately
(see hf_inspect.py) and hand them in. That split is deliberate:
this logic is the part worth unit-testing on its own, against fixed
sample data, without spinning up HTTP calls every time.
"""
from __future__ import annotations
from typing import Optional
from models import FieldMapping
# Known flat-column name pairs, checked case-insensitively against the
# dataset's actual column names. Ordered roughly by how common they are
# across instruction-tuning datasets on the Hub.
_KNOWN_FLAT_PAIRS = [
("instruction", "output"),
("instruction", "response"),
("prompt", "chosen"),
("prompt", "completion"),
("query", "answer"),
("query", "answers"),
("question", "solution"),
("question", "answer"),
("input", "output"),
]
# Known (human_tag, assistant_tag) pairs for the role key inside a
# conversation-list column (e.g. ShareGPT's "from", OpenAI's "role").
_KNOWN_TAG_PAIRS = [
("human", "gpt"),
("user", "assistant"),
("user", "model"),
]
def _is_list_of_dicts_column(values: list) -> bool:
for v in values:
if not isinstance(v, list) or not v:
return False
if not all(isinstance(item, dict) for item in v):
return False
return True
def detect_flat_pair(columns: list) -> Optional[FieldMapping]:
"""Match the dataset's column names against known flat pairs."""
lower_to_actual = {c.lower(): c for c in columns}
for user_name, asst_name in _KNOWN_FLAT_PAIRS:
if user_name in lower_to_actual and asst_name in lower_to_actual:
return FieldMapping(
kind="flat_pair",
config={
"user_field": lower_to_actual[user_name],
"assistant_field": lower_to_actual[asst_name],
},
)
return None
def _guess_role_and_content_key(values: list, keys: set):
"""Pick which key behaves like a role tag (short, low-cardinality
strings) and which behaves like free-text content (longer, varied).
Returns (role_key, content_key), or (None, None) if it can't tell.
"""
candidates = list(keys)
if len(candidates) < 2:
return None, None
avg_lengths = {}
distinct_counts = {}
for key in candidates:
all_values = [item.get(key) for value_list in values for item in value_list if key in item]
string_values = [v for v in all_values if isinstance(v, str)]
if not string_values:
avg_lengths[key] = float("inf")
distinct_counts[key] = len(set(map(str, all_values)))
continue
avg_lengths[key] = sum(len(v) for v in string_values) / len(string_values)
distinct_counts[key] = len(set(string_values))
role_key = min(candidates, key=lambda k: (distinct_counts[k], avg_lengths[k]))
remaining = [k for k in candidates if k != role_key]
content_key = max(remaining, key=lambda k: avg_lengths[k])
return role_key, content_key
def detect_list_column(sample_rows: list) -> Optional[dict]:
"""Find a column whose values look like a conversation list (ShareGPT's
`conversations`, OpenAI's `messages`, or anything shaped like them) and
figure out which sub-key is the role and which is the content.
Returns a dict describing what was found - used both to try full
auto-detection and to pre-fill the manual-mapping UI when auto-detect
isn't confident. Returns None if nothing list-shaped turned up.
"""
if not sample_rows:
return None
columns = list(sample_rows[0].keys())
for col in columns:
values = [row.get(col) for row in sample_rows]
if not _is_list_of_dicts_column(values):
continue
common_keys = None
for value_list in values:
for item in value_list:
keys = set(item.keys())
common_keys = keys if common_keys is None else (common_keys & keys)
if not common_keys or len(common_keys) < 2:
continue
role_key, content_key = _guess_role_and_content_key(values, common_keys)
if role_key is None:
continue
tag_values = sorted(
{
item.get(role_key)
for value_list in values
for item in value_list
if role_key in item and item.get(role_key) is not None
}
)
return {
"list_field": col,
"role_key": role_key,
"content_key": content_key,
"tag_values": tag_values,
}
return None
def detect_conversation_list(sample_rows: list) -> Optional[FieldMapping]:
"""Full auto-detect for conversation-list columns. Only returns a
mapping when both the human and assistant tags are recognized -
anything less certain falls through to manual mapping on purpose.
"""
found = detect_list_column(sample_rows)
if not found:
return None
tags_lower = {str(t).lower(): t for t in found["tag_values"] if t is not None}
for human_tag, gpt_tag in _KNOWN_TAG_PAIRS:
if human_tag in tags_lower and gpt_tag in tags_lower:
return FieldMapping(
kind="conversation_list",
config={
"list_field": found["list_field"],
"role_key": found["role_key"],
"content_key": found["content_key"],
"human_tag": tags_lower[human_tag],
"gpt_tag": tags_lower[gpt_tag],
},
)
return None
def auto_detect(sample_rows: list) -> Optional[FieldMapping]:
"""Try every detector in order of confidence. Returns None if nothing
lands cleanly - caller should fall back to manual mapping.
"""
if not sample_rows:
return None
mapping = detect_conversation_list(sample_rows)
if mapping:
return mapping
columns = list(sample_rows[0].keys())
return detect_flat_pair(columns)
|