Dataset-Creator / app.py
TitleOS's picture
Upload 9 files
390cebe verified
Raw
History Blame Contribute Delete
14.7 kB
"""Dataset Creator - Gradio app.
Lets you list any number of Hugging Face datasets, auto-detects (or lets
you manually map) how each one's rows turn into a chat-format triplet,
then combines, shuffles, and either pushes the result to your own HF
account or hands you a JSONL download.
UI wiring conventions used throughout this file, so the OAuth injection
keeps working:
- Every per-entry callback takes `entries` (the gr.State list) and
`uid` (a `gr.State(uid)` constant created at render time) as its
first two parameters, matching the first two items in `inputs=`.
- `oauth_token` / `oauth_profile` are declared as trailing, annotated
parameters and are *never* included in `inputs=` - Gradio resolves
them automatically from the signed-in session. See the Gradio OAuth
guide; this is the one part of the app that needs a live Spaces
OAuth session to fully smoke-test.
"""
from __future__ import annotations
import os
import random
import tempfile
import gradio as gr
import field_mapper
import schema_detect
import hf_dataset_loader
import hf_inspect
import hf_publish
from models import DatasetEntry, FieldMapping
_STATUS_LABELS = {
"empty": "Not detected yet.",
"detecting": "Checking schema...",
"needs_mapping": "Couldn't auto-detect the field layout - map it manually below.",
"ready": "Ready.",
"error": "Error.",
}
def _find(entries: list, uid: str) -> DatasetEntry:
for entry in entries:
if entry.uid == uid:
return entry
raise ValueError(f"No dataset entry with uid {uid!r} - it may have been removed.")
def _status_text(entry: DatasetEntry) -> str:
label = _STATUS_LABELS.get(entry.status, entry.status)
if entry.status == "error" and entry.error_message:
return f"**Status:** {label} {entry.error_message}"
if entry.status == "ready" and entry.mapping:
return f"**Status:** {label} (mapping type: `{entry.mapping.kind}`)"
return f"**Status:** {label}"
# --- entry list mutation ----------------------------------------------------
def add_entry(entries: list) -> list:
return entries + [DatasetEntry()]
def remove_entry(entries: list, uid: str) -> list:
return [e for e in entries if e.uid != uid]
def update_field(entries: list, uid: str, field_name: str, value) -> list:
entry = _find(entries, uid)
setattr(entry, field_name, value)
return entries
def update_limit(entries: list, uid: str, value) -> list:
entry = _find(entries, uid)
entry.limit = int(value) if value else 0
return entries
# --- schema detection / mapping ---------------------------------------------
def detect_entry(entries: list, uid: str, oauth_token: gr.OAuthToken | None) -> list:
entry = _find(entries, uid)
if not entry.repo_id.strip():
entry.status = "error"
entry.error_message = "Enter a dataset repo id first."
return entries
entry.status = "detecting"
token = oauth_token.token if oauth_token else None
try:
rows = hf_inspect.peek_rows(
entry.repo_id.strip(),
entry.subset.strip(),
entry.split.strip() or "train",
sample_size=8,
token=token,
)
except hf_inspect.DatasetInspectError as exc:
entry.status = "error"
entry.error_message = str(exc)
return entries
entry.sample_rows = rows
entry.detected_columns = list(rows[0].keys())
entry.detected_list_info = schema_detect.detect_list_column(rows)
mapping = schema_detect.auto_detect(rows)
if mapping:
entry.mapping = mapping
entry.status = "ready"
entry.error_message = ""
else:
entry.mapping = None
entry.status = "needs_mapping"
entry.error_message = ""
return entries
def apply_flat_mapping(entries: list, uid: str, user_field: str, assistant_field: str) -> list:
entry = _find(entries, uid)
if not user_field or not assistant_field:
entry.status = "error"
entry.error_message = "Pick both a user field and an assistant field."
return entries
entry.mapping = FieldMapping(
kind="flat_pair",
config={"user_field": user_field, "assistant_field": assistant_field},
)
entry.status = "ready"
entry.error_message = ""
return entries
def apply_list_mapping(
entries: list,
uid: str,
list_field: str,
role_key: str,
content_key: str,
human_tag: str,
assistant_tag: str,
) -> list:
entry = _find(entries, uid)
if not all([list_field, role_key, content_key, human_tag, assistant_tag]):
entry.status = "error"
entry.error_message = "Fill in every field for the conversation-list mapping."
return entries
if human_tag == assistant_tag:
entry.status = "error"
entry.error_message = "Human and assistant tags can't be the same value."
return entries
entry.mapping = FieldMapping(
kind="conversation_list",
config={
"list_field": list_field,
"role_key": role_key,
"content_key": content_key,
"human_tag": human_tag,
"gpt_tag": assistant_tag,
},
)
entry.status = "ready"
entry.error_message = ""
return entries
# --- combine / export --------------------------------------------------------
def run_pipeline(
entries: list,
seed,
oauth_token: gr.OAuthToken | None,
progress=gr.Progress(),
):
ready_entries = [e for e in entries if e.status == "ready" and e.mapping is not None]
if not ready_entries:
raise gr.Error("No datasets are ready. Detect or manually map at least one dataset first.")
token = oauth_token.token if oauth_token else None
all_records = []
breakdown_lines = []
total = len(ready_entries)
for idx, entry in enumerate(ready_entries):
progress((idx, total), desc=f"Loading {entry.repo_id}")
try:
rows = hf_dataset_loader.load_limited(
entry.repo_id.strip(),
entry.subset.strip(),
entry.split.strip() or "train",
int(entry.limit),
token=token,
)
except Exception as exc:
raise gr.Error(f"Failed loading '{entry.repo_id}': {exc}")
kept = 0
for row in rows:
triplet = field_mapper.extract_triplet(row, entry.mapping, entry.system_prompt)
if triplet:
all_records.append(triplet)
kept += 1
breakdown_lines.append(f"- `{entry.repo_id}` ({entry.split}): {kept}/{len(rows)} rows usable")
progress((total, total), desc="Shuffling")
rng = random.Random(int(seed)) if seed not in (None, "") else random.Random()
rng.shuffle(all_records)
summary = f"**Combined {len(all_records)} records from {total} dataset(s):**\n\n" + "\n".join(breakdown_lines)
return all_records, summary
def do_push(
records: list,
repo_name: str,
private: bool,
oauth_token: gr.OAuthToken | None,
oauth_profile: gr.OAuthProfile | None,
):
if not records:
raise gr.Error("Build the combined dataset first.")
if not oauth_token or not oauth_profile:
raise gr.Error("Sign in with your Hugging Face account first.")
if not repo_name.strip():
raise gr.Error("Give the dataset a name.")
repo_id = f"{oauth_profile.username}/{repo_name.strip()}"
try:
url = hf_publish.push_dataset(records, repo_id, bool(private), oauth_token.token)
except Exception as exc:
raise gr.Error(f"Push failed: {exc}")
return f"Pushed. View it at {url}"
def do_download(records: list):
if not records:
raise gr.Error("Build the combined dataset first.")
out_dir = tempfile.mkdtemp(prefix="sage_dataset_")
out_path = os.path.join(out_dir, "combined_dataset.jsonl")
hf_publish.write_jsonl(records, out_path)
return out_path
# --- UI -----------------------------------------------------------------------
with gr.Blocks(title="Dataset Creator") as demo:
gr.Markdown("# Dataset Creator")
gr.Markdown(
"Combine chat-format data from multiple Hugging Face datasets into one "
"shuffled set. Sign in to push the result to your own account, or just "
"download it as JSONL."
)
gr.LoginButton()
entries_state = gr.State([])
combined_state = gr.State([])
add_btn = gr.Button("+ Add dataset")
@gr.render(inputs=entries_state)
def render_entries(entries):
for entry in entries:
uid = entry.uid
uid_state = gr.State(uid)
with gr.Group():
gr.Markdown(f"**{entry.repo_id or 'New dataset'}**", key=f"title-{uid}")
with gr.Row():
repo_tb = gr.Textbox(
label="HF dataset repo",
value=entry.repo_id,
placeholder="e.g. NousResearch/hermes-function-calling-v1",
key=f"repo-{uid}",
)
subset_tb = gr.Textbox(
label="Config / subset (optional)", value=entry.subset, key=f"subset-{uid}"
)
split_tb = gr.Textbox(label="Split", value=entry.split or "train", key=f"split-{uid}")
limit_num = gr.Number(
label="Sample limit", value=entry.limit, precision=0, key=f"limit-{uid}"
)
system_tb = gr.Textbox(
label="System prompt override",
value=entry.system_prompt,
lines=2,
key=f"sys-{uid}",
)
repo_field = gr.State("repo_id")
subset_field = gr.State("subset")
split_field = gr.State("split")
sys_field = gr.State("system_prompt")
repo_tb.change(update_field, [entries_state, uid_state, repo_field, repo_tb], entries_state)
subset_tb.change(update_field, [entries_state, uid_state, subset_field, subset_tb], entries_state)
split_tb.change(update_field, [entries_state, uid_state, split_field, split_tb], entries_state)
limit_num.change(update_limit, [entries_state, uid_state, limit_num], entries_state)
system_tb.change(update_field, [entries_state, uid_state, sys_field, system_tb], entries_state)
with gr.Row():
detect_btn = gr.Button("Detect schema", key=f"detect-{uid}")
remove_btn = gr.Button("Remove", variant="stop", key=f"remove-{uid}")
gr.Markdown(_status_text(entry), key=f"status-{uid}")
detect_btn.click(detect_entry, [entries_state, uid_state], entries_state)
remove_btn.click(remove_entry, [entries_state, uid_state], entries_state)
if entry.status == "needs_mapping":
gr.Markdown(
f"Columns found: `{', '.join(entry.detected_columns)}`", key=f"cols-{uid}"
)
with gr.Row():
user_field_dd = gr.Dropdown(
choices=entry.detected_columns, label="User field", key=f"userf-{uid}"
)
asst_field_dd = gr.Dropdown(
choices=entry.detected_columns, label="Assistant field", key=f"asstf-{uid}"
)
apply_flat_btn = gr.Button("Apply flat mapping", key=f"applyflat-{uid}")
apply_flat_btn.click(
apply_flat_mapping,
[entries_state, uid_state, user_field_dd, asst_field_dd],
entries_state,
)
if entry.detected_list_info:
info = entry.detected_list_info
gr.Markdown(
f"Also found a conversation-style column: `{info['list_field']}` "
f"(role key `{info['role_key']}`, content key `{info['content_key']}`, "
f"tags seen: {info['tag_values']})",
key=f"listinfo-{uid}",
)
with gr.Row():
human_tag_dd = gr.Dropdown(
choices=info["tag_values"], label="Human/user tag", key=f"humant-{uid}"
)
asst_tag_dd = gr.Dropdown(
choices=info["tag_values"], label="Assistant tag", key=f"asstt-{uid}"
)
apply_list_btn = gr.Button(
"Apply conversation-list mapping", key=f"applylist-{uid}"
)
list_field_state = gr.State(info["list_field"])
role_key_state = gr.State(info["role_key"])
content_key_state = gr.State(info["content_key"])
apply_list_btn.click(
apply_list_mapping,
[
entries_state,
uid_state,
list_field_state,
role_key_state,
content_key_state,
human_tag_dd,
asst_tag_dd,
],
entries_state,
)
add_btn.click(add_entry, entries_state, entries_state)
gr.Markdown("---")
gr.Markdown("## Combine")
seed_num = gr.Number(label="Shuffle seed (optional)", precision=0)
build_btn = gr.Button("Build combined dataset", variant="primary")
summary_md = gr.Markdown()
build_btn.click(run_pipeline, [entries_state, seed_num], [combined_state, summary_md])
gr.Markdown("## Export")
with gr.Row():
repo_name_tb = gr.Textbox(label="Dataset name (goes under your HF username)")
private_cb = gr.Checkbox(label="Private", value=True)
push_btn = gr.Button("Push to Hub")
push_result_md = gr.Markdown()
push_btn.click(do_push, [combined_state, repo_name_tb, private_cb], push_result_md)
download_btn = gr.Button("Download as JSONL")
download_file = gr.File(label="Download")
download_btn.click(do_download, combined_state, download_file)
if __name__ == "__main__":
demo.launch()