"""Pure functions for guessing how a raw HF dataset row maps onto a (system, user, assistant) triplet, based on a handful of sample rows. Nothing here touches the network - callers fetch sample rows separately (see hf_inspect.py) and hand them in. That split is deliberate: this logic is the part worth unit-testing on its own, against fixed sample data, without spinning up HTTP calls every time. """ from __future__ import annotations from typing import Optional from models import FieldMapping # Known flat-column name pairs, checked case-insensitively against the # dataset's actual column names. Ordered roughly by how common they are # across instruction-tuning datasets on the Hub. _KNOWN_FLAT_PAIRS = [ ("instruction", "output"), ("instruction", "response"), ("prompt", "chosen"), ("prompt", "completion"), ("query", "answer"), ("query", "answers"), ("question", "solution"), ("question", "answer"), ("input", "output"), ] # Known (human_tag, assistant_tag) pairs for the role key inside a # conversation-list column (e.g. ShareGPT's "from", OpenAI's "role"). _KNOWN_TAG_PAIRS = [ ("human", "gpt"), ("user", "assistant"), ("user", "model"), ] def _is_list_of_dicts_column(values: list) -> bool: for v in values: if not isinstance(v, list) or not v: return False if not all(isinstance(item, dict) for item in v): return False return True def detect_flat_pair(columns: list) -> Optional[FieldMapping]: """Match the dataset's column names against known flat pairs.""" lower_to_actual = {c.lower(): c for c in columns} for user_name, asst_name in _KNOWN_FLAT_PAIRS: if user_name in lower_to_actual and asst_name in lower_to_actual: return FieldMapping( kind="flat_pair", config={ "user_field": lower_to_actual[user_name], "assistant_field": lower_to_actual[asst_name], }, ) return None def _guess_role_and_content_key(values: list, keys: set): """Pick which key behaves like a role tag (short, low-cardinality strings) and which behaves like free-text content (longer, varied). Returns (role_key, content_key), or (None, None) if it can't tell. """ candidates = list(keys) if len(candidates) < 2: return None, None avg_lengths = {} distinct_counts = {} for key in candidates: all_values = [item.get(key) for value_list in values for item in value_list if key in item] string_values = [v for v in all_values if isinstance(v, str)] if not string_values: avg_lengths[key] = float("inf") distinct_counts[key] = len(set(map(str, all_values))) continue avg_lengths[key] = sum(len(v) for v in string_values) / len(string_values) distinct_counts[key] = len(set(string_values)) role_key = min(candidates, key=lambda k: (distinct_counts[k], avg_lengths[k])) remaining = [k for k in candidates if k != role_key] content_key = max(remaining, key=lambda k: avg_lengths[k]) return role_key, content_key def detect_list_column(sample_rows: list) -> Optional[dict]: """Find a column whose values look like a conversation list (ShareGPT's `conversations`, OpenAI's `messages`, or anything shaped like them) and figure out which sub-key is the role and which is the content. Returns a dict describing what was found - used both to try full auto-detection and to pre-fill the manual-mapping UI when auto-detect isn't confident. Returns None if nothing list-shaped turned up. """ if not sample_rows: return None columns = list(sample_rows[0].keys()) for col in columns: values = [row.get(col) for row in sample_rows] if not _is_list_of_dicts_column(values): continue common_keys = None for value_list in values: for item in value_list: keys = set(item.keys()) common_keys = keys if common_keys is None else (common_keys & keys) if not common_keys or len(common_keys) < 2: continue role_key, content_key = _guess_role_and_content_key(values, common_keys) if role_key is None: continue tag_values = sorted( { item.get(role_key) for value_list in values for item in value_list if role_key in item and item.get(role_key) is not None } ) return { "list_field": col, "role_key": role_key, "content_key": content_key, "tag_values": tag_values, } return None def detect_conversation_list(sample_rows: list) -> Optional[FieldMapping]: """Full auto-detect for conversation-list columns. Only returns a mapping when both the human and assistant tags are recognized - anything less certain falls through to manual mapping on purpose. """ found = detect_list_column(sample_rows) if not found: return None tags_lower = {str(t).lower(): t for t in found["tag_values"] if t is not None} for human_tag, gpt_tag in _KNOWN_TAG_PAIRS: if human_tag in tags_lower and gpt_tag in tags_lower: return FieldMapping( kind="conversation_list", config={ "list_field": found["list_field"], "role_key": found["role_key"], "content_key": found["content_key"], "human_tag": tags_lower[human_tag], "gpt_tag": tags_lower[gpt_tag], }, ) return None def auto_detect(sample_rows: list) -> Optional[FieldMapping]: """Try every detector in order of confidence. Returns None if nothing lands cleanly - caller should fall back to manual mapping. """ if not sample_rows: return None mapping = detect_conversation_list(sample_rows) if mapping: return mapping columns = list(sample_rows[0].keys()) return detect_flat_pair(columns)