Spaces:
Running
Running
| """Dataset Creator - Gradio app. | |
| Lets you list any number of Hugging Face datasets, auto-detects (or lets | |
| you manually map) how each one's rows turn into a chat-format triplet, | |
| then combines, shuffles, and either pushes the result to your own HF | |
| account or hands you a JSONL download. | |
| UI wiring conventions used throughout this file, so the OAuth injection | |
| keeps working: | |
| - Every per-entry callback takes `entries` (the gr.State list) and | |
| `uid` (a `gr.State(uid)` constant created at render time) as its | |
| first two parameters, matching the first two items in `inputs=`. | |
| - `oauth_token` / `oauth_profile` are declared as trailing, annotated | |
| parameters and are *never* included in `inputs=` - Gradio resolves | |
| them automatically from the signed-in session. See the Gradio OAuth | |
| guide; this is the one part of the app that needs a live Spaces | |
| OAuth session to fully smoke-test. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import random | |
| import tempfile | |
| import gradio as gr | |
| import field_mapper | |
| import schema_detect | |
| import hf_dataset_loader | |
| import hf_inspect | |
| import hf_publish | |
| from models import DatasetEntry, FieldMapping | |
| _STATUS_LABELS = { | |
| "empty": "Not detected yet.", | |
| "detecting": "Checking schema...", | |
| "needs_mapping": "Couldn't auto-detect the field layout - map it manually below.", | |
| "ready": "Ready.", | |
| "error": "Error.", | |
| } | |
| def _find(entries: list, uid: str) -> DatasetEntry: | |
| for entry in entries: | |
| if entry.uid == uid: | |
| return entry | |
| raise ValueError(f"No dataset entry with uid {uid!r} - it may have been removed.") | |
| def _status_text(entry: DatasetEntry) -> str: | |
| label = _STATUS_LABELS.get(entry.status, entry.status) | |
| if entry.status == "error" and entry.error_message: | |
| return f"**Status:** {label} {entry.error_message}" | |
| if entry.status == "ready" and entry.mapping: | |
| return f"**Status:** {label} (mapping type: `{entry.mapping.kind}`)" | |
| return f"**Status:** {label}" | |
| # --- entry list mutation ---------------------------------------------------- | |
| def add_entry(entries: list) -> list: | |
| return entries + [DatasetEntry()] | |
| def remove_entry(entries: list, uid: str) -> list: | |
| return [e for e in entries if e.uid != uid] | |
| def update_field(entries: list, uid: str, field_name: str, value) -> list: | |
| entry = _find(entries, uid) | |
| setattr(entry, field_name, value) | |
| return entries | |
| def update_limit(entries: list, uid: str, value) -> list: | |
| entry = _find(entries, uid) | |
| entry.limit = int(value) if value else 0 | |
| return entries | |
| # --- schema detection / mapping --------------------------------------------- | |
| def detect_entry(entries: list, uid: str, oauth_token: gr.OAuthToken | None) -> list: | |
| entry = _find(entries, uid) | |
| if not entry.repo_id.strip(): | |
| entry.status = "error" | |
| entry.error_message = "Enter a dataset repo id first." | |
| return entries | |
| entry.status = "detecting" | |
| token = oauth_token.token if oauth_token else None | |
| try: | |
| rows = hf_inspect.peek_rows( | |
| entry.repo_id.strip(), | |
| entry.subset.strip(), | |
| entry.split.strip() or "train", | |
| sample_size=8, | |
| token=token, | |
| ) | |
| except hf_inspect.DatasetInspectError as exc: | |
| entry.status = "error" | |
| entry.error_message = str(exc) | |
| return entries | |
| entry.sample_rows = rows | |
| entry.detected_columns = list(rows[0].keys()) | |
| entry.detected_list_info = schema_detect.detect_list_column(rows) | |
| mapping = schema_detect.auto_detect(rows) | |
| if mapping: | |
| entry.mapping = mapping | |
| entry.status = "ready" | |
| entry.error_message = "" | |
| else: | |
| entry.mapping = None | |
| entry.status = "needs_mapping" | |
| entry.error_message = "" | |
| return entries | |
| def apply_flat_mapping(entries: list, uid: str, user_field: str, assistant_field: str) -> list: | |
| entry = _find(entries, uid) | |
| if not user_field or not assistant_field: | |
| entry.status = "error" | |
| entry.error_message = "Pick both a user field and an assistant field." | |
| return entries | |
| entry.mapping = FieldMapping( | |
| kind="flat_pair", | |
| config={"user_field": user_field, "assistant_field": assistant_field}, | |
| ) | |
| entry.status = "ready" | |
| entry.error_message = "" | |
| return entries | |
| def apply_list_mapping( | |
| entries: list, | |
| uid: str, | |
| list_field: str, | |
| role_key: str, | |
| content_key: str, | |
| human_tag: str, | |
| assistant_tag: str, | |
| ) -> list: | |
| entry = _find(entries, uid) | |
| if not all([list_field, role_key, content_key, human_tag, assistant_tag]): | |
| entry.status = "error" | |
| entry.error_message = "Fill in every field for the conversation-list mapping." | |
| return entries | |
| if human_tag == assistant_tag: | |
| entry.status = "error" | |
| entry.error_message = "Human and assistant tags can't be the same value." | |
| return entries | |
| entry.mapping = FieldMapping( | |
| kind="conversation_list", | |
| config={ | |
| "list_field": list_field, | |
| "role_key": role_key, | |
| "content_key": content_key, | |
| "human_tag": human_tag, | |
| "gpt_tag": assistant_tag, | |
| }, | |
| ) | |
| entry.status = "ready" | |
| entry.error_message = "" | |
| return entries | |
| # --- combine / export -------------------------------------------------------- | |
| def run_pipeline( | |
| entries: list, | |
| seed, | |
| oauth_token: gr.OAuthToken | None, | |
| progress=gr.Progress(), | |
| ): | |
| ready_entries = [e for e in entries if e.status == "ready" and e.mapping is not None] | |
| if not ready_entries: | |
| raise gr.Error("No datasets are ready. Detect or manually map at least one dataset first.") | |
| token = oauth_token.token if oauth_token else None | |
| all_records = [] | |
| breakdown_lines = [] | |
| total = len(ready_entries) | |
| for idx, entry in enumerate(ready_entries): | |
| progress((idx, total), desc=f"Loading {entry.repo_id}") | |
| try: | |
| rows = hf_dataset_loader.load_limited( | |
| entry.repo_id.strip(), | |
| entry.subset.strip(), | |
| entry.split.strip() or "train", | |
| int(entry.limit), | |
| token=token, | |
| ) | |
| except Exception as exc: | |
| raise gr.Error(f"Failed loading '{entry.repo_id}': {exc}") | |
| kept = 0 | |
| for row in rows: | |
| triplet = field_mapper.extract_triplet(row, entry.mapping, entry.system_prompt) | |
| if triplet: | |
| all_records.append(triplet) | |
| kept += 1 | |
| breakdown_lines.append(f"- `{entry.repo_id}` ({entry.split}): {kept}/{len(rows)} rows usable") | |
| progress((total, total), desc="Shuffling") | |
| rng = random.Random(int(seed)) if seed not in (None, "") else random.Random() | |
| rng.shuffle(all_records) | |
| summary = f"**Combined {len(all_records)} records from {total} dataset(s):**\n\n" + "\n".join(breakdown_lines) | |
| return all_records, summary | |
| def do_push( | |
| records: list, | |
| repo_name: str, | |
| private: bool, | |
| oauth_token: gr.OAuthToken | None, | |
| oauth_profile: gr.OAuthProfile | None, | |
| ): | |
| if not records: | |
| raise gr.Error("Build the combined dataset first.") | |
| if not oauth_token or not oauth_profile: | |
| raise gr.Error("Sign in with your Hugging Face account first.") | |
| if not repo_name.strip(): | |
| raise gr.Error("Give the dataset a name.") | |
| repo_id = f"{oauth_profile.username}/{repo_name.strip()}" | |
| try: | |
| url = hf_publish.push_dataset(records, repo_id, bool(private), oauth_token.token) | |
| except Exception as exc: | |
| raise gr.Error(f"Push failed: {exc}") | |
| return f"Pushed. View it at {url}" | |
| def do_download(records: list): | |
| if not records: | |
| raise gr.Error("Build the combined dataset first.") | |
| out_dir = tempfile.mkdtemp(prefix="sage_dataset_") | |
| out_path = os.path.join(out_dir, "combined_dataset.jsonl") | |
| hf_publish.write_jsonl(records, out_path) | |
| return out_path | |
| # --- UI ----------------------------------------------------------------------- | |
| with gr.Blocks(title="Dataset Creator") as demo: | |
| gr.Markdown("# Dataset Creator") | |
| gr.Markdown( | |
| "Combine chat-format data from multiple Hugging Face datasets into one " | |
| "shuffled set. Sign in to push the result to your own account, or just " | |
| "download it as JSONL." | |
| ) | |
| gr.LoginButton() | |
| entries_state = gr.State([]) | |
| combined_state = gr.State([]) | |
| add_btn = gr.Button("+ Add dataset") | |
| def render_entries(entries): | |
| for entry in entries: | |
| uid = entry.uid | |
| uid_state = gr.State(uid) | |
| with gr.Group(): | |
| gr.Markdown(f"**{entry.repo_id or 'New dataset'}**", key=f"title-{uid}") | |
| with gr.Row(): | |
| repo_tb = gr.Textbox( | |
| label="HF dataset repo", | |
| value=entry.repo_id, | |
| placeholder="e.g. NousResearch/hermes-function-calling-v1", | |
| key=f"repo-{uid}", | |
| ) | |
| subset_tb = gr.Textbox( | |
| label="Config / subset (optional)", value=entry.subset, key=f"subset-{uid}" | |
| ) | |
| split_tb = gr.Textbox(label="Split", value=entry.split or "train", key=f"split-{uid}") | |
| limit_num = gr.Number( | |
| label="Sample limit", value=entry.limit, precision=0, key=f"limit-{uid}" | |
| ) | |
| system_tb = gr.Textbox( | |
| label="System prompt override", | |
| value=entry.system_prompt, | |
| lines=2, | |
| key=f"sys-{uid}", | |
| ) | |
| repo_field = gr.State("repo_id") | |
| subset_field = gr.State("subset") | |
| split_field = gr.State("split") | |
| sys_field = gr.State("system_prompt") | |
| repo_tb.change(update_field, [entries_state, uid_state, repo_field, repo_tb], entries_state) | |
| subset_tb.change(update_field, [entries_state, uid_state, subset_field, subset_tb], entries_state) | |
| split_tb.change(update_field, [entries_state, uid_state, split_field, split_tb], entries_state) | |
| limit_num.change(update_limit, [entries_state, uid_state, limit_num], entries_state) | |
| system_tb.change(update_field, [entries_state, uid_state, sys_field, system_tb], entries_state) | |
| with gr.Row(): | |
| detect_btn = gr.Button("Detect schema", key=f"detect-{uid}") | |
| remove_btn = gr.Button("Remove", variant="stop", key=f"remove-{uid}") | |
| gr.Markdown(_status_text(entry), key=f"status-{uid}") | |
| detect_btn.click(detect_entry, [entries_state, uid_state], entries_state) | |
| remove_btn.click(remove_entry, [entries_state, uid_state], entries_state) | |
| if entry.status == "needs_mapping": | |
| gr.Markdown( | |
| f"Columns found: `{', '.join(entry.detected_columns)}`", key=f"cols-{uid}" | |
| ) | |
| with gr.Row(): | |
| user_field_dd = gr.Dropdown( | |
| choices=entry.detected_columns, label="User field", key=f"userf-{uid}" | |
| ) | |
| asst_field_dd = gr.Dropdown( | |
| choices=entry.detected_columns, label="Assistant field", key=f"asstf-{uid}" | |
| ) | |
| apply_flat_btn = gr.Button("Apply flat mapping", key=f"applyflat-{uid}") | |
| apply_flat_btn.click( | |
| apply_flat_mapping, | |
| [entries_state, uid_state, user_field_dd, asst_field_dd], | |
| entries_state, | |
| ) | |
| if entry.detected_list_info: | |
| info = entry.detected_list_info | |
| gr.Markdown( | |
| f"Also found a conversation-style column: `{info['list_field']}` " | |
| f"(role key `{info['role_key']}`, content key `{info['content_key']}`, " | |
| f"tags seen: {info['tag_values']})", | |
| key=f"listinfo-{uid}", | |
| ) | |
| with gr.Row(): | |
| human_tag_dd = gr.Dropdown( | |
| choices=info["tag_values"], label="Human/user tag", key=f"humant-{uid}" | |
| ) | |
| asst_tag_dd = gr.Dropdown( | |
| choices=info["tag_values"], label="Assistant tag", key=f"asstt-{uid}" | |
| ) | |
| apply_list_btn = gr.Button( | |
| "Apply conversation-list mapping", key=f"applylist-{uid}" | |
| ) | |
| list_field_state = gr.State(info["list_field"]) | |
| role_key_state = gr.State(info["role_key"]) | |
| content_key_state = gr.State(info["content_key"]) | |
| apply_list_btn.click( | |
| apply_list_mapping, | |
| [ | |
| entries_state, | |
| uid_state, | |
| list_field_state, | |
| role_key_state, | |
| content_key_state, | |
| human_tag_dd, | |
| asst_tag_dd, | |
| ], | |
| entries_state, | |
| ) | |
| add_btn.click(add_entry, entries_state, entries_state) | |
| gr.Markdown("---") | |
| gr.Markdown("## Combine") | |
| seed_num = gr.Number(label="Shuffle seed (optional)", precision=0) | |
| build_btn = gr.Button("Build combined dataset", variant="primary") | |
| summary_md = gr.Markdown() | |
| build_btn.click(run_pipeline, [entries_state, seed_num], [combined_state, summary_md]) | |
| gr.Markdown("## Export") | |
| with gr.Row(): | |
| repo_name_tb = gr.Textbox(label="Dataset name (goes under your HF username)") | |
| private_cb = gr.Checkbox(label="Private", value=True) | |
| push_btn = gr.Button("Push to Hub") | |
| push_result_md = gr.Markdown() | |
| push_btn.click(do_push, [combined_state, repo_name_tb, private_cb], push_result_md) | |
| download_btn = gr.Button("Download as JSONL") | |
| download_file = gr.File(label="Download") | |
| download_btn.click(do_download, combined_state, download_file) | |
| if __name__ == "__main__": | |
| demo.launch() | |