Spaces:
Running
Running
| """Pure functions for guessing how a raw HF dataset row maps onto a | |
| (system, user, assistant) triplet, based on a handful of sample rows. | |
| Nothing here touches the network - callers fetch sample rows separately | |
| (see hf_inspect.py) and hand them in. That split is deliberate: | |
| this logic is the part worth unit-testing on its own, against fixed | |
| sample data, without spinning up HTTP calls every time. | |
| """ | |
| from __future__ import annotations | |
| from typing import Optional | |
| from models import FieldMapping | |
| # Known flat-column name pairs, checked case-insensitively against the | |
| # dataset's actual column names. Ordered roughly by how common they are | |
| # across instruction-tuning datasets on the Hub. | |
| _KNOWN_FLAT_PAIRS = [ | |
| ("instruction", "output"), | |
| ("instruction", "response"), | |
| ("prompt", "chosen"), | |
| ("prompt", "completion"), | |
| ("query", "answer"), | |
| ("query", "answers"), | |
| ("question", "solution"), | |
| ("question", "answer"), | |
| ("input", "output"), | |
| ] | |
| # Known (human_tag, assistant_tag) pairs for the role key inside a | |
| # conversation-list column (e.g. ShareGPT's "from", OpenAI's "role"). | |
| _KNOWN_TAG_PAIRS = [ | |
| ("human", "gpt"), | |
| ("user", "assistant"), | |
| ("user", "model"), | |
| ] | |
| def _is_list_of_dicts_column(values: list) -> bool: | |
| for v in values: | |
| if not isinstance(v, list) or not v: | |
| return False | |
| if not all(isinstance(item, dict) for item in v): | |
| return False | |
| return True | |
| def detect_flat_pair(columns: list) -> Optional[FieldMapping]: | |
| """Match the dataset's column names against known flat pairs.""" | |
| lower_to_actual = {c.lower(): c for c in columns} | |
| for user_name, asst_name in _KNOWN_FLAT_PAIRS: | |
| if user_name in lower_to_actual and asst_name in lower_to_actual: | |
| return FieldMapping( | |
| kind="flat_pair", | |
| config={ | |
| "user_field": lower_to_actual[user_name], | |
| "assistant_field": lower_to_actual[asst_name], | |
| }, | |
| ) | |
| return None | |
| def _guess_role_and_content_key(values: list, keys: set): | |
| """Pick which key behaves like a role tag (short, low-cardinality | |
| strings) and which behaves like free-text content (longer, varied). | |
| Returns (role_key, content_key), or (None, None) if it can't tell. | |
| """ | |
| candidates = list(keys) | |
| if len(candidates) < 2: | |
| return None, None | |
| avg_lengths = {} | |
| distinct_counts = {} | |
| for key in candidates: | |
| all_values = [item.get(key) for value_list in values for item in value_list if key in item] | |
| string_values = [v for v in all_values if isinstance(v, str)] | |
| if not string_values: | |
| avg_lengths[key] = float("inf") | |
| distinct_counts[key] = len(set(map(str, all_values))) | |
| continue | |
| avg_lengths[key] = sum(len(v) for v in string_values) / len(string_values) | |
| distinct_counts[key] = len(set(string_values)) | |
| role_key = min(candidates, key=lambda k: (distinct_counts[k], avg_lengths[k])) | |
| remaining = [k for k in candidates if k != role_key] | |
| content_key = max(remaining, key=lambda k: avg_lengths[k]) | |
| return role_key, content_key | |
| def detect_list_column(sample_rows: list) -> Optional[dict]: | |
| """Find a column whose values look like a conversation list (ShareGPT's | |
| `conversations`, OpenAI's `messages`, or anything shaped like them) and | |
| figure out which sub-key is the role and which is the content. | |
| Returns a dict describing what was found - used both to try full | |
| auto-detection and to pre-fill the manual-mapping UI when auto-detect | |
| isn't confident. Returns None if nothing list-shaped turned up. | |
| """ | |
| if not sample_rows: | |
| return None | |
| columns = list(sample_rows[0].keys()) | |
| for col in columns: | |
| values = [row.get(col) for row in sample_rows] | |
| if not _is_list_of_dicts_column(values): | |
| continue | |
| common_keys = None | |
| for value_list in values: | |
| for item in value_list: | |
| keys = set(item.keys()) | |
| common_keys = keys if common_keys is None else (common_keys & keys) | |
| if not common_keys or len(common_keys) < 2: | |
| continue | |
| role_key, content_key = _guess_role_and_content_key(values, common_keys) | |
| if role_key is None: | |
| continue | |
| tag_values = sorted( | |
| { | |
| item.get(role_key) | |
| for value_list in values | |
| for item in value_list | |
| if role_key in item and item.get(role_key) is not None | |
| } | |
| ) | |
| return { | |
| "list_field": col, | |
| "role_key": role_key, | |
| "content_key": content_key, | |
| "tag_values": tag_values, | |
| } | |
| return None | |
| def detect_conversation_list(sample_rows: list) -> Optional[FieldMapping]: | |
| """Full auto-detect for conversation-list columns. Only returns a | |
| mapping when both the human and assistant tags are recognized - | |
| anything less certain falls through to manual mapping on purpose. | |
| """ | |
| found = detect_list_column(sample_rows) | |
| if not found: | |
| return None | |
| tags_lower = {str(t).lower(): t for t in found["tag_values"] if t is not None} | |
| for human_tag, gpt_tag in _KNOWN_TAG_PAIRS: | |
| if human_tag in tags_lower and gpt_tag in tags_lower: | |
| return FieldMapping( | |
| kind="conversation_list", | |
| config={ | |
| "list_field": found["list_field"], | |
| "role_key": found["role_key"], | |
| "content_key": found["content_key"], | |
| "human_tag": tags_lower[human_tag], | |
| "gpt_tag": tags_lower[gpt_tag], | |
| }, | |
| ) | |
| return None | |
| def auto_detect(sample_rows: list) -> Optional[FieldMapping]: | |
| """Try every detector in order of confidence. Returns None if nothing | |
| lands cleanly - caller should fall back to manual mapping. | |
| """ | |
| if not sample_rows: | |
| return None | |
| mapping = detect_conversation_list(sample_rows) | |
| if mapping: | |
| return mapping | |
| columns = list(sample_rows[0].keys()) | |
| return detect_flat_pair(columns) | |