File size: 6,193 Bytes
390cebe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Pure functions for guessing how a raw HF dataset row maps onto a
(system, user, assistant) triplet, based on a handful of sample rows.

Nothing here touches the network - callers fetch sample rows separately
(see hf_inspect.py) and hand them in. That split is deliberate:
this logic is the part worth unit-testing on its own, against fixed
sample data, without spinning up HTTP calls every time.
"""
from __future__ import annotations

from typing import Optional

from models import FieldMapping

# Known flat-column name pairs, checked case-insensitively against the
# dataset's actual column names. Ordered roughly by how common they are
# across instruction-tuning datasets on the Hub.
_KNOWN_FLAT_PAIRS = [
    ("instruction", "output"),
    ("instruction", "response"),
    ("prompt", "chosen"),
    ("prompt", "completion"),
    ("query", "answer"),
    ("query", "answers"),
    ("question", "solution"),
    ("question", "answer"),
    ("input", "output"),
]

# Known (human_tag, assistant_tag) pairs for the role key inside a
# conversation-list column (e.g. ShareGPT's "from", OpenAI's "role").
_KNOWN_TAG_PAIRS = [
    ("human", "gpt"),
    ("user", "assistant"),
    ("user", "model"),
]


def _is_list_of_dicts_column(values: list) -> bool:
    for v in values:
        if not isinstance(v, list) or not v:
            return False
        if not all(isinstance(item, dict) for item in v):
            return False
    return True


def detect_flat_pair(columns: list) -> Optional[FieldMapping]:
    """Match the dataset's column names against known flat pairs."""
    lower_to_actual = {c.lower(): c for c in columns}
    for user_name, asst_name in _KNOWN_FLAT_PAIRS:
        if user_name in lower_to_actual and asst_name in lower_to_actual:
            return FieldMapping(
                kind="flat_pair",
                config={
                    "user_field": lower_to_actual[user_name],
                    "assistant_field": lower_to_actual[asst_name],
                },
            )
    return None


def _guess_role_and_content_key(values: list, keys: set):
    """Pick which key behaves like a role tag (short, low-cardinality
    strings) and which behaves like free-text content (longer, varied).
    Returns (role_key, content_key), or (None, None) if it can't tell.
    """
    candidates = list(keys)
    if len(candidates) < 2:
        return None, None

    avg_lengths = {}
    distinct_counts = {}
    for key in candidates:
        all_values = [item.get(key) for value_list in values for item in value_list if key in item]
        string_values = [v for v in all_values if isinstance(v, str)]
        if not string_values:
            avg_lengths[key] = float("inf")
            distinct_counts[key] = len(set(map(str, all_values)))
            continue
        avg_lengths[key] = sum(len(v) for v in string_values) / len(string_values)
        distinct_counts[key] = len(set(string_values))

    role_key = min(candidates, key=lambda k: (distinct_counts[k], avg_lengths[k]))
    remaining = [k for k in candidates if k != role_key]
    content_key = max(remaining, key=lambda k: avg_lengths[k])
    return role_key, content_key


def detect_list_column(sample_rows: list) -> Optional[dict]:
    """Find a column whose values look like a conversation list (ShareGPT's
    `conversations`, OpenAI's `messages`, or anything shaped like them) and
    figure out which sub-key is the role and which is the content.

    Returns a dict describing what was found - used both to try full
    auto-detection and to pre-fill the manual-mapping UI when auto-detect
    isn't confident. Returns None if nothing list-shaped turned up.
    """
    if not sample_rows:
        return None

    columns = list(sample_rows[0].keys())
    for col in columns:
        values = [row.get(col) for row in sample_rows]
        if not _is_list_of_dicts_column(values):
            continue

        common_keys = None
        for value_list in values:
            for item in value_list:
                keys = set(item.keys())
                common_keys = keys if common_keys is None else (common_keys & keys)
        if not common_keys or len(common_keys) < 2:
            continue

        role_key, content_key = _guess_role_and_content_key(values, common_keys)
        if role_key is None:
            continue

        tag_values = sorted(
            {
                item.get(role_key)
                for value_list in values
                for item in value_list
                if role_key in item and item.get(role_key) is not None
            }
        )

        return {
            "list_field": col,
            "role_key": role_key,
            "content_key": content_key,
            "tag_values": tag_values,
        }

    return None


def detect_conversation_list(sample_rows: list) -> Optional[FieldMapping]:
    """Full auto-detect for conversation-list columns. Only returns a
    mapping when both the human and assistant tags are recognized -
    anything less certain falls through to manual mapping on purpose.
    """
    found = detect_list_column(sample_rows)
    if not found:
        return None

    tags_lower = {str(t).lower(): t for t in found["tag_values"] if t is not None}
    for human_tag, gpt_tag in _KNOWN_TAG_PAIRS:
        if human_tag in tags_lower and gpt_tag in tags_lower:
            return FieldMapping(
                kind="conversation_list",
                config={
                    "list_field": found["list_field"],
                    "role_key": found["role_key"],
                    "content_key": found["content_key"],
                    "human_tag": tags_lower[human_tag],
                    "gpt_tag": tags_lower[gpt_tag],
                },
            )
    return None


def auto_detect(sample_rows: list) -> Optional[FieldMapping]:
    """Try every detector in order of confidence. Returns None if nothing
    lands cleanly - caller should fall back to manual mapping.
    """
    if not sample_rows:
        return None
    mapping = detect_conversation_list(sample_rows)
    if mapping:
        return mapping
    columns = list(sample_rows[0].keys())
    return detect_flat_pair(columns)