"""Dataset Creator - Gradio app. Lets you list any number of Hugging Face datasets, auto-detects (or lets you manually map) how each one's rows turn into a chat-format triplet, then combines, shuffles, and either pushes the result to your own HF account or hands you a JSONL download. UI wiring conventions used throughout this file, so the OAuth injection keeps working: - Every per-entry callback takes `entries` (the gr.State list) and `uid` (a `gr.State(uid)` constant created at render time) as its first two parameters, matching the first two items in `inputs=`. - `oauth_token` / `oauth_profile` are declared as trailing, annotated parameters and are *never* included in `inputs=` - Gradio resolves them automatically from the signed-in session. See the Gradio OAuth guide; this is the one part of the app that needs a live Spaces OAuth session to fully smoke-test. """ from __future__ import annotations import os import random import tempfile import gradio as gr import field_mapper import schema_detect import hf_dataset_loader import hf_inspect import hf_publish from models import DatasetEntry, FieldMapping _STATUS_LABELS = { "empty": "Not detected yet.", "detecting": "Checking schema...", "needs_mapping": "Couldn't auto-detect the field layout - map it manually below.", "ready": "Ready.", "error": "Error.", } def _find(entries: list, uid: str) -> DatasetEntry: for entry in entries: if entry.uid == uid: return entry raise ValueError(f"No dataset entry with uid {uid!r} - it may have been removed.") def _status_text(entry: DatasetEntry) -> str: label = _STATUS_LABELS.get(entry.status, entry.status) if entry.status == "error" and entry.error_message: return f"**Status:** {label} {entry.error_message}" if entry.status == "ready" and entry.mapping: return f"**Status:** {label} (mapping type: `{entry.mapping.kind}`)" return f"**Status:** {label}" # --- entry list mutation ---------------------------------------------------- def add_entry(entries: list) -> list: return entries + [DatasetEntry()] def remove_entry(entries: list, uid: str) -> list: return [e for e in entries if e.uid != uid] def update_field(entries: list, uid: str, field_name: str, value) -> list: entry = _find(entries, uid) setattr(entry, field_name, value) return entries def update_limit(entries: list, uid: str, value) -> list: entry = _find(entries, uid) entry.limit = int(value) if value else 0 return entries # --- schema detection / mapping --------------------------------------------- def detect_entry(entries: list, uid: str, oauth_token: gr.OAuthToken | None) -> list: entry = _find(entries, uid) if not entry.repo_id.strip(): entry.status = "error" entry.error_message = "Enter a dataset repo id first." return entries entry.status = "detecting" token = oauth_token.token if oauth_token else None try: rows = hf_inspect.peek_rows( entry.repo_id.strip(), entry.subset.strip(), entry.split.strip() or "train", sample_size=8, token=token, ) except hf_inspect.DatasetInspectError as exc: entry.status = "error" entry.error_message = str(exc) return entries entry.sample_rows = rows entry.detected_columns = list(rows[0].keys()) entry.detected_list_info = schema_detect.detect_list_column(rows) mapping = schema_detect.auto_detect(rows) if mapping: entry.mapping = mapping entry.status = "ready" entry.error_message = "" else: entry.mapping = None entry.status = "needs_mapping" entry.error_message = "" return entries def apply_flat_mapping(entries: list, uid: str, user_field: str, assistant_field: str) -> list: entry = _find(entries, uid) if not user_field or not assistant_field: entry.status = "error" entry.error_message = "Pick both a user field and an assistant field." return entries entry.mapping = FieldMapping( kind="flat_pair", config={"user_field": user_field, "assistant_field": assistant_field}, ) entry.status = "ready" entry.error_message = "" return entries def apply_list_mapping( entries: list, uid: str, list_field: str, role_key: str, content_key: str, human_tag: str, assistant_tag: str, ) -> list: entry = _find(entries, uid) if not all([list_field, role_key, content_key, human_tag, assistant_tag]): entry.status = "error" entry.error_message = "Fill in every field for the conversation-list mapping." return entries if human_tag == assistant_tag: entry.status = "error" entry.error_message = "Human and assistant tags can't be the same value." return entries entry.mapping = FieldMapping( kind="conversation_list", config={ "list_field": list_field, "role_key": role_key, "content_key": content_key, "human_tag": human_tag, "gpt_tag": assistant_tag, }, ) entry.status = "ready" entry.error_message = "" return entries # --- combine / export -------------------------------------------------------- def run_pipeline( entries: list, seed, oauth_token: gr.OAuthToken | None, progress=gr.Progress(), ): ready_entries = [e for e in entries if e.status == "ready" and e.mapping is not None] if not ready_entries: raise gr.Error("No datasets are ready. Detect or manually map at least one dataset first.") token = oauth_token.token if oauth_token else None all_records = [] breakdown_lines = [] total = len(ready_entries) for idx, entry in enumerate(ready_entries): progress((idx, total), desc=f"Loading {entry.repo_id}") try: rows = hf_dataset_loader.load_limited( entry.repo_id.strip(), entry.subset.strip(), entry.split.strip() or "train", int(entry.limit), token=token, ) except Exception as exc: raise gr.Error(f"Failed loading '{entry.repo_id}': {exc}") kept = 0 for row in rows: triplet = field_mapper.extract_triplet(row, entry.mapping, entry.system_prompt) if triplet: all_records.append(triplet) kept += 1 breakdown_lines.append(f"- `{entry.repo_id}` ({entry.split}): {kept}/{len(rows)} rows usable") progress((total, total), desc="Shuffling") rng = random.Random(int(seed)) if seed not in (None, "") else random.Random() rng.shuffle(all_records) summary = f"**Combined {len(all_records)} records from {total} dataset(s):**\n\n" + "\n".join(breakdown_lines) return all_records, summary def do_push( records: list, repo_name: str, private: bool, oauth_token: gr.OAuthToken | None, oauth_profile: gr.OAuthProfile | None, ): if not records: raise gr.Error("Build the combined dataset first.") if not oauth_token or not oauth_profile: raise gr.Error("Sign in with your Hugging Face account first.") if not repo_name.strip(): raise gr.Error("Give the dataset a name.") repo_id = f"{oauth_profile.username}/{repo_name.strip()}" try: url = hf_publish.push_dataset(records, repo_id, bool(private), oauth_token.token) except Exception as exc: raise gr.Error(f"Push failed: {exc}") return f"Pushed. View it at {url}" def do_download(records: list): if not records: raise gr.Error("Build the combined dataset first.") out_dir = tempfile.mkdtemp(prefix="sage_dataset_") out_path = os.path.join(out_dir, "combined_dataset.jsonl") hf_publish.write_jsonl(records, out_path) return out_path # --- UI ----------------------------------------------------------------------- with gr.Blocks(title="Dataset Creator") as demo: gr.Markdown("# Dataset Creator") gr.Markdown( "Combine chat-format data from multiple Hugging Face datasets into one " "shuffled set. Sign in to push the result to your own account, or just " "download it as JSONL." ) gr.LoginButton() entries_state = gr.State([]) combined_state = gr.State([]) add_btn = gr.Button("+ Add dataset") @gr.render(inputs=entries_state) def render_entries(entries): for entry in entries: uid = entry.uid uid_state = gr.State(uid) with gr.Group(): gr.Markdown(f"**{entry.repo_id or 'New dataset'}**", key=f"title-{uid}") with gr.Row(): repo_tb = gr.Textbox( label="HF dataset repo", value=entry.repo_id, placeholder="e.g. NousResearch/hermes-function-calling-v1", key=f"repo-{uid}", ) subset_tb = gr.Textbox( label="Config / subset (optional)", value=entry.subset, key=f"subset-{uid}" ) split_tb = gr.Textbox(label="Split", value=entry.split or "train", key=f"split-{uid}") limit_num = gr.Number( label="Sample limit", value=entry.limit, precision=0, key=f"limit-{uid}" ) system_tb = gr.Textbox( label="System prompt override", value=entry.system_prompt, lines=2, key=f"sys-{uid}", ) repo_field = gr.State("repo_id") subset_field = gr.State("subset") split_field = gr.State("split") sys_field = gr.State("system_prompt") repo_tb.change(update_field, [entries_state, uid_state, repo_field, repo_tb], entries_state) subset_tb.change(update_field, [entries_state, uid_state, subset_field, subset_tb], entries_state) split_tb.change(update_field, [entries_state, uid_state, split_field, split_tb], entries_state) limit_num.change(update_limit, [entries_state, uid_state, limit_num], entries_state) system_tb.change(update_field, [entries_state, uid_state, sys_field, system_tb], entries_state) with gr.Row(): detect_btn = gr.Button("Detect schema", key=f"detect-{uid}") remove_btn = gr.Button("Remove", variant="stop", key=f"remove-{uid}") gr.Markdown(_status_text(entry), key=f"status-{uid}") detect_btn.click(detect_entry, [entries_state, uid_state], entries_state) remove_btn.click(remove_entry, [entries_state, uid_state], entries_state) if entry.status == "needs_mapping": gr.Markdown( f"Columns found: `{', '.join(entry.detected_columns)}`", key=f"cols-{uid}" ) with gr.Row(): user_field_dd = gr.Dropdown( choices=entry.detected_columns, label="User field", key=f"userf-{uid}" ) asst_field_dd = gr.Dropdown( choices=entry.detected_columns, label="Assistant field", key=f"asstf-{uid}" ) apply_flat_btn = gr.Button("Apply flat mapping", key=f"applyflat-{uid}") apply_flat_btn.click( apply_flat_mapping, [entries_state, uid_state, user_field_dd, asst_field_dd], entries_state, ) if entry.detected_list_info: info = entry.detected_list_info gr.Markdown( f"Also found a conversation-style column: `{info['list_field']}` " f"(role key `{info['role_key']}`, content key `{info['content_key']}`, " f"tags seen: {info['tag_values']})", key=f"listinfo-{uid}", ) with gr.Row(): human_tag_dd = gr.Dropdown( choices=info["tag_values"], label="Human/user tag", key=f"humant-{uid}" ) asst_tag_dd = gr.Dropdown( choices=info["tag_values"], label="Assistant tag", key=f"asstt-{uid}" ) apply_list_btn = gr.Button( "Apply conversation-list mapping", key=f"applylist-{uid}" ) list_field_state = gr.State(info["list_field"]) role_key_state = gr.State(info["role_key"]) content_key_state = gr.State(info["content_key"]) apply_list_btn.click( apply_list_mapping, [ entries_state, uid_state, list_field_state, role_key_state, content_key_state, human_tag_dd, asst_tag_dd, ], entries_state, ) add_btn.click(add_entry, entries_state, entries_state) gr.Markdown("---") gr.Markdown("## Combine") seed_num = gr.Number(label="Shuffle seed (optional)", precision=0) build_btn = gr.Button("Build combined dataset", variant="primary") summary_md = gr.Markdown() build_btn.click(run_pipeline, [entries_state, seed_num], [combined_state, summary_md]) gr.Markdown("## Export") with gr.Row(): repo_name_tb = gr.Textbox(label="Dataset name (goes under your HF username)") private_cb = gr.Checkbox(label="Private", value=True) push_btn = gr.Button("Push to Hub") push_result_md = gr.Markdown() push_btn.click(do_push, [combined_state, repo_name_tb, private_cb], push_result_md) download_btn = gr.Button("Download as JSONL") download_file = gr.File(label="Download") download_btn.click(do_download, combined_state, download_file) if __name__ == "__main__": demo.launch()